Skip to content

Misc SubComponent removals #773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 57 additions & 21 deletions src/Microsoft.ML.Api/ComponentCreation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Runtime.CommandLine;
Expand Down Expand Up @@ -243,7 +244,8 @@ public static IDataLoader CreateLoader(this IHostEnvironment env, string setting
{
Contracts.CheckValue(env, nameof(env));
Contracts.CheckValue(files, nameof(files));
return CreateCore<IDataLoader, SignatureDataLoader>(env, settings, files);
Type factoryType = typeof(IComponentFactory<IMultiStreamSource, IDataLoader>);
return CreateCore<IDataLoader>(env, factoryType, typeof(SignatureDataLoader), settings, files);
}

/// <summary>
Expand All @@ -262,7 +264,7 @@ public static IDataSaver CreateSaver<TArgs>(this IHostEnvironment env, TArgs arg
public static IDataSaver CreateSaver(this IHostEnvironment env, string settings)
{
Contracts.CheckValue(env, nameof(env));
return CreateCore<IDataSaver, SignatureDataSaver>(env, settings);
return CreateCore<IDataSaver>(env, typeof(SignatureDataSaver), settings);
}

/// <summary>
Expand All @@ -283,7 +285,8 @@ public static IDataTransform CreateTransform(this IHostEnvironment env, string s
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(source, nameof(source));
return CreateCore<IDataTransform, SignatureDataTransform>(env, settings, source);
Type factoryType = typeof(IComponentFactory<IDataView, IDataTransform>);
return CreateCore<IDataTransform>(env, factoryType, typeof(SignatureDataTransform), settings, source);
}

/// <summary>
Expand All @@ -305,18 +308,17 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin
env.CheckValue(predictor, nameof(predictor));
env.CheckValueOrNull(trainSchema);

ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings);
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
var mapper = bindable.Bind(env, data.Schema);
return CreateCore<IDataScorerTransform, SignatureDataScorer>(env, settings, data.Data, mapper, trainSchema);
}
Type factoryType = typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>);
Type signatureType = typeof(SignatureDataScorer);

private static ICommandLineComponentFactory ParseScorerSettings(string settings)
{
return CmdParser.CreateComponentFactory(
typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>),
typeof(SignatureDataScorer),
ICommandLineComponentFactory scorerFactorySettings = CmdParser.CreateComponentFactory(
factoryType,
signatureType,
settings);

var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
var mapper = bindable.Bind(env, data.Schema);
return CreateCore<IDataScorerTransform>(env, factoryType, signatureType, settings, data.Data, mapper, trainSchema);
}

/// <summary>
Expand Down Expand Up @@ -344,7 +346,7 @@ public static IEvaluator CreateEvaluator(this IHostEnvironment env, string setti
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(settings, nameof(settings));
return CreateCore<IEvaluator, SignatureEvaluator>(env, settings);
return CreateCore<IEvaluator>(env, typeof(SignatureEvaluator), settings);
}

/// <summary>
Expand All @@ -369,14 +371,40 @@ internal static ITrainer CreateTrainer<TArgs>(this IHostEnvironment env, TArgs a
internal static ITrainer CreateTrainer(this IHostEnvironment env, string settings, out string loadName)
{
Contracts.CheckValue(env, nameof(env));
return CreateCore<ITrainer, SignatureTrainer>(env, settings, out loadName);
return CreateCore<ITrainer>(env, typeof(SignatureTrainer), settings, out loadName);
}

private static TRes CreateCore<TRes>(
IHostEnvironment env,
Type signatureType,
string settings,
params object[] extraArgs)
where TRes : class
{
return CreateCore<TRes>(env, signatureType, settings, out string loadName, extraArgs);
}

private static TRes CreateCore<TRes>(
IHostEnvironment env,
Type signatureType,
string settings,
out string loadName,
params object[] extraArgs)
where TRes : class
{
return CreateCore<TRes>(env, typeof(IComponentFactory<TRes>), signatureType, settings, out loadName, extraArgs);
}

private static TRes CreateCore<TRes, TSig>(IHostEnvironment env, string settings, params object[] extraArgs)
private static TRes CreateCore<TRes>(
IHostEnvironment env,
Type factoryType,
Type signatureType,
string settings,
params object[] extraArgs)
where TRes : class
{
string loadName;
return CreateCore<TRes, TSig>(env, settings, out loadName, extraArgs);
return CreateCore<TRes>(env, factoryType, signatureType, settings, out loadName, extraArgs);
}

private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs args, params object[] extraArgs)
Expand All @@ -387,15 +415,23 @@ private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs ar
return CreateCore<TRes, TArgs, TSig>(env, args, out loadName, extraArgs);
}

private static TRes CreateCore<TRes, TSig>(IHostEnvironment env, string settings, out string loadName, params object[] extraArgs)
private static TRes CreateCore<TRes>(
IHostEnvironment env,
Type factoryType,
Type signatureType,
string settings,
out string loadName,
params object[] extraArgs)
where TRes : class
{
Contracts.AssertValue(env);
env.AssertValue(factoryType);
env.AssertValue(signatureType);
env.AssertValue(settings, "settings");

var sc = SubComponent.Parse<TRes, TSig>(settings);
loadName = sc.Kind;
return sc.CreateInstance(env, extraArgs);
var factory = CmdParser.CreateComponentFactory(factoryType, signatureType, settings);
loadName = factory.Name;
return ComponentCatalog.CreateInstance<TRes>(env, factory.SignatureType, factory.Name, factory.GetSettingsString(), extraArgs);
}

private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs args, out string loadName, params object[] extraArgs)
Expand Down
7 changes: 4 additions & 3 deletions src/Microsoft.ML.Maml/ChainCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Tools;

[assembly: LoadableClass(ChainCommand.Summary, typeof(ChainCommand), typeof(ChainCommand.Arguments), typeof(SignatureCommand),
Expand All @@ -21,8 +22,8 @@ public sealed class ChainCommand : ICommand
public sealed class Arguments
{
#pragma warning disable 649 // never assigned
[Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd")]
public SubComponent<ICommand, SignatureCommand>[] Command;
[Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd", SignatureType = typeof(SignatureCommand))]
public IComponentFactory<ICommand>[] Command;
#pragma warning restore 649 // never assigned
}

Expand Down Expand Up @@ -61,7 +62,7 @@ public void Run()
chCmd.Info("Executing: {0}", sub);
chCmd.Info("=====================================================================================");

var cmd = sub.CreateInstance(_host);
var cmd = sub.CreateComponent(_host);
cmd.Run();
count++;

Expand Down
9 changes: 5 additions & 4 deletions src/Microsoft.ML.Maml/HelpCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Tools;

Expand Down Expand Up @@ -51,8 +52,8 @@ public sealed class Arguments
[Argument(ArgumentType.Multiple, HelpText = "Extra DLLs", ShortName = "dll")]
public string[] ExtraAssemblies;

[Argument(ArgumentType.LastOccurenceWins, Hide = true)]
public SubComponent<IGenerator, SignatureModuleGenerator> Generator;
[Argument(ArgumentType.LastOccurenceWins, Hide = true, SignatureType = typeof(SignatureModuleGenerator))]
public IComponentFactory<string, IGenerator> Generator;
#pragma warning restore 649 // never assigned
}

Expand Down Expand Up @@ -87,9 +88,9 @@ public HelpCommand(IHostEnvironment env, Arguments args)

_extraAssemblies = args.ExtraAssemblies;

if (args.Generator.IsGood())
if (args.Generator != null)
{
_generator = args.Generator.CreateInstance(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments()));
_generator = args.Generator.CreateComponent(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments()));
}
}

Expand Down
4 changes: 1 addition & 3 deletions src/Microsoft.ML.Maml/MAML.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ internal static int MainCore(TlcEnvironment env, string args, bool alwaysPrintSt
return -1;
}

var cmdDef = new SubComponent<ICommand, SignatureCommand>(kind, settings);

if (!ComponentCatalog.TryCreateInstance(mainHost, out ICommand cmd, cmdDef))
if (!ComponentCatalog.TryCreateInstance<ICommand, SignatureCommand>(mainHost, out ICommand cmd, kind, settings))
{
// Telemetry: Log
telemetryPipe.Send(TelemetryMessage.CreateCommand("UnknownCommand", settings));
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,12 @@ private static bool ValidateMamlOutput(string filename, string[] rawLines, out L
{
if (Utils.Size(chainArgs.Command) == 0)
return null;
var acceptableCommand = chainArgs.Command.FirstOrDefault(x =>
string.Equals(x.Kind, "CV", StringComparison.OrdinalIgnoreCase) ||
string.Equals(x.Kind, "TrainTest", StringComparison.OrdinalIgnoreCase) ||
string.Equals(x.Kind, "Test", StringComparison.OrdinalIgnoreCase));
var acceptableCommand = chainArgs.Command.Cast<ICommandLineComponentFactory>().FirstOrDefault(x =>
string.Equals(x.Name, "CV", StringComparison.OrdinalIgnoreCase) ||
string.Equals(x.Name, "TrainTest", StringComparison.OrdinalIgnoreCase) ||
string.Equals(x.Name, "Test", StringComparison.OrdinalIgnoreCase));
if (acceptableCommand == null || !ParseCommandArguments(env,
acceptableCommand.Kind + " " + acceptableCommand.SubComponentSettings, out commandArgs, out command, trimExe))
acceptableCommand.Name + " " + acceptableCommand.GetSettingsString(), out commandArgs, out command, trimExe))
{
return null;
}
Expand Down