From 77fff0381a536a057f88209cdbf1bcc9d993f2f4 Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Wed, 15 Aug 2018 13:20:56 -0500 Subject: [PATCH 1/7] Convert Ensemble stacking SubComponent to IComponentFactory. --- .../EntryPoints/ComponentFactory.cs | 19 +++++++++++++++ .../Microsoft.ML.Ensemble.csproj | 1 + .../OutputCombiners/BaseScalarStacking.cs | 2 +- .../OutputCombiners/BaseStacking.cs | 18 +++++++------- .../OutputCombiners/MultiStacking.cs | 24 ++++++++++++++++--- .../OutputCombiners/RegressionStacking.cs | 18 ++++++++++++-- .../OutputCombiners/Stacking.cs | 19 ++++++++++++--- 7 files changed, 82 insertions(+), 19 deletions(-) diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs index d69a9d0b93..14f339627a 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs @@ -29,6 +29,25 @@ public interface IComponentFactory<out TComponent>: IComponentFactory TComponent CreateComponent(IHostEnvironment env); } + /// <summary> + /// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>) + /// that simply wraps a delegate which creates the component. + /// </summary> + public class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent> + { + private Func<IHostEnvironment, TComponent> _factory; + + public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env) + { + return _factory(env); + } + } + /// <summary> /// An interface for creating a component when we take one extra parameter (and an <see cref="IHostEnvironment"/>). /// </summary> diff --git a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj index ddd4557788..ef0ad01a40 100644 --- a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj +++ b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj @@ -10,6 +10,7 @@ <ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" /> <ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" /> <ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" /> + <ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" /> </ItemGroup> </Project> diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs index a5c9c757a4..9b9900a65f 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs @@ -9,7 +9,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { - public abstract class BaseScalarStacking<TSigBase> : BaseStacking<Single, TSigBase> + public abstract class BaseScalarStacking : BaseStacking<Single> { internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args) : base(env, name, args) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index f30174a31d..311faf2448 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; @@ -15,7 +16,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using ColumnRole = RoleMappedSchema.ColumnRole; - public abstract class BaseStacking<TOutput, TSigBase> : IStackingTrainer<TOutput> + public abstract class BaseStacking<TOutput> : IStackingTrainer<TOutput> { public abstract class ArgumentsBase { @@ -24,13 +25,10 @@ public abstract class ArgumentsBase [TGUI(Label = "Validation Dataset Proportion")] public Single ValidationDatasetProportion = 0.3f; - [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, - Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - [TGUI(Label = "Base predictor")] - public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType; + public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory { get; set; } } - protected readonly SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType; + protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory; protected readonly IHost Host; protected IPredictorProducing<TOutput> Meta; @@ -45,10 +43,10 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args) Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1, nameof(args.ValidationDatasetProportion), "The validation proportion for stacking should be greater than or equal to 0 and less than 1"); - Host.CheckUserArg(args.BasePredictorType.IsGood(), nameof(args.BasePredictorType)); + Host.CheckUserArg(args.BasePredictorFactory != null, nameof(args.BasePredictorFactory)); ValidationDatasetProportion = args.ValidationDatasetProportion; - BasePredictorType = args.BasePredictorType; + BasePredictorFactory = args.BasePredictorFactory; } internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx) @@ -135,7 +133,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models, using (var ch = host.Start("Training stacked model")) { ch.Check(Meta == null, "Train called multiple times"); - ch.Check(BasePredictorType != null); + ch.Check(BasePredictorFactory != null); var maps = new ValueMapper<VBuffer<Single>, TOutput>[models.Count]; for (int i = 0; i < maps.Length; i++) @@ -187,7 +185,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models, var view = bldr.GetDataView(); var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features); - var trainer = BasePredictorType.CreateInstance(host); + var trainer = BasePredictorFactory.CreateComponent(host); if (trainer.Info.NeedNormalization) ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this."); Meta = trainer.Train(rmd); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 2ef74c8169..9b00d3732b 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -8,7 +8,10 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner), @@ -20,7 +23,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TVectorPredictor = IPredictorProducing<VBuffer<Single>>; - public sealed class MultiStacking : BaseStacking<VBuffer<Single>, SignatureMultiClassClassifierTrainer>, ICanSaveModel, IMultiClassOutputCombiner + public sealed class MultiStacking : BaseStacking<VBuffer<Single>>, ICanSaveModel, IMultiClassOutputCombiner { public const string LoadName = "MultiStacking"; public const string LoaderSignature = "MultiStackingCombiner"; @@ -38,13 +41,28 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory<ITrainer<IPredictorProducing<VBuffer<Single>>>> BasePredictorType; + + public override IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorFactory + { + get { return BasePredictorType; } + set { BasePredictorType = value; } + } + public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this); public Arguments() { // REVIEW: Perhaps we can have a better non-parametetric learner. - BasePredictorType = new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>( - "OVA", "p=FastTreeBinaryClassification"); + BasePredictorType = new SimpleComponentFactory<ITrainer<TVectorPredictor>>( + env => new Ova(env, new Ova.Arguments() + { + PredictorType = new SimpleComponentFactory<ITrainer<IPredictorProducing<Single>>>( + e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())) + })); } } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index 0b5f8e6057..5ce1902a30 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -7,6 +7,8 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner), @@ -19,7 +21,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing<Single>; - public sealed class RegressionStacking : BaseScalarStacking<SignatureRegressorTrainer>, IRegressionOutputCombiner, ICanSaveModel + public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel { public const string LoadName = "RegressionStacking"; public const string LoaderSignature = "RegressionStacking"; @@ -37,9 +39,21 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType; + + public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory + { + get { return BasePredictorType; } + set { BasePredictorType = value; } + } + public Arguments() { - BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("FastTreeRegression"); + BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>( + env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments())); } public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index f3481e9936..04b473cd45 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -5,9 +5,10 @@ using System; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; [assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)] @@ -16,7 +17,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing<Single>; - public sealed class Stacking : BaseScalarStacking<SignatureBinaryClassifierTrainer>, IBinaryOutputCombiner, ICanSaveModel + public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel { public const string UserName = "Stacking"; public const string LoadName = "Stacking"; @@ -35,9 +36,21 @@ private static VersionInfo GetVersionInfo() [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory { + [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, + Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] + [TGUI(Label = "Base predictor")] + public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType; + + public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory + { + get { return BasePredictorType; } + set { BasePredictorType = value; } + } + public Arguments() { - BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification"); + BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>( + env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments())); } public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this); From 84119d35a9e33874a902a3db00dc5c6213b6552a Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Wed, 15 Aug 2018 17:27:31 -0500 Subject: [PATCH 2/7] Replace SubComponent with IComponentFactory in ML.Ensemble Working towards #585 --- .../SubModelSelector/BaseSubModelSelector.cs | 11 +++++------ .../Trainer/Binary/EnsembleTrainer.cs | 19 +++++++++++++++++-- .../Trainer/EnsembleTrainerBase.cs | 14 +++++++------- .../MulticlassDataPartitionEnsembleTrainer.cs | 19 +++++++++++++++++-- .../Regression/RegressionEnsembleTrainer.cs | 19 +++++++++++++++++-- 5 files changed, 63 insertions(+), 19 deletions(-) diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index 8465e1a5e8..ac79e400a9 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -54,16 +54,16 @@ public virtual IList<FeatureSubsetModel<IPredictorProducing<TOutput>>> Prune(ILi return models; } - private SubComponent<IEvaluator, SignatureEvaluator> GetEvaluatorSubComponent() + private IEvaluator GetEvaluator(IHostEnvironment env) { switch (PredictionKind) { case PredictionKind.BinaryClassification: - return new SubComponent<IEvaluator, SignatureEvaluator>(BinaryClassifierEvaluator.LoadName); + return new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()); case PredictionKind.Regression: - return new SubComponent<IEvaluator, SignatureEvaluator>(RegressionEvaluator.LoadName); + return new RegressionEvaluator(env, new RegressionEvaluator.Arguments()); case PredictionKind.MultiClassClassification: - return new SubComponent<IEvaluator, SignatureEvaluator>(MultiClassClassifierEvaluator.LoadName); + return new MultiClassClassifierEvaluator(env, new MultiClassClassifierEvaluator.Arguments()); default: throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind); } @@ -83,10 +83,9 @@ public virtual void CalculateMetrics(FeatureSubsetModel<IPredictorProducing<TOut IDataScorerTransform scorePipe = ScoreUtils.GetScorer(model.Predictor, testData, Host, testData.Schema); // REVIEW: Should we somehow allow the user to customize the evaluator? // By what mechanism should we allow that? - var evalComp = GetEvaluatorSubComponent(); RoleMappedData scoredTestData = new RoleMappedData(scorePipe, GetColumnRoles(testData.Schema, scorePipe.Schema)); - IEvaluator evaluator = evalComp.CreateInstance(Host); + IEvaluator evaluator = GetEvaluator(Host); // REVIEW: with the new evaluators, metrics of individual models are no longer // printed to the Console. Consider adding an option on the combiner to print them. // REVIEW: Consider adding an option to the combiner to save a data view diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index ab2dff5045..1ec0c94131 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -13,6 +13,8 @@ using Microsoft.ML.Runtime.Ensemble.Selector; using Microsoft.ML.Ensemble.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, @@ -26,7 +28,7 @@ namespace Microsoft.ML.Runtime.Ensemble /// A generic ensemble trainer for binary classification. /// </summary> public sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor, - IBinarySubModelSelector, IBinaryOutputCombiner, SignatureBinaryClassifierTrainer>, + IBinarySubModelSelector, IBinaryOutputCombiner>, IModelCombiner<TScalarPredictor, TScalarPredictor> { public const string LoadNameValue = "WeightedEnsemble"; @@ -44,9 +46,22 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportBinaryOutputCombinerFactory OutputCombiner = new MedianFactory(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] + public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors; + + public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories + { + get { return BasePredictors; } + set { BasePredictors = value; } + } + public Arguments() { - BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("LinearSVM") }; + BasePredictors = new[] + { + new SimpleComponentFactory<ITrainer<TScalarPredictor>>( + env => new LinearSvm(env, new LinearSvm.Arguments())) + }; } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 0a350ef8ee..1f9aff7169 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Ensemble { using Stopwatch = System.Diagnostics.Stopwatch; - public abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner, TSig> : TrainerBase<TPredictor> + public abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner> : TrainerBase<TPredictor> where TPredictor : class, IPredictorProducing<TOutput> where TSelector : class, ISubModelSelector<TOutput> where TCombiner : class, IOutputCombiner<TOutput> @@ -53,8 +53,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel [TGUI(Label = "Show Sub-Model Metrics")] public bool ShowMetrics; - [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSig>[] BasePredictors; + public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] BasePredictorFactories { get; set; } } private const int DefaultNumModels = 50; @@ -78,21 +77,22 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, using (var ch = Host.Start("Init")) { - ch.CheckUserArg(Utils.Size(Args.BasePredictors) > 0, nameof(Args.BasePredictors), "This should have at-least one value"); + var predictorFactories = Args.BasePredictorFactories; + ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(Args.BasePredictorFactories), "This should have at-least one value"); NumModels = Args.NumModels ?? - (Args.BasePredictors.Length == 1 ? DefaultNumModels : Args.BasePredictors.Length); + (predictorFactories.Length == 1 ? DefaultNumModels : predictorFactories.Length); ch.CheckUserArg(NumModels > 0, nameof(Args.NumModels), "Must be positive, or null to indicate numModels is the number of base predictors"); - if (Utils.Size(Args.BasePredictors) > NumModels) + if (Utils.Size(predictorFactories) > NumModels) ch.Warning("The base predictor count is greater than models count. Some of the base predictors will be ignored."); _subsetSelector = Args.SamplingType.CreateComponent(Host); Trainers = new ITrainer<IPredictorProducing<TOutput>>[NumModels]; for (int i = 0; i < Trainers.Length; i++) - Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host); + Trainers[i] = predictorFactories[i % predictorFactories.Length].CreateComponent(Host); // We infer normalization and calibration preferences from the trainers. However, even if the internal trainers // don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache. Info = new TrainerInfo( diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 4421cd5838..c0a1a48e08 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -12,7 +12,9 @@ using Microsoft.ML.Runtime.Ensemble; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Ensemble.Selector; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments), @@ -28,7 +30,7 @@ namespace Microsoft.ML.Runtime.Ensemble /// </summary> public sealed class MulticlassDataPartitionEnsembleTrainer : EnsembleTrainerBase<VBuffer<Single>, EnsembleMultiClassPredictor, - IMulticlassSubModelSelector, IMultiClassOutputCombiner, SignatureMultiClassClassifierTrainer>, + IMulticlassSubModelSelector, IMultiClassOutputCombiner>, IModelCombiner<TVectorPredictor, TVectorPredictor> { public const string LoadNameValue = "WeightedEnsembleMulticlass"; @@ -45,9 +47,22 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] + public IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictors; + + public override IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictorFactories + { + get { return BasePredictors; } + set { BasePredictors = value; } + } + public Arguments() { - BasePredictors = new[] { new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") }; + BasePredictors = new[] + { + new SimpleComponentFactory<ITrainer<TVectorPredictor>>( + env => new MulticlassLogisticRegression(env, new MulticlassLogisticRegression.Arguments())) + }; } } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 1cc36f20cd..37d98eefd2 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -12,7 +12,9 @@ using Microsoft.ML.Runtime.Ensemble; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Ensemble.Selector; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer) }, @@ -23,7 +25,7 @@ namespace Microsoft.ML.Runtime.Ensemble { using TScalarPredictor = IPredictorProducing<Single>; public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor, - IRegressionSubModelSelector, IRegressionOutputCombiner, SignatureRegressorTrainer>, + IRegressionSubModelSelector, IRegressionOutputCombiner>, IModelCombiner<TScalarPredictor, TScalarPredictor> { public const string LoadNameValue = "EnsembleRegression"; @@ -39,9 +41,22 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportRegressionOutputCombinerFactory OutputCombiner = new MedianFactory(); + [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] + public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors; + + public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories + { + get { return BasePredictors; } + set { BasePredictors = value; } + } + public Arguments() { - BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("OnlineGradientDescent") }; + BasePredictors = new[] + { + new SimpleComponentFactory<ITrainer<TScalarPredictor>>( + env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments())) + }; } } From 7c06f93e9db7dbeab85fc34c5267f2cef5b61a87 Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Thu, 16 Aug 2018 09:54:36 -0500 Subject: [PATCH 3/7] Convert SimpleComponentFactory classes to ComponentFactoryUtils static class. --- .../EntryPoints/ComponentFactory.cs | 90 ++++++++++++------- .../Commands/CrossValidationCommand.cs | 2 +- .../OutputCombiners/MultiStacking.cs | 4 +- .../OutputCombiners/RegressionStacking.cs | 2 +- .../OutputCombiners/Stacking.cs | 2 +- .../Trainer/Binary/EnsembleTrainer.cs | 2 +- .../MulticlassDataPartitionEnsembleTrainer.cs | 2 +- .../Regression/RegressionEnsembleTrainer.cs | 2 +- 8 files changed, 65 insertions(+), 41 deletions(-) diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs index 14f339627a..8dea31c13a 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs @@ -30,57 +30,81 @@ public interface IComponentFactory<out TComponent>: IComponentFactory } /// <summary> - /// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>) - /// that simply wraps a delegate which creates the component. + /// An interface for creating a component when we take one extra parameter (and an <see cref="IHostEnvironment"/>). /// </summary> - public class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent> + public interface IComponentFactory<in TArg1, out TComponent> : IComponentFactory { - private Func<IHostEnvironment, TComponent> _factory; - - public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory) - { - _factory = factory; - } - - public TComponent CreateComponent(IHostEnvironment env) - { - return _factory(env); - } + TComponent CreateComponent(IHostEnvironment env, TArg1 argument1); } /// <summary> - /// An interface for creating a component when we take one extra parameter (and an <see cref="IHostEnvironment"/>). + /// An interface for creating a component when we take two extra parameters (and an <see cref="IHostEnvironment"/>). /// </summary> - public interface IComponentFactory<in TArg1, out TComponent> : IComponentFactory + public interface IComponentFactory<in TArg1, in TArg2, out TComponent> : IComponentFactory { - TComponent CreateComponent(IHostEnvironment env, TArg1 argument1); + TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2); } /// <summary> - /// A class for creating a component when we take one extra parameter - /// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which - /// creates the component. + /// A utility class for creating IComponentFactory instances. /// </summary> - public class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent> + public static class ComponentFactoryUtils { - private Func<IHostEnvironment, TArg1, TComponent> _factory; + /// <summary> + /// Creates a component factory with no extra parameters (other than an <see cref="IHostEnvironment"/>) + /// that simply wraps a delegate which creates the component. + /// </summary> + public static IComponentFactory<TComponent> CreateFromFunction<TComponent>(Func<IHostEnvironment, TComponent> factory) + { + return new SimpleComponentFactory<TComponent>(factory); + } - public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory) + /// <summary> + /// Creates a component factory when we take one extra parameter (and an + /// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component. + /// </summary> + public static IComponentFactory<TArg1, TComponent> CreateFromFunction<TArg1, TComponent>(Func<IHostEnvironment, TArg1, TComponent> factory) { - _factory = factory; + return new SimpleComponentFactory<TArg1, TComponent>(factory); } - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) + /// <summary> + /// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>) + /// that simply wraps a delegate which creates the component. + /// </summary> + private sealed class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent> { - return _factory(env, argument1); + private readonly Func<IHostEnvironment, TComponent> _factory; + + public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env) + { + return _factory(env); + } } - } - /// <summary> - /// An interface for creating a component when we take two extra parameters (and an <see cref="IHostEnvironment"/>). - /// </summary> - public interface IComponentFactory<in TArg1, in TArg2, out TComponent> : IComponentFactory - { - TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2); + /// <summary> + /// A class for creating a component when we take one extra parameter + /// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which + /// creates the component. + /// </summary> + private sealed class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent> + { + private readonly Func<IHostEnvironment, TArg1, TComponent> _factory; + + public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) + { + return _factory(env, argument1); + } + } } } diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index affd949064..1c264e02d2 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -164,7 +164,7 @@ private void RunCore(IChannel ch, string cmd) new[] { new KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>( - "", new SimpleComponentFactory<IDataView, IDataTransform>( + "", ComponentFactoryUtils.CreateFromFunction<IDataView, IDataTransform>( (env, input) => { var args = new GenerateNumberTransform.Arguments(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 9b00d3732b..3038e2b089 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -57,10 +57,10 @@ public override IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorFacto public Arguments() { // REVIEW: Perhaps we can have a better non-parametetric learner. - BasePredictorType = new SimpleComponentFactory<ITrainer<TVectorPredictor>>( + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( env => new Ova(env, new Ova.Arguments() { - PredictorType = new SimpleComponentFactory<ITrainer<IPredictorProducing<Single>>>( + PredictorType = ComponentFactoryUtils.CreateFromFunction( e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())) })); } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index 5ce1902a30..18066c5362 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -52,7 +52,7 @@ public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFacto public Arguments() { - BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>( + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments())); } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index 04b473cd45..6be4ffea64 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -49,7 +49,7 @@ public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFacto public Arguments() { - BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>( + BasePredictorType = ComponentFactoryUtils.CreateFromFunction( env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments())); } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index 1ec0c94131..5cf1f6ea4c 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -59,7 +59,7 @@ public Arguments() { BasePredictors = new[] { - new SimpleComponentFactory<ITrainer<TScalarPredictor>>( + ComponentFactoryUtils.CreateFromFunction( env => new LinearSvm(env, new LinearSvm.Arguments())) }; } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index c0a1a48e08..3183292383 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -60,7 +60,7 @@ public Arguments() { BasePredictors = new[] { - new SimpleComponentFactory<ITrainer<TVectorPredictor>>( + ComponentFactoryUtils.CreateFromFunction( env => new MulticlassLogisticRegression(env, new MulticlassLogisticRegression.Arguments())) }; } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 37d98eefd2..283fe2f32d 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -54,7 +54,7 @@ public Arguments() { BasePredictors = new[] { - new SimpleComponentFactory<ITrainer<TScalarPredictor>>( + ComponentFactoryUtils.CreateFromFunction( env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments())) }; } From d41dbe623d2a78c43ae13380516032158ac6c308 Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Thu, 16 Aug 2018 14:04:52 -0500 Subject: [PATCH 4/7] Move Stacking arguments abstract properties to internal get methods. This reduces the public surface area. --- .../OutputCombiners/BaseStacking.cs | 12 ++++++------ .../OutputCombiners/MultiStacking.cs | 8 ++------ .../OutputCombiners/RegressionStacking.cs | 8 ++------ .../OutputCombiners/Stacking.cs | 8 ++------ .../Trainer/Binary/EnsembleTrainer.cs | 6 +----- .../Trainer/EnsembleTrainerBase.cs | 6 +++--- .../MulticlassDataPartitionEnsembleTrainer.cs | 6 +----- .../Trainer/Regression/RegressionEnsembleTrainer.cs | 6 +----- 8 files changed, 18 insertions(+), 42 deletions(-) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index 311faf2448..c294a1fcde 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -25,10 +25,10 @@ public abstract class ArgumentsBase [TGUI(Label = "Validation Dataset Proportion")] public Single ValidationDatasetProportion = 0.3f; - public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory { get; set; } + internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> GetPredictorFactory(); } - protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory; + protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorType; protected readonly IHost Host; protected IPredictorProducing<TOutput> Meta; @@ -43,10 +43,10 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args) Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1, nameof(args.ValidationDatasetProportion), "The validation proportion for stacking should be greater than or equal to 0 and less than 1"); - Host.CheckUserArg(args.BasePredictorFactory != null, nameof(args.BasePredictorFactory)); ValidationDatasetProportion = args.ValidationDatasetProportion; - BasePredictorFactory = args.BasePredictorFactory; + BasePredictorType = args.GetPredictorFactory(); + Host.CheckValue(BasePredictorType, nameof(BasePredictorType)); } internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx) @@ -133,7 +133,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models, using (var ch = host.Start("Training stacked model")) { ch.Check(Meta == null, "Train called multiple times"); - ch.Check(BasePredictorFactory != null); + ch.Check(BasePredictorType != null); var maps = new ValueMapper<VBuffer<Single>, TOutput>[models.Count]; for (int i = 0; i < maps.Length; i++) @@ -185,7 +185,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models, var view = bldr.GetDataView(); var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features); - var trainer = BasePredictorFactory.CreateComponent(host); + var trainer = BasePredictorType.CreateComponent(host); if (trainer.Info.NeedNormalization) ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this."); Meta = trainer.Train(rmd); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 3038e2b089..40d59b631a 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -44,13 +44,9 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] [TGUI(Label = "Base predictor")] - public IComponentFactory<ITrainer<IPredictorProducing<VBuffer<Single>>>> BasePredictorType; + public IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorType; - public override IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorFactory - { - get { return BasePredictorType; } - set { BasePredictorType = value; } - } + internal override IComponentFactory<ITrainer<TVectorPredictor>> GetPredictorFactory() => BasePredictorType; public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index 18066c5362..436e365b79 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -42,13 +42,9 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] [TGUI(Label = "Base predictor")] - public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType; + public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType; - public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory - { - get { return BasePredictorType; } - set { BasePredictorType = value; } - } + internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType; public Arguments() { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index 6be4ffea64..f0ced9d947 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -39,13 +39,9 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] [TGUI(Label = "Base predictor")] - public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType; + public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType; - public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory - { - get { return BasePredictorType; } - set { BasePredictorType = value; } - } + internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType; public Arguments() { diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index 5cf1f6ea4c..6a4b151fab 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -49,11 +49,7 @@ public sealed class Arguments : ArgumentsBase [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors; - public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories - { - get { return BasePredictors; } - set { BasePredictors = value; } - } + internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors; public Arguments() { diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 1f9aff7169..5ebf5cfac0 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -53,7 +53,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel [TGUI(Label = "Show Sub-Model Metrics")] public bool ShowMetrics; - public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] BasePredictorFactories { get; set; } + internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] GetPredictorFactories(); } private const int DefaultNumModels = 50; @@ -77,8 +77,8 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, using (var ch = Host.Start("Init")) { - var predictorFactories = Args.BasePredictorFactories; - ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(Args.BasePredictorFactories), "This should have at-least one value"); + var predictorFactories = Args.GetPredictorFactories(); + ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(EnsembleTrainer.Arguments.BasePredictors), "This should have at-least one value"); NumModels = Args.NumModels ?? (predictorFactories.Length == 1 ? DefaultNumModels : predictorFactories.Length); diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 3183292383..d5c60e24a6 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -50,11 +50,7 @@ public sealed class Arguments : ArgumentsBase [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] public IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictors; - public override IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictorFactories - { - get { return BasePredictors; } - set { BasePredictors = value; } - } + internal override IComponentFactory<ITrainer<TVectorPredictor>>[] GetPredictorFactories() => BasePredictors; public Arguments() { diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 283fe2f32d..3fe853aff5 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -44,11 +44,7 @@ public sealed class Arguments : ArgumentsBase [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors; - public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories - { - get { return BasePredictors; } - set { BasePredictors = value; } - } + internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors; public Arguments() { From 2beb2ce21da3e0b526614e3e9d71cdc7b869c8d4 Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Fri, 17 Aug 2018 10:43:52 -0500 Subject: [PATCH 5/7] PR feedback --- src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs index cb1c1fbaf6..1f8d5b0e8f 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs @@ -54,7 +54,7 @@ public interface IComponentFactory<in TArg1, in TArg2, in TArg3, out TComponent> } /// <summary> - /// A utility class for creating IComponentFactory instances. + /// A utility class for creating <see cref="IComponentFactory"/> instances. /// </summary> public static class ComponentFactoryUtils { From 4923971fbde6ee6eecf35eed4fd046a12c521d24 Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Fri, 17 Aug 2018 14:33:36 -0500 Subject: [PATCH 6/7] PR Feedback --- .../Selector/SubModelSelector/BaseSubModelSelector.cs | 4 ++-- src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index ac79e400a9..513ad95139 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -81,10 +81,10 @@ public virtual void CalculateMetrics(FeatureSubsetModel<IPredictorProducing<TOut // Because the training and test datasets are drawn from the same base dataset, the test data role mappings // are the same as for the train data. IDataScorerTransform scorePipe = ScoreUtils.GetScorer(model.Predictor, testData, Host, testData.Schema); - // REVIEW: Should we somehow allow the user to customize the evaluator? - // By what mechanism should we allow that? RoleMappedData scoredTestData = new RoleMappedData(scorePipe, GetColumnRoles(testData.Schema, scorePipe.Schema)); + // REVIEW: Should we somehow allow the user to customize the evaluator? + // By what mechanism should we allow that? IEvaluator evaluator = GetEvaluator(Host); // REVIEW: with the new evaluators, metrics of individual models are no longer // printed to the Console. Consider adding an option on the combiner to print them. diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index 6a4b151fab..9165360ccb 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -5,15 +5,14 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Ensemble.EntryPoints; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.Ensemble.Selector; -using Microsoft.ML.Ensemble.EntryPoints; -using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Learners; [assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), From 78a4178d1905dd2be2b9209ed4eb57c6b1f55ea2 Mon Sep 17 00:00:00 2001 From: Eric Erhardt <eric.erhardt@microsoft.com> Date: Fri, 17 Aug 2018 16:43:50 -0500 Subject: [PATCH 7/7] Trigger