diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs index 0a1e1cd605..4c01b77057 100644 --- a/src/Microsoft.ML.Api/ComponentCreation.cs +++ b/src/Microsoft.ML.Api/ComponentCreation.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using System.IO; using Microsoft.ML.Runtime.CommandLine; @@ -243,7 +244,8 @@ public static IDataLoader CreateLoader(this IHostEnvironment env, string setting { Contracts.CheckValue(env, nameof(env)); Contracts.CheckValue(files, nameof(files)); - return CreateCore(env, settings, files); + Type factoryType = typeof(IComponentFactory); + return CreateCore(env, factoryType, typeof(SignatureDataLoader), settings, files); } /// @@ -262,7 +264,7 @@ public static IDataSaver CreateSaver(this IHostEnvironment env, TArgs arg public static IDataSaver CreateSaver(this IHostEnvironment env, string settings) { Contracts.CheckValue(env, nameof(env)); - return CreateCore(env, settings); + return CreateCore(env, typeof(SignatureDataSaver), settings); } /// @@ -283,7 +285,8 @@ public static IDataTransform CreateTransform(this IHostEnvironment env, string s { Contracts.CheckValue(env, nameof(env)); env.CheckValue(source, nameof(source)); - return CreateCore(env, settings, source); + Type factoryType = typeof(IComponentFactory); + return CreateCore(env, factoryType, typeof(SignatureDataTransform), settings, source); } /// @@ -305,18 +308,17 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin env.CheckValue(predictor, nameof(predictor)); env.CheckValueOrNull(trainSchema); - ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings); - var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings); - var mapper = bindable.Bind(env, data.Schema); - return CreateCore(env, settings, data.Data, mapper, trainSchema); - } + Type factoryType = typeof(IComponentFactory); + Type signatureType = typeof(SignatureDataScorer); - private static ICommandLineComponentFactory ParseScorerSettings(string settings) - { - return CmdParser.CreateComponentFactory( - typeof(IComponentFactory), - typeof(SignatureDataScorer), + ICommandLineComponentFactory scorerFactorySettings = CmdParser.CreateComponentFactory( + factoryType, + signatureType, settings); + + var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings); + var mapper = bindable.Bind(env, data.Schema); + return CreateCore(env, factoryType, signatureType, settings, data.Data, mapper, trainSchema); } /// @@ -344,7 +346,7 @@ public static IEvaluator CreateEvaluator(this IHostEnvironment env, string setti { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(settings, nameof(settings)); - return CreateCore(env, settings); + return CreateCore(env, typeof(SignatureEvaluator), settings); } /// @@ -369,14 +371,40 @@ internal static ITrainer CreateTrainer(this IHostEnvironment env, TArgs a internal static ITrainer CreateTrainer(this IHostEnvironment env, string settings, out string loadName) { Contracts.CheckValue(env, nameof(env)); - return CreateCore(env, settings, out loadName); + return CreateCore(env, typeof(SignatureTrainer), settings, out loadName); + } + + private static TRes CreateCore( + IHostEnvironment env, + Type signatureType, + string settings, + params object[] extraArgs) + where TRes : class + { + return CreateCore(env, signatureType, settings, out string loadName, extraArgs); + } + + private static TRes CreateCore( + IHostEnvironment env, + Type signatureType, + string settings, + out string loadName, + params object[] extraArgs) + where TRes : class + { + return CreateCore(env, typeof(IComponentFactory), signatureType, settings, out loadName, extraArgs); } - private static TRes CreateCore(IHostEnvironment env, string settings, params object[] extraArgs) + private static TRes CreateCore( + IHostEnvironment env, + Type factoryType, + Type signatureType, + string settings, + params object[] extraArgs) where TRes : class { string loadName; - return CreateCore(env, settings, out loadName, extraArgs); + return CreateCore(env, factoryType, signatureType, settings, out loadName, extraArgs); } private static TRes CreateCore(IHostEnvironment env, TArgs args, params object[] extraArgs) @@ -387,15 +415,23 @@ private static TRes CreateCore(IHostEnvironment env, TArgs ar return CreateCore(env, args, out loadName, extraArgs); } - private static TRes CreateCore(IHostEnvironment env, string settings, out string loadName, params object[] extraArgs) + private static TRes CreateCore( + IHostEnvironment env, + Type factoryType, + Type signatureType, + string settings, + out string loadName, + params object[] extraArgs) where TRes : class { Contracts.AssertValue(env); + env.AssertValue(factoryType); + env.AssertValue(signatureType); env.AssertValue(settings, "settings"); - var sc = SubComponent.Parse(settings); - loadName = sc.Kind; - return sc.CreateInstance(env, extraArgs); + var factory = CmdParser.CreateComponentFactory(factoryType, signatureType, settings); + loadName = factory.Name; + return ComponentCatalog.CreateInstance(env, factory.SignatureType, factory.Name, factory.GetSettingsString(), extraArgs); } private static TRes CreateCore(IHostEnvironment env, TArgs args, out string loadName, params object[] extraArgs) diff --git a/src/Microsoft.ML.Maml/ChainCommand.cs b/src/Microsoft.ML.Maml/ChainCommand.cs index 829923a60c..4dc599360b 100644 --- a/src/Microsoft.ML.Maml/ChainCommand.cs +++ b/src/Microsoft.ML.Maml/ChainCommand.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Tools; [assembly: LoadableClass(ChainCommand.Summary, typeof(ChainCommand), typeof(ChainCommand.Arguments), typeof(SignatureCommand), @@ -21,8 +22,8 @@ public sealed class ChainCommand : ICommand public sealed class Arguments { #pragma warning disable 649 // never assigned - [Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd")] - public SubComponent[] Command; + [Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd", SignatureType = typeof(SignatureCommand))] + public IComponentFactory[] Command; #pragma warning restore 649 // never assigned } @@ -61,7 +62,7 @@ public void Run() chCmd.Info("Executing: {0}", sub); chCmd.Info("====================================================================================="); - var cmd = sub.CreateInstance(_host); + var cmd = sub.CreateComponent(_host); cmd.Run(); count++; diff --git a/src/Microsoft.ML.Maml/HelpCommand.cs b/src/Microsoft.ML.Maml/HelpCommand.cs index a815ffc0e5..bf7cb37e8c 100644 --- a/src/Microsoft.ML.Maml/HelpCommand.cs +++ b/src/Microsoft.ML.Maml/HelpCommand.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Tools; @@ -51,8 +52,8 @@ public sealed class Arguments [Argument(ArgumentType.Multiple, HelpText = "Extra DLLs", ShortName = "dll")] public string[] ExtraAssemblies; - [Argument(ArgumentType.LastOccurenceWins, Hide = true)] - public SubComponent Generator; + [Argument(ArgumentType.LastOccurenceWins, Hide = true, SignatureType = typeof(SignatureModuleGenerator))] + public IComponentFactory Generator; #pragma warning restore 649 // never assigned } @@ -87,9 +88,9 @@ public HelpCommand(IHostEnvironment env, Arguments args) _extraAssemblies = args.ExtraAssemblies; - if (args.Generator.IsGood()) + if (args.Generator != null) { - _generator = args.Generator.CreateInstance(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments())); + _generator = args.Generator.CreateComponent(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments())); } } diff --git a/src/Microsoft.ML.Maml/MAML.cs b/src/Microsoft.ML.Maml/MAML.cs index 1f341dd30e..cb75151379 100644 --- a/src/Microsoft.ML.Maml/MAML.cs +++ b/src/Microsoft.ML.Maml/MAML.cs @@ -122,9 +122,7 @@ internal static int MainCore(TlcEnvironment env, string args, bool alwaysPrintSt return -1; } - var cmdDef = new SubComponent(kind, settings); - - if (!ComponentCatalog.TryCreateInstance(mainHost, out ICommand cmd, cmdDef)) + if (!ComponentCatalog.TryCreateInstance(mainHost, out ICommand cmd, kind, settings)) { // Telemetry: Log telemetryPipe.Send(TelemetryMessage.CreateCommand("UnknownCommand", settings)); diff --git a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs index 5048467d6c..b07be5b2c0 100644 --- a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs +++ b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs @@ -433,12 +433,12 @@ private static bool ValidateMamlOutput(string filename, string[] rawLines, out L { if (Utils.Size(chainArgs.Command) == 0) return null; - var acceptableCommand = chainArgs.Command.FirstOrDefault(x => - string.Equals(x.Kind, "CV", StringComparison.OrdinalIgnoreCase) || - string.Equals(x.Kind, "TrainTest", StringComparison.OrdinalIgnoreCase) || - string.Equals(x.Kind, "Test", StringComparison.OrdinalIgnoreCase)); + var acceptableCommand = chainArgs.Command.Cast().FirstOrDefault(x => + string.Equals(x.Name, "CV", StringComparison.OrdinalIgnoreCase) || + string.Equals(x.Name, "TrainTest", StringComparison.OrdinalIgnoreCase) || + string.Equals(x.Name, "Test", StringComparison.OrdinalIgnoreCase)); if (acceptableCommand == null || !ParseCommandArguments(env, - acceptableCommand.Kind + " " + acceptableCommand.SubComponentSettings, out commandArgs, out command, trimExe)) + acceptableCommand.Name + " " + acceptableCommand.GetSettingsString(), out commandArgs, out command, trimExe)) { return null; }