From 3beaaeea3980b483a35a6ff78c0256902a0698f8 Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Wed, 29 Aug 2018 16:04:04 -0700 Subject: [PATCH 1/2] Remove SubComponent usages in ML.Maml --- src/Microsoft.ML.Maml/ChainCommand.cs | 7 ++++--- src/Microsoft.ML.Maml/HelpCommand.cs | 9 +++++---- src/Microsoft.ML.Maml/MAML.cs | 4 +--- src/Microsoft.ML.ResultProcessor/ResultProcessor.cs | 10 +++++----- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.ML.Maml/ChainCommand.cs b/src/Microsoft.ML.Maml/ChainCommand.cs index 829923a60c..4dc599360b 100644 --- a/src/Microsoft.ML.Maml/ChainCommand.cs +++ b/src/Microsoft.ML.Maml/ChainCommand.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Tools; [assembly: LoadableClass(ChainCommand.Summary, typeof(ChainCommand), typeof(ChainCommand.Arguments), typeof(SignatureCommand), @@ -21,8 +22,8 @@ public sealed class ChainCommand : ICommand public sealed class Arguments { #pragma warning disable 649 // never assigned - [Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd")] - public SubComponent<ICommand, SignatureCommand>[] Command; + [Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd", SignatureType = typeof(SignatureCommand))] + public IComponentFactory<ICommand>[] Command; #pragma warning restore 649 // never assigned } @@ -61,7 +62,7 @@ public void Run() chCmd.Info("Executing: {0}", sub); chCmd.Info("====================================================================================="); - var cmd = sub.CreateInstance(_host); + var cmd = sub.CreateComponent(_host); cmd.Run(); count++; diff --git a/src/Microsoft.ML.Maml/HelpCommand.cs b/src/Microsoft.ML.Maml/HelpCommand.cs index a815ffc0e5..bf7cb37e8c 100644 --- a/src/Microsoft.ML.Maml/HelpCommand.cs +++ b/src/Microsoft.ML.Maml/HelpCommand.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Tools; @@ -51,8 +52,8 @@ public sealed class Arguments [Argument(ArgumentType.Multiple, HelpText = "Extra DLLs", ShortName = "dll")] public string[] ExtraAssemblies; - [Argument(ArgumentType.LastOccurenceWins, Hide = true)] - public SubComponent<IGenerator, SignatureModuleGenerator> Generator; + [Argument(ArgumentType.LastOccurenceWins, Hide = true, SignatureType = typeof(SignatureModuleGenerator))] + public IComponentFactory<string, IGenerator> Generator; #pragma warning restore 649 // never assigned } @@ -87,9 +88,9 @@ public HelpCommand(IHostEnvironment env, Arguments args) _extraAssemblies = args.ExtraAssemblies; - if (args.Generator.IsGood()) + if (args.Generator != null) { - _generator = args.Generator.CreateInstance(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments())); + _generator = args.Generator.CreateComponent(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments())); } } diff --git a/src/Microsoft.ML.Maml/MAML.cs b/src/Microsoft.ML.Maml/MAML.cs index 1f341dd30e..cb75151379 100644 --- a/src/Microsoft.ML.Maml/MAML.cs +++ b/src/Microsoft.ML.Maml/MAML.cs @@ -122,9 +122,7 @@ internal static int MainCore(TlcEnvironment env, string args, bool alwaysPrintSt return -1; } - var cmdDef = new SubComponent<ICommand, SignatureCommand>(kind, settings); - - if (!ComponentCatalog.TryCreateInstance(mainHost, out ICommand cmd, cmdDef)) + if (!ComponentCatalog.TryCreateInstance<ICommand, SignatureCommand>(mainHost, out ICommand cmd, kind, settings)) { // Telemetry: Log telemetryPipe.Send(TelemetryMessage.CreateCommand("UnknownCommand", settings)); diff --git a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs index 5048467d6c..b07be5b2c0 100644 --- a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs +++ b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs @@ -433,12 +433,12 @@ private static bool ValidateMamlOutput(string filename, string[] rawLines, out L { if (Utils.Size(chainArgs.Command) == 0) return null; - var acceptableCommand = chainArgs.Command.FirstOrDefault(x => - string.Equals(x.Kind, "CV", StringComparison.OrdinalIgnoreCase) || - string.Equals(x.Kind, "TrainTest", StringComparison.OrdinalIgnoreCase) || - string.Equals(x.Kind, "Test", StringComparison.OrdinalIgnoreCase)); + var acceptableCommand = chainArgs.Command.Cast<ICommandLineComponentFactory>().FirstOrDefault(x => + string.Equals(x.Name, "CV", StringComparison.OrdinalIgnoreCase) || + string.Equals(x.Name, "TrainTest", StringComparison.OrdinalIgnoreCase) || + string.Equals(x.Name, "Test", StringComparison.OrdinalIgnoreCase)); if (acceptableCommand == null || !ParseCommandArguments(env, - acceptableCommand.Kind + " " + acceptableCommand.SubComponentSettings, out commandArgs, out command, trimExe)) + acceptableCommand.Name + " " + acceptableCommand.GetSettingsString(), out commandArgs, out command, trimExe)) { return null; } From c235086b3aca42d134bb62e12edcae381aba191f Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Wed, 29 Aug 2018 19:12:34 -0700 Subject: [PATCH 2/2] Remove SubComponent usage from ComponentCreation. --- src/Microsoft.ML.Api/ComponentCreation.cs | 78 +++++++++++++++++------ 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs index 0a1e1cd605..4c01b77057 100644 --- a/src/Microsoft.ML.Api/ComponentCreation.cs +++ b/src/Microsoft.ML.Api/ComponentCreation.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using System.IO; using Microsoft.ML.Runtime.CommandLine; @@ -243,7 +244,8 @@ public static IDataLoader CreateLoader(this IHostEnvironment env, string setting { Contracts.CheckValue(env, nameof(env)); Contracts.CheckValue(files, nameof(files)); - return CreateCore<IDataLoader, SignatureDataLoader>(env, settings, files); + Type factoryType = typeof(IComponentFactory<IMultiStreamSource, IDataLoader>); + return CreateCore<IDataLoader>(env, factoryType, typeof(SignatureDataLoader), settings, files); } /// <summary> @@ -262,7 +264,7 @@ public static IDataSaver CreateSaver<TArgs>(this IHostEnvironment env, TArgs arg public static IDataSaver CreateSaver(this IHostEnvironment env, string settings) { Contracts.CheckValue(env, nameof(env)); - return CreateCore<IDataSaver, SignatureDataSaver>(env, settings); + return CreateCore<IDataSaver>(env, typeof(SignatureDataSaver), settings); } /// <summary> @@ -283,7 +285,8 @@ public static IDataTransform CreateTransform(this IHostEnvironment env, string s { Contracts.CheckValue(env, nameof(env)); env.CheckValue(source, nameof(source)); - return CreateCore<IDataTransform, SignatureDataTransform>(env, settings, source); + Type factoryType = typeof(IComponentFactory<IDataView, IDataTransform>); + return CreateCore<IDataTransform>(env, factoryType, typeof(SignatureDataTransform), settings, source); } /// <summary> @@ -305,18 +308,17 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin env.CheckValue(predictor, nameof(predictor)); env.CheckValueOrNull(trainSchema); - ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings); - var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings); - var mapper = bindable.Bind(env, data.Schema); - return CreateCore<IDataScorerTransform, SignatureDataScorer>(env, settings, data.Data, mapper, trainSchema); - } + Type factoryType = typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>); + Type signatureType = typeof(SignatureDataScorer); - private static ICommandLineComponentFactory ParseScorerSettings(string settings) - { - return CmdParser.CreateComponentFactory( - typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>), - typeof(SignatureDataScorer), + ICommandLineComponentFactory scorerFactorySettings = CmdParser.CreateComponentFactory( + factoryType, + signatureType, settings); + + var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings); + var mapper = bindable.Bind(env, data.Schema); + return CreateCore<IDataScorerTransform>(env, factoryType, signatureType, settings, data.Data, mapper, trainSchema); } /// <summary> @@ -344,7 +346,7 @@ public static IEvaluator CreateEvaluator(this IHostEnvironment env, string setti { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(settings, nameof(settings)); - return CreateCore<IEvaluator, SignatureEvaluator>(env, settings); + return CreateCore<IEvaluator>(env, typeof(SignatureEvaluator), settings); } /// <summary> @@ -369,14 +371,40 @@ internal static ITrainer CreateTrainer<TArgs>(this IHostEnvironment env, TArgs a internal static ITrainer CreateTrainer(this IHostEnvironment env, string settings, out string loadName) { Contracts.CheckValue(env, nameof(env)); - return CreateCore<ITrainer, SignatureTrainer>(env, settings, out loadName); + return CreateCore<ITrainer>(env, typeof(SignatureTrainer), settings, out loadName); + } + + private static TRes CreateCore<TRes>( + IHostEnvironment env, + Type signatureType, + string settings, + params object[] extraArgs) + where TRes : class + { + return CreateCore<TRes>(env, signatureType, settings, out string loadName, extraArgs); + } + + private static TRes CreateCore<TRes>( + IHostEnvironment env, + Type signatureType, + string settings, + out string loadName, + params object[] extraArgs) + where TRes : class + { + return CreateCore<TRes>(env, typeof(IComponentFactory<TRes>), signatureType, settings, out loadName, extraArgs); } - private static TRes CreateCore<TRes, TSig>(IHostEnvironment env, string settings, params object[] extraArgs) + private static TRes CreateCore<TRes>( + IHostEnvironment env, + Type factoryType, + Type signatureType, + string settings, + params object[] extraArgs) where TRes : class { string loadName; - return CreateCore<TRes, TSig>(env, settings, out loadName, extraArgs); + return CreateCore<TRes>(env, factoryType, signatureType, settings, out loadName, extraArgs); } private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs args, params object[] extraArgs) @@ -387,15 +415,23 @@ private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs ar return CreateCore<TRes, TArgs, TSig>(env, args, out loadName, extraArgs); } - private static TRes CreateCore<TRes, TSig>(IHostEnvironment env, string settings, out string loadName, params object[] extraArgs) + private static TRes CreateCore<TRes>( + IHostEnvironment env, + Type factoryType, + Type signatureType, + string settings, + out string loadName, + params object[] extraArgs) where TRes : class { Contracts.AssertValue(env); + env.AssertValue(factoryType); + env.AssertValue(signatureType); env.AssertValue(settings, "settings"); - var sc = SubComponent.Parse<TRes, TSig>(settings); - loadName = sc.Kind; - return sc.CreateInstance(env, extraArgs); + var factory = CmdParser.CreateComponentFactory(factoryType, signatureType, settings); + loadName = factory.Name; + return ComponentCatalog.CreateInstance<TRes>(env, factory.SignatureType, factory.Name, factory.GetSettingsString(), extraArgs); } private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs args, out string loadName, params object[] extraArgs)