diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs index bd2613e3fd..1f8d5b0e8f 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs @@ -37,41 +37,6 @@ public interface IComponentFactory : IComponentFactory TComponent CreateComponent(IHostEnvironment env, TArg1 argument1); } - /// - /// A class for creating a component when we take one extra parameter - /// (and an ) that simply wraps a delegate which - /// creates the component. - /// - public class SimpleComponentFactory : IComponentFactory - { - private Func _factory; - - public SimpleComponentFactory(Func factory) - { - _factory = factory; - } - - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) - { - return _factory(env, argument1); - } - } - - public class SimpleComponentFactory : IComponentFactory - { - private Func _factory; - - public SimpleComponentFactory(Func factory) - { - _factory = factory; - } - - public TComponent CreateComponent(IHostEnvironment env) - { - return _factory(env); - } - } - /// /// An interface for creating a component when we take two extra parameters (and an ). /// @@ -89,22 +54,94 @@ public interface IComponentFactory } /// - /// A class for creating a component when we take three extra parameters - /// (and an ) that simply wraps a delegate which - /// creates the component. + /// A utility class for creating instances. /// - public class SimpleComponentFactory : IComponentFactory + public static class ComponentFactoryUtils { - private Func _factory; + /// + /// Creates a component factory with no extra parameters (other than an ) + /// that simply wraps a delegate which creates the component. + /// + public static IComponentFactory CreateFromFunction(Func factory) + { + return new SimpleComponentFactory(factory); + } + + /// + /// Creates a component factory when we take one extra parameter (and an + /// ) that simply wraps a delegate which creates the component. + /// + public static IComponentFactory CreateFromFunction(Func factory) + { + return new SimpleComponentFactory(factory); + } + + /// + /// Creates a component factory when we take three extra parameters (and an + /// ) that simply wraps a delegate which creates the component. + /// + public static IComponentFactory CreateFromFunction(Func factory) + { + return new SimpleComponentFactory(factory); + } - public SimpleComponentFactory(Func factory) + /// + /// A class for creating a component with no extra parameters (other than an ) + /// that simply wraps a delegate which creates the component. + /// + private sealed class SimpleComponentFactory : IComponentFactory { - _factory = factory; + private readonly Func _factory; + + public SimpleComponentFactory(Func factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env) + { + return _factory(env); + } } - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) + /// + /// A class for creating a component when we take one extra parameter + /// (and an ) that simply wraps a delegate which + /// creates the component. + /// + private sealed class SimpleComponentFactory : IComponentFactory { - return _factory(env, argument1, argument2, argument3); + private readonly Func _factory; + + public SimpleComponentFactory(Func factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) + { + return _factory(env, argument1); + } + } + + /// + /// A class for creating a component when we take three extra parameters + /// (and an ) that simply wraps a delegate which + /// creates the component. + /// + private sealed class SimpleComponentFactory : IComponentFactory + { + private readonly Func _factory; + + public SimpleComponentFactory(Func factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) + { + return _factory(env, argument1, argument2, argument3); + } } } } diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 3055afe917..971512cc97 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -164,7 +164,7 @@ private void RunCore(IChannel ch, string cmd) new[] { new KeyValuePair>( - "", new SimpleComponentFactory( + "", ComponentFactoryUtils.CreateFromFunction( (env, input) => { var args = new GenerateNumberTransform.Arguments(); diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 635d751d2e..879b29d4ce 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -330,7 +330,7 @@ public static TScorerFactory GetScorerComponent( }; } - return new SimpleComponentFactory(factoryFunc); + return ComponentFactoryUtils.CreateFromFunction(factoryFunc); } /// diff --git a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj index ddd4557788..ef0ad01a40 100644 --- a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj +++ b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj @@ -10,6 +10,7 @@ + diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs index a5c9c757a4..9b9900a65f 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs @@ -9,7 +9,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { - public abstract class BaseScalarStacking : BaseStacking + public abstract class BaseScalarStacking : BaseStacking { internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args) : base(env, name, args) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index f30174a31d..c294a1fcde 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; @@ -15,7 +16,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using ColumnRole = RoleMappedSchema.ColumnRole; - public abstract class BaseStacking : IStackingTrainer + public abstract class BaseStacking : IStackingTrainer { public abstract class ArgumentsBase { @@ -24,13 +25,10 @@ public abstract class ArgumentsBase [TGUI(Label = "Validation Dataset Proportion")] public Single ValidationDatasetProportion = 0.3f; - [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, - Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - [TGUI(Label = "Base predictor")] - public SubComponent>, TSigBase> BasePredictorType; + internal abstract IComponentFactory>> GetPredictorFactory(); } - protected readonly SubComponent>, TSigBase> BasePredictorType; + protected readonly IComponentFactory>> BasePredictorType; protected readonly IHost Host; protected IPredictorProducing Meta; @@ -45,10 +43,10 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args) Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1, nameof(args.ValidationDatasetProportion), "The validation proportion for stacking should be greater than or equal to 0 and less than 1"); - Host.CheckUserArg(args.BasePredictorType.IsGood(), nameof(args.BasePredictorType)); ValidationDatasetProportion = args.ValidationDatasetProportion; - BasePredictorType = args.BasePredictorType; + BasePredictorType = args.GetPredictorFactory(); + Host.CheckValue(BasePredictorType, nameof(BasePredictorType)); } internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx) @@ -187,7 +185,7 @@ public void Train(List>> models, var view = bldr.GetDataView(); var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features); - var trainer = BasePredictorType.CreateInstance(host); + var trainer = BasePredictorType.CreateComponent(host); if (trainer.Info.NeedNormalization) ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this."); Meta = trainer.Train(rmd); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 2ef74c8169..40d59b631a 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -8,7 +8,10 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner), @@ -20,7 +23,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TVectorPredictor = IPredictorProducing>; - public sealed class MultiStacking : BaseStacking, SignatureMultiClassClassifierTrainer>, ICanSaveModel, IMultiClassOutputCombiner + public sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner { public const string LoadName = "MultiStacking"; public const string LoaderSignature = "MultiStackingCombiner"; @@ -38,13 +41,24 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory> BasePredictorType; + + internal override IComponentFactory> GetPredictorFactory() => BasePredictorType; + public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this); public Arguments() { // REVIEW: Perhaps we can have a better non-parametetric learner. - BasePredictorType = new SubComponent, SignatureMultiClassClassifierTrainer>( - "OVA", "p=FastTreeBinaryClassification"); + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( + env => new Ova(env, new Ova.Arguments() + { + PredictorType = ComponentFactoryUtils.CreateFromFunction( + e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())) + })); } } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index 0b5f8e6057..436e365b79 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -7,6 +7,8 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner), @@ -19,7 +21,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing; - public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel + public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel { public const string LoadName = "RegressionStacking"; public const string LoaderSignature = "RegressionStacking"; @@ -37,9 +39,17 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory> BasePredictorType; + + internal override IComponentFactory> GetPredictorFactory() => BasePredictorType; + public Arguments() { - BasePredictorType = new SubComponent, SignatureRegressorTrainer>("FastTreeRegression"); + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( + env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments())); } public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index f3481e9936..f0ced9d947 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -5,9 +5,10 @@ using System; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)] @@ -16,7 +17,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing; - public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel + public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel { public const string UserName = "Stacking"; public const string LoadName = "Stacking"; @@ -35,9 +36,17 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory> BasePredictorType; + + internal override IComponentFactory> GetPredictorFactory() => BasePredictorType; + public Arguments() { - BasePredictorType = new SubComponent, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification"); + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( + env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments())); } public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index 8465e1a5e8..513ad95139 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -54,16 +54,16 @@ public virtual IList>> Prune(ILi return models; } - private SubComponent GetEvaluatorSubComponent() + private IEvaluator GetEvaluator(IHostEnvironment env) { switch (PredictionKind) { case PredictionKind.BinaryClassification: - return new SubComponent(BinaryClassifierEvaluator.LoadName); + return new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()); case PredictionKind.Regression: - return new SubComponent(RegressionEvaluator.LoadName); + return new RegressionEvaluator(env, new RegressionEvaluator.Arguments()); case PredictionKind.MultiClassClassification: - return new SubComponent(MultiClassClassifierEvaluator.LoadName); + return new MultiClassClassifierEvaluator(env, new MultiClassClassifierEvaluator.Arguments()); default: throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind); } @@ -81,12 +81,11 @@ public virtual void CalculateMetrics(FeatureSubsetModel public sealed class EnsembleTrainer : EnsembleTrainerBase, + IBinarySubModelSelector, IBinaryOutputCombiner>, IModelCombiner { public const string LoadNameValue = "WeightedEnsemble"; @@ -44,9 +45,18 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportBinaryOutputCombinerFactory OutputCombiner = new MedianFactory(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] + public IComponentFactory>[] BasePredictors; + + internal override IComponentFactory>[] GetPredictorFactories() => BasePredictors; + public Arguments() { - BasePredictors = new[] { new SubComponent, SignatureBinaryClassifierTrainer>("LinearSVM") }; + BasePredictors = new[] + { + ComponentFactoryUtils.CreateFromFunction( + env => new LinearSvm(env, new LinearSvm.Arguments())) + }; } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 0a350ef8ee..5ebf5cfac0 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Ensemble { using Stopwatch = System.Diagnostics.Stopwatch; - public abstract class EnsembleTrainerBase : TrainerBase + public abstract class EnsembleTrainerBase : TrainerBase where TPredictor : class, IPredictorProducing where TSelector : class, ISubModelSelector where TCombiner : class, IOutputCombiner @@ -53,8 +53,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel [TGUI(Label = "Show Sub-Model Metrics")] public bool ShowMetrics; - [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - public SubComponent>, TSig>[] BasePredictors; + internal abstract IComponentFactory>>[] GetPredictorFactories(); } private const int DefaultNumModels = 50; @@ -78,21 +77,22 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, using (var ch = Host.Start("Init")) { - ch.CheckUserArg(Utils.Size(Args.BasePredictors) > 0, nameof(Args.BasePredictors), "This should have at-least one value"); + var predictorFactories = Args.GetPredictorFactories(); + ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(EnsembleTrainer.Arguments.BasePredictors), "This should have at-least one value"); NumModels = Args.NumModels ?? - (Args.BasePredictors.Length == 1 ? DefaultNumModels : Args.BasePredictors.Length); + (predictorFactories.Length == 1 ? DefaultNumModels : predictorFactories.Length); ch.CheckUserArg(NumModels > 0, nameof(Args.NumModels), "Must be positive, or null to indicate numModels is the number of base predictors"); - if (Utils.Size(Args.BasePredictors) > NumModels) + if (Utils.Size(predictorFactories) > NumModels) ch.Warning("The base predictor count is greater than models count. Some of the base predictors will be ignored."); _subsetSelector = Args.SamplingType.CreateComponent(Host); Trainers = new ITrainer>[NumModels]; for (int i = 0; i < Trainers.Length; i++) - Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host); + Trainers[i] = predictorFactories[i % predictorFactories.Length].CreateComponent(Host); // We infer normalization and calibration preferences from the trainers. However, even if the internal trainers // don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache. Info = new TrainerInfo( diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 4421cd5838..d5c60e24a6 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -12,7 +12,9 @@ using Microsoft.ML.Runtime.Ensemble; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Ensemble.Selector; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments), @@ -28,7 +30,7 @@ namespace Microsoft.ML.Runtime.Ensemble /// public sealed class MulticlassDataPartitionEnsembleTrainer : EnsembleTrainerBase, EnsembleMultiClassPredictor, - IMulticlassSubModelSelector, IMultiClassOutputCombiner, SignatureMultiClassClassifierTrainer>, + IMulticlassSubModelSelector, IMultiClassOutputCombiner>, IModelCombiner { public const string LoadNameValue = "WeightedEnsembleMulticlass"; @@ -45,9 +47,18 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] + public IComponentFactory>[] BasePredictors; + + internal override IComponentFactory>[] GetPredictorFactories() => BasePredictors; + public Arguments() { - BasePredictors = new[] { new SubComponent, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") }; + BasePredictors = new[] + { + ComponentFactoryUtils.CreateFromFunction( + env => new MulticlassLogisticRegression(env, new MulticlassLogisticRegression.Arguments())) + }; } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 1cc36f20cd..3fe853aff5 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -12,7 +12,9 @@ using Microsoft.ML.Runtime.Ensemble; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Ensemble.Selector; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer) }, @@ -23,7 +25,7 @@ namespace Microsoft.ML.Runtime.Ensemble { using TScalarPredictor = IPredictorProducing; public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase, + IRegressionSubModelSelector, IRegressionOutputCombiner>, IModelCombiner { public const string LoadNameValue = "EnsembleRegression"; @@ -39,9 +41,18 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportRegressionOutputCombinerFactory OutputCombiner = new MedianFactory(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] + public IComponentFactory>[] BasePredictors; + + internal override IComponentFactory>[] GetPredictorFactories() => BasePredictors; + public Arguments() { - BasePredictors = new[] { new SubComponent, SignatureRegressorTrainer>("OnlineGradientDescent") }; + BasePredictors = new[] + { + ComponentFactoryUtils.CreateFromFunction( + env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments())) + }; } } diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 88ff3eaf62..e9f5f326f6 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -672,10 +672,10 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV trainScoreArgs.Trainer = new SubComponent(args.Trainer.Kind, args.Trainer.Settings); - trainScoreArgs.Scorer = new SimpleComponentFactory( + trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction( (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema)); - var mapperFactory = new SimpleComponentFactory( + var mapperFactory = ComponentFactoryUtils.CreateFromFunction( (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor)); var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs index 931ab95999..811f857046 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs @@ -27,10 +27,8 @@ void Metacomponents() var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); var trainer = new Ova(env, new Ova.Arguments { - PredictorType = new SimpleComponentFactory>> - ( - (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()) - ) + PredictorType = ComponentFactoryUtils.CreateFromFunction( + (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())) }); IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;