diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs index bd2613e3fd..1f8d5b0e8f 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs @@ -37,41 +37,6 @@ public interface IComponentFactory<in TArg1, out TComponent> : IComponentFactory TComponent CreateComponent(IHostEnvironment env, TArg1 argument1); } - /// <summary> - /// A class for creating a component when we take one extra parameter - /// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which - /// creates the component. - /// </summary> - public class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent> - { - private Func<IHostEnvironment, TArg1, TComponent> _factory; - - public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory) - { - _factory = factory; - } - - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) - { - return _factory(env, argument1); - } - } - - public class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent> - { - private Func<IHostEnvironment, TComponent> _factory; - - public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory) - { - _factory = factory; - } - - public TComponent CreateComponent(IHostEnvironment env) - { - return _factory(env); - } - } - /// <summary> /// An interface for creating a component when we take two extra parameters (and an <see cref="IHostEnvironment"/>). /// </summary> @@ -89,22 +54,94 @@ public interface IComponentFactory<in TArg1, in TArg2, in TArg3, out TComponent> } /// <summary> - /// A class for creating a component when we take three extra parameters - /// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which - /// creates the component. + /// A utility class for creating <see cref="IComponentFactory"/> instances. /// </summary> - public class SimpleComponentFactory<TArg1, TArg2, TArg3, TComponent> : IComponentFactory<TArg1, TArg2, TArg3, TComponent> + public static class ComponentFactoryUtils { - private Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> _factory; + /// <summary> + /// Creates a component factory with no extra parameters (other than an <see cref="IHostEnvironment"/>) + /// that simply wraps a delegate which creates the component. + /// </summary> + public static IComponentFactory<TComponent> CreateFromFunction<TComponent>(Func<IHostEnvironment, TComponent> factory) + { + return new SimpleComponentFactory<TComponent>(factory); + } + + /// <summary> + /// Creates a component factory when we take one extra parameter (and an + /// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component. + /// </summary> + public static IComponentFactory<TArg1, TComponent> CreateFromFunction<TArg1, TComponent>(Func<IHostEnvironment, TArg1, TComponent> factory) + { + return new SimpleComponentFactory<TArg1, TComponent>(factory); + } + + /// <summary> + /// Creates a component factory when we take three extra parameters (and an + /// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component. + /// </summary> + public static IComponentFactory<TArg1, TArg2, TArg3, TComponent> CreateFromFunction<TArg1, TArg2, TArg3, TComponent>(Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> factory) + { + return new SimpleComponentFactory<TArg1, TArg2, TArg3, TComponent>(factory); + } - public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> factory) + /// <summary> + /// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>) + /// that simply wraps a delegate which creates the component. + /// </summary> + private sealed class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent> { - _factory = factory; + private readonly Func<IHostEnvironment, TComponent> _factory; + + public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env) + { + return _factory(env); + } } - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) + /// <summary> + /// A class for creating a component when we take one extra parameter + /// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which + /// creates the component. + /// </summary> + private sealed class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent> { - return _factory(env, argument1, argument2, argument3); + private readonly Func<IHostEnvironment, TArg1, TComponent> _factory; + + public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) + { + return _factory(env, argument1); + } + } + + /// <summary> + /// A class for creating a component when we take three extra parameters + /// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which + /// creates the component. + /// </summary> + private sealed class SimpleComponentFactory<TArg1, TArg2, TArg3, TComponent> : IComponentFactory<TArg1, TArg2, TArg3, TComponent> + { + private readonly Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> _factory; + + public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) + { + return _factory(env, argument1, argument2, argument3); + } } } } diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 3055afe917..971512cc97 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -164,7 +164,7 @@ private void RunCore(IChannel ch, string cmd) new[] { new KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>( - "", new SimpleComponentFactory<IDataView, IDataTransform>( + "", ComponentFactoryUtils.CreateFromFunction<IDataView, IDataTransform>( (env, input) => { var args = new GenerateNumberTransform.Arguments(); diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 635d751d2e..879b29d4ce 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -330,7 +330,7 @@ public static TScorerFactory GetScorerComponent( }; } - return new SimpleComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(factoryFunc); + return ComponentFactoryUtils.CreateFromFunction(factoryFunc); } /// <summary> diff --git a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj index ddd4557788..ef0ad01a40 100644 --- a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj +++ b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj @@ -10,6 +10,7 @@ <ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" /> <ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" /> <ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" /> + <ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" /> </ItemGroup> </Project> diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs index a5c9c757a4..9b9900a65f 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs @@ -9,7 +9,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { - public abstract class BaseScalarStacking<TSigBase> : BaseStacking<Single, TSigBase> + public abstract class BaseScalarStacking : BaseStacking<Single> { internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args) : base(env, name, args) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index f30174a31d..c294a1fcde 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; @@ -15,7 +16,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using ColumnRole = RoleMappedSchema.ColumnRole; - public abstract class BaseStacking<TOutput, TSigBase> : IStackingTrainer<TOutput> + public abstract class BaseStacking<TOutput> : IStackingTrainer<TOutput> { public abstract class ArgumentsBase { @@ -24,13 +25,10 @@ public abstract class ArgumentsBase [TGUI(Label = "Validation Dataset Proportion")] public Single ValidationDatasetProportion = 0.3f; - [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, - Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - [TGUI(Label = "Base predictor")] - public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType; + internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> GetPredictorFactory(); } - protected readonly SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType; + protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorType; protected readonly IHost Host; protected IPredictorProducing<TOutput> Meta; @@ -45,10 +43,10 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args) Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1, nameof(args.ValidationDatasetProportion), "The validation proportion for stacking should be greater than or equal to 0 and less than 1"); - Host.CheckUserArg(args.BasePredictorType.IsGood(), nameof(args.BasePredictorType)); ValidationDatasetProportion = args.ValidationDatasetProportion; - BasePredictorType = args.BasePredictorType; + BasePredictorType = args.GetPredictorFactory(); + Host.CheckValue(BasePredictorType, nameof(BasePredictorType)); } internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx) @@ -187,7 +185,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models, var view = bldr.GetDataView(); var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features); - var trainer = BasePredictorType.CreateInstance(host); + var trainer = BasePredictorType.CreateComponent(host); if (trainer.Info.NeedNormalization) ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this."); Meta = trainer.Train(rmd); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 2ef74c8169..40d59b631a 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -8,7 +8,10 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner), @@ -20,7 +23,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TVectorPredictor = IPredictorProducing<VBuffer<Single>>; - public sealed class MultiStacking : BaseStacking<VBuffer<Single>, SignatureMultiClassClassifierTrainer>, ICanSaveModel, IMultiClassOutputCombiner + public sealed class MultiStacking : BaseStacking<VBuffer<Single>>, ICanSaveModel, IMultiClassOutputCombiner { public const string LoadName = "MultiStacking"; public const string LoaderSignature = "MultiStackingCombiner"; @@ -38,13 +41,24 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorType; + + internal override IComponentFactory<ITrainer<TVectorPredictor>> GetPredictorFactory() => BasePredictorType; + public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this); public Arguments() { // REVIEW: Perhaps we can have a better non-parametetric learner. - BasePredictorType = new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>( - "OVA", "p=FastTreeBinaryClassification"); + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( + env => new Ova(env, new Ova.Arguments() + { + PredictorType = ComponentFactoryUtils.CreateFromFunction( + e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())) + })); } } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index 0b5f8e6057..436e365b79 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -7,6 +7,8 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner), @@ -19,7 +21,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing<Single>; - public sealed class RegressionStacking : BaseScalarStacking<SignatureRegressorTrainer>, IRegressionOutputCombiner, ICanSaveModel + public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel { public const string LoadName = "RegressionStacking"; public const string LoaderSignature = "RegressionStacking"; @@ -37,9 +39,17 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType; + + internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType; + public Arguments() { - BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("FastTreeRegression"); + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( + env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments())); } public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index f3481e9936..f0ced9d947 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -5,9 +5,10 @@ using System; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)] @@ -16,7 +17,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing<Single>; - public sealed class Stacking : BaseScalarStacking<SignatureBinaryClassifierTrainer>, IBinaryOutputCombiner, ICanSaveModel + public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel { public const string UserName = "Stacking"; public const string LoadName = "Stacking"; @@ -35,9 +36,17 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType; + + internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType; + public Arguments() { - BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification"); + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( + env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments())); } public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index 8465e1a5e8..513ad95139 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -54,16 +54,16 @@ public virtual IList<FeatureSubsetModel<IPredictorProducing<TOutput>>> Prune(ILi return models; } - private SubComponent<IEvaluator, SignatureEvaluator> GetEvaluatorSubComponent() + private IEvaluator GetEvaluator(IHostEnvironment env) { switch (PredictionKind) { case PredictionKind.BinaryClassification: - return new SubComponent<IEvaluator, SignatureEvaluator>(BinaryClassifierEvaluator.LoadName); + return new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()); case PredictionKind.Regression: - return new SubComponent<IEvaluator, SignatureEvaluator>(RegressionEvaluator.LoadName); + return new RegressionEvaluator(env, new RegressionEvaluator.Arguments()); case PredictionKind.MultiClassClassification: - return new SubComponent<IEvaluator, SignatureEvaluator>(MultiClassClassifierEvaluator.LoadName); + return new MultiClassClassifierEvaluator(env, new MultiClassClassifierEvaluator.Arguments()); default: throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind); } @@ -81,12 +81,11 @@ public virtual void CalculateMetrics(FeatureSubsetModel<IPredictorProducing<TOut // Because the training and test datasets are drawn from the same base dataset, the test data role mappings // are the same as for the train data. IDataScorerTransform scorePipe = ScoreUtils.GetScorer(model.Predictor, testData, Host, testData.Schema); - // REVIEW: Should we somehow allow the user to customize the evaluator? - // By what mechanism should we allow that? - var evalComp = GetEvaluatorSubComponent(); RoleMappedData scoredTestData = new RoleMappedData(scorePipe, GetColumnRoles(testData.Schema, scorePipe.Schema)); - IEvaluator evaluator = evalComp.CreateInstance(Host); + // REVIEW: Should we somehow allow the user to customize the evaluator? + // By what mechanism should we allow that? + IEvaluator evaluator = GetEvaluator(Host); // REVIEW: with the new evaluators, metrics of individual models are no longer // printed to the Console. Consider adding an option on the combiner to print them. // REVIEW: Consider adding an option to the combiner to save a data view diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index ab2dff5045..9165360ccb 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -5,14 +5,15 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Ensemble.EntryPoints; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Ensemble.EntryPoints; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, @@ -26,7 +27,7 @@ namespace Microsoft.ML.Runtime.Ensemble /// A generic ensemble trainer for binary classification. /// </summary> public sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor, - IBinarySubModelSelector, IBinaryOutputCombiner, SignatureBinaryClassifierTrainer>, + IBinarySubModelSelector, IBinaryOutputCombiner>, IModelCombiner<TScalarPredictor, TScalarPredictor> { public const string LoadNameValue = "WeightedEnsemble"; @@ -44,9 +45,18 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportBinaryOutputCombinerFactory OutputCombiner = new MedianFactory(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] + public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors; + + internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors; + public Arguments() { - BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("LinearSVM") }; + BasePredictors = new[] + { + ComponentFactoryUtils.CreateFromFunction( + env => new LinearSvm(env, new LinearSvm.Arguments())) + }; } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 0a350ef8ee..5ebf5cfac0 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Ensemble { using Stopwatch = System.Diagnostics.Stopwatch; - public abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner, TSig> : TrainerBase<TPredictor> + public abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner> : TrainerBase<TPredictor> where TPredictor : class, IPredictorProducing<TOutput> where TSelector : class, ISubModelSelector<TOutput> where TCombiner : class, IOutputCombiner<TOutput> @@ -53,8 +53,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel [TGUI(Label = "Show Sub-Model Metrics")] public bool ShowMetrics; - [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSig>[] BasePredictors; + internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] GetPredictorFactories(); } private const int DefaultNumModels = 50; @@ -78,21 +77,22 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, using (var ch = Host.Start("Init")) { - ch.CheckUserArg(Utils.Size(Args.BasePredictors) > 0, nameof(Args.BasePredictors), "This should have at-least one value"); + var predictorFactories = Args.GetPredictorFactories(); + ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(EnsembleTrainer.Arguments.BasePredictors), "This should have at-least one value"); NumModels = Args.NumModels ?? - (Args.BasePredictors.Length == 1 ? DefaultNumModels : Args.BasePredictors.Length); + (predictorFactories.Length == 1 ? DefaultNumModels : predictorFactories.Length); ch.CheckUserArg(NumModels > 0, nameof(Args.NumModels), "Must be positive, or null to indicate numModels is the number of base predictors"); - if (Utils.Size(Args.BasePredictors) > NumModels) + if (Utils.Size(predictorFactories) > NumModels) ch.Warning("The base predictor count is greater than models count. Some of the base predictors will be ignored."); _subsetSelector = Args.SamplingType.CreateComponent(Host); Trainers = new ITrainer<IPredictorProducing<TOutput>>[NumModels]; for (int i = 0; i < Trainers.Length; i++) - Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host); + Trainers[i] = predictorFactories[i % predictorFactories.Length].CreateComponent(Host); // We infer normalization and calibration preferences from the trainers. However, even if the internal trainers // don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache. Info = new TrainerInfo( diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 4421cd5838..d5c60e24a6 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -12,7 +12,9 @@ using Microsoft.ML.Runtime.Ensemble; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Ensemble.Selector; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments), @@ -28,7 +30,7 @@ namespace Microsoft.ML.Runtime.Ensemble /// </summary> public sealed class MulticlassDataPartitionEnsembleTrainer : EnsembleTrainerBase<VBuffer<Single>, EnsembleMultiClassPredictor, - IMulticlassSubModelSelector, IMultiClassOutputCombiner, SignatureMultiClassClassifierTrainer>, + IMulticlassSubModelSelector, IMultiClassOutputCombiner>, IModelCombiner<TVectorPredictor, TVectorPredictor> { public const string LoadNameValue = "WeightedEnsembleMulticlass"; @@ -45,9 +47,18 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] + public IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictors; + + internal override IComponentFactory<ITrainer<TVectorPredictor>>[] GetPredictorFactories() => BasePredictors; + public Arguments() { - BasePredictors = new[] { new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") }; + BasePredictors = new[] + { + ComponentFactoryUtils.CreateFromFunction( + env => new MulticlassLogisticRegression(env, new MulticlassLogisticRegression.Arguments())) + }; } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 1cc36f20cd..3fe853aff5 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -12,7 +12,9 @@ using Microsoft.ML.Runtime.Ensemble; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Ensemble.Selector; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer) }, @@ -23,7 +25,7 @@ namespace Microsoft.ML.Runtime.Ensemble { using TScalarPredictor = IPredictorProducing<Single>; public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor, - IRegressionSubModelSelector, IRegressionOutputCombiner, SignatureRegressorTrainer>, + IRegressionSubModelSelector, IRegressionOutputCombiner>, IModelCombiner<TScalarPredictor, TScalarPredictor> { public const string LoadNameValue = "EnsembleRegression"; @@ -39,9 +41,18 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportRegressionOutputCombinerFactory OutputCombiner = new MedianFactory(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] + public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors; + + internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors; + public Arguments() { - BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("OnlineGradientDescent") }; + BasePredictors = new[] + { + ComponentFactoryUtils.CreateFromFunction( + env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments())) + }; } } diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 88ff3eaf62..e9f5f326f6 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -672,10 +672,10 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV trainScoreArgs.Trainer = new SubComponent<ITrainer, SignatureTrainer>(args.Trainer.Kind, args.Trainer.Settings); - trainScoreArgs.Scorer = new SimpleComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>( + trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>( (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema)); - var mapperFactory = new SimpleComponentFactory<IPredictor, ISchemaBindableMapper>( + var mapperFactory = ComponentFactoryUtils.CreateFromFunction<IPredictor, ISchemaBindableMapper>( (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor)); var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs index 931ab95999..811f857046 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs @@ -27,10 +27,8 @@ void Metacomponents() var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"); var trainer = new Ova(env, new Ova.Arguments { - PredictorType = new SimpleComponentFactory<ITrainer<IPredictorProducing<float>>> - ( - (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()) - ) + PredictorType = ComponentFactoryUtils.CreateFromFunction( + (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())) }); IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;