From 77fff0381a536a057f88209cdbf1bcc9d993f2f4 Mon Sep 17 00:00:00 2001
From: Eric Erhardt <eric.erhardt@microsoft.com>
Date: Wed, 15 Aug 2018 13:20:56 -0500
Subject: [PATCH 1/7] Convert Ensemble stacking SubComponent to
 IComponentFactory.

---
 .../EntryPoints/ComponentFactory.cs           | 19 +++++++++++++++
 .../Microsoft.ML.Ensemble.csproj              |  1 +
 .../OutputCombiners/BaseScalarStacking.cs     |  2 +-
 .../OutputCombiners/BaseStacking.cs           | 18 +++++++-------
 .../OutputCombiners/MultiStacking.cs          | 24 ++++++++++++++++---
 .../OutputCombiners/RegressionStacking.cs     | 18 ++++++++++++--
 .../OutputCombiners/Stacking.cs               | 19 ++++++++++++---
 7 files changed, 82 insertions(+), 19 deletions(-)

diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
index d69a9d0b93..14f339627a 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
@@ -29,6 +29,25 @@ public interface IComponentFactory<out TComponent>: IComponentFactory
         TComponent CreateComponent(IHostEnvironment env);
     }
 
+    /// <summary>
+    /// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>)
+    /// that simply wraps a delegate which creates the component.
+    /// </summary>
+    public class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent>
+    {
+        private Func<IHostEnvironment, TComponent> _factory;
+
+        public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory)
+        {
+            _factory = factory;
+        }
+
+        public TComponent CreateComponent(IHostEnvironment env)
+        {
+            return _factory(env);
+        }
+    }
+
     /// <summary>
     /// An interface for creating a component when we take one extra parameter (and an <see cref="IHostEnvironment"/>).
     /// </summary>
diff --git a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj
index ddd4557788..ef0ad01a40 100644
--- a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj
+++ b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj
@@ -10,6 +10,7 @@
     <ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
     <ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
     <ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
+    <ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
   </ItemGroup>
 
 </Project>
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
index a5c9c757a4..9b9900a65f 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
@@ -9,7 +9,7 @@
 
 namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
 {
-    public abstract class BaseScalarStacking<TSigBase> : BaseStacking<Single, TSigBase>
+    public abstract class BaseScalarStacking : BaseStacking<Single>
     {
         internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args)
             : base(env, name, args)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
index f30174a31d..311faf2448 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
@@ -7,6 +7,7 @@
 using System.Threading.Tasks;
 using Microsoft.ML.Runtime.CommandLine;
 using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
 using Microsoft.ML.Runtime.Internal.Internallearn;
 using Microsoft.ML.Runtime.Internal.Utilities;
 using Microsoft.ML.Runtime.Model;
@@ -15,7 +16,7 @@
 namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
 {
     using ColumnRole = RoleMappedSchema.ColumnRole;
-    public abstract class BaseStacking<TOutput, TSigBase> : IStackingTrainer<TOutput>
+    public abstract class BaseStacking<TOutput> : IStackingTrainer<TOutput>
     {
         public abstract class ArgumentsBase
         {
@@ -24,13 +25,10 @@ public abstract class ArgumentsBase
             [TGUI(Label = "Validation Dataset Proportion")]
             public Single ValidationDatasetProportion = 0.3f;
 
-            [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
-                Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
-            [TGUI(Label = "Base predictor")]
-            public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
+            public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory { get; set; }
         }
 
-        protected readonly SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
+        protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory;
         protected readonly IHost Host;
         protected IPredictorProducing<TOutput> Meta;
 
@@ -45,10 +43,10 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args)
             Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1,
                     nameof(args.ValidationDatasetProportion),
                     "The validation proportion for stacking should be greater than or equal to 0 and less than 1");
-            Host.CheckUserArg(args.BasePredictorType.IsGood(), nameof(args.BasePredictorType));
+            Host.CheckUserArg(args.BasePredictorFactory != null, nameof(args.BasePredictorFactory));
 
             ValidationDatasetProportion = args.ValidationDatasetProportion;
-            BasePredictorType = args.BasePredictorType;
+            BasePredictorFactory = args.BasePredictorFactory;
         }
 
         internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
@@ -135,7 +133,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
             using (var ch = host.Start("Training stacked model"))
             {
                 ch.Check(Meta == null, "Train called multiple times");
-                ch.Check(BasePredictorType != null);
+                ch.Check(BasePredictorFactory != null);
 
                 var maps = new ValueMapper<VBuffer<Single>, TOutput>[models.Count];
                 for (int i = 0; i < maps.Length; i++)
@@ -187,7 +185,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
                 var view = bldr.GetDataView();
                 var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
 
-                var trainer = BasePredictorType.CreateInstance(host);
+                var trainer = BasePredictorFactory.CreateComponent(host);
                 if (trainer.Info.NeedNormalization)
                     ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
                 Meta = trainer.Train(rmd);
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
index 2ef74c8169..9b00d3732b 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
@@ -8,7 +8,10 @@
 using Microsoft.ML.Runtime.Data;
 using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
 using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.FastTree;
+using Microsoft.ML.Runtime.Internal.Internallearn;
 using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Learners;
 using Microsoft.ML.Runtime.Model;
 
 [assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner),
@@ -20,7 +23,7 @@
 namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
 {
     using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
-    public sealed class MultiStacking : BaseStacking<VBuffer<Single>, SignatureMultiClassClassifierTrainer>, ICanSaveModel, IMultiClassOutputCombiner
+    public sealed class MultiStacking : BaseStacking<VBuffer<Single>>, ICanSaveModel, IMultiClassOutputCombiner
     {
         public const string LoadName = "MultiStacking";
         public const string LoaderSignature = "MultiStackingCombiner";
@@ -38,13 +41,28 @@ private static VersionInfo GetVersionInfo()
         [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
         public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
         {
+            [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
+                Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
+            [TGUI(Label = "Base predictor")]
+            public IComponentFactory<ITrainer<IPredictorProducing<VBuffer<Single>>>> BasePredictorType;
+
+            public override IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorFactory
+            {
+                get { return BasePredictorType; }
+                set { BasePredictorType = value; }
+            }
+
             public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);
 
             public Arguments()
             {
                 // REVIEW: Perhaps we can have a better non-parametetric learner.
-                BasePredictorType = new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>(
-                    "OVA", "p=FastTreeBinaryClassification");
+                BasePredictorType = new SimpleComponentFactory<ITrainer<TVectorPredictor>>(
+                    env => new Ova(env, new Ova.Arguments()
+                    {
+                        PredictorType = new SimpleComponentFactory<ITrainer<IPredictorProducing<Single>>>(
+                            e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
+                    }));
             }
         }
 
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
index 0b5f8e6057..5ce1902a30 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
@@ -7,6 +7,8 @@
 using Microsoft.ML.Runtime.Data;
 using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
 using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.FastTree;
+using Microsoft.ML.Runtime.Internal.Internallearn;
 using Microsoft.ML.Runtime.Model;
 
 [assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner),
@@ -19,7 +21,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
 {
     using TScalarPredictor = IPredictorProducing<Single>;
 
-    public sealed class RegressionStacking : BaseScalarStacking<SignatureRegressorTrainer>, IRegressionOutputCombiner, ICanSaveModel
+    public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel
     {
         public const string LoadName = "RegressionStacking";
         public const string LoaderSignature = "RegressionStacking";
@@ -37,9 +39,21 @@ private static VersionInfo GetVersionInfo()
         [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
         public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory
         {
+            [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
+                Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
+            [TGUI(Label = "Base predictor")]
+            public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType;
+
+            public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory
+            {
+                get { return BasePredictorType; }
+                set { BasePredictorType = value; }
+            }
+
             public Arguments()
             {
-                BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("FastTreeRegression");
+                BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
+                    env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments()));
             }
 
             public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
index f3481e9936..04b473cd45 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
@@ -5,9 +5,10 @@
 using System;
 using Microsoft.ML.Runtime;
 using Microsoft.ML.Runtime.CommandLine;
-using Microsoft.ML.Runtime.Data;
 using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
 using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.FastTree;
+using Microsoft.ML.Runtime.Internal.Internallearn;
 using Microsoft.ML.Runtime.Model;
 
 [assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)]
@@ -16,7 +17,7 @@
 namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
 {
     using TScalarPredictor = IPredictorProducing<Single>;
-    public sealed class Stacking : BaseScalarStacking<SignatureBinaryClassifierTrainer>, IBinaryOutputCombiner, ICanSaveModel
+    public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel
     {
         public const string UserName = "Stacking";
         public const string LoadName = "Stacking";
@@ -35,9 +36,21 @@ private static VersionInfo GetVersionInfo()
         [TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
         public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory
         {
+            [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
+                Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
+            [TGUI(Label = "Base predictor")]
+            public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType;
+
+            public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory
+            {
+                get { return BasePredictorType; }
+                set { BasePredictorType = value; }
+            }
+
             public Arguments()
             {
-                BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification");
+                BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
+                    env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()));
             }
 
             public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);

From 84119d35a9e33874a902a3db00dc5c6213b6552a Mon Sep 17 00:00:00 2001
From: Eric Erhardt <eric.erhardt@microsoft.com>
Date: Wed, 15 Aug 2018 17:27:31 -0500
Subject: [PATCH 2/7] Replace SubComponent with IComponentFactory in
 ML.Ensemble

Working towards #585
---
 .../SubModelSelector/BaseSubModelSelector.cs  | 11 +++++------
 .../Trainer/Binary/EnsembleTrainer.cs         | 19 +++++++++++++++++--
 .../Trainer/EnsembleTrainerBase.cs            | 14 +++++++-------
 .../MulticlassDataPartitionEnsembleTrainer.cs | 19 +++++++++++++++++--
 .../Regression/RegressionEnsembleTrainer.cs   | 19 +++++++++++++++++--
 5 files changed, 63 insertions(+), 19 deletions(-)

diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
index 8465e1a5e8..ac79e400a9 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
@@ -54,16 +54,16 @@ public virtual IList<FeatureSubsetModel<IPredictorProducing<TOutput>>> Prune(ILi
             return models;
         }
 
-        private SubComponent<IEvaluator, SignatureEvaluator> GetEvaluatorSubComponent()
+        private IEvaluator GetEvaluator(IHostEnvironment env)
         {
             switch (PredictionKind)
             {
                 case PredictionKind.BinaryClassification:
-                    return new SubComponent<IEvaluator, SignatureEvaluator>(BinaryClassifierEvaluator.LoadName);
+                    return new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments());
                 case PredictionKind.Regression:
-                    return new SubComponent<IEvaluator, SignatureEvaluator>(RegressionEvaluator.LoadName);
+                    return new RegressionEvaluator(env, new RegressionEvaluator.Arguments());
                 case PredictionKind.MultiClassClassification:
-                    return new SubComponent<IEvaluator, SignatureEvaluator>(MultiClassClassifierEvaluator.LoadName);
+                    return new MultiClassClassifierEvaluator(env, new MultiClassClassifierEvaluator.Arguments());
                 default:
                     throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind);
             }
@@ -83,10 +83,9 @@ public virtual void CalculateMetrics(FeatureSubsetModel<IPredictorProducing<TOut
                 IDataScorerTransform scorePipe = ScoreUtils.GetScorer(model.Predictor, testData, Host, testData.Schema);
                 // REVIEW: Should we somehow allow the user to customize the evaluator?
                 // By what mechanism should we allow that?
-                var evalComp = GetEvaluatorSubComponent();
                 RoleMappedData scoredTestData = new RoleMappedData(scorePipe,
                     GetColumnRoles(testData.Schema, scorePipe.Schema));
-                IEvaluator evaluator = evalComp.CreateInstance(Host);
+                IEvaluator evaluator = GetEvaluator(Host);
                 // REVIEW: with the new evaluators, metrics of individual models are no longer
                 // printed to the Console. Consider adding an option on the combiner to print them.
                 // REVIEW: Consider adding an option to the combiner to save a data view
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
index ab2dff5045..1ec0c94131 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
@@ -13,6 +13,8 @@
 using Microsoft.ML.Runtime.Ensemble.Selector;
 using Microsoft.ML.Ensemble.EntryPoints;
 using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Learners;
 
 [assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments),
     new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
@@ -26,7 +28,7 @@ namespace Microsoft.ML.Runtime.Ensemble
     /// A generic ensemble trainer for binary classification.
     /// </summary>
     public sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
-        IBinarySubModelSelector, IBinaryOutputCombiner, SignatureBinaryClassifierTrainer>,
+        IBinarySubModelSelector, IBinaryOutputCombiner>,
         IModelCombiner<TScalarPredictor, TScalarPredictor>
     {
         public const string LoadNameValue = "WeightedEnsemble";
@@ -44,9 +46,22 @@ public sealed class Arguments : ArgumentsBase
             [TGUI(Label = "Output combiner", Description = "Output combiner type")]
             public ISupportBinaryOutputCombinerFactory OutputCombiner = new MedianFactory();
 
+            [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
+            public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors;
+
+            public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories
+            {
+                get { return BasePredictors; }
+                set { BasePredictors = value; }
+            }
+
             public Arguments()
             {
-                BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("LinearSVM") };
+                BasePredictors = new[]
+                {
+                    new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
+                        env => new LinearSvm(env, new LinearSvm.Arguments()))
+                };
             }
         }
 
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
index 0a350ef8ee..1f9aff7169 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
@@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Ensemble
 {
     using Stopwatch = System.Diagnostics.Stopwatch;
 
-    public abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner, TSig> : TrainerBase<TPredictor>
+    public abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner> : TrainerBase<TPredictor>
          where TPredictor : class, IPredictorProducing<TOutput>
          where TSelector : class, ISubModelSelector<TOutput>
          where TCombiner : class, IOutputCombiner<TOutput>
@@ -53,8 +53,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel
             [TGUI(Label = "Show Sub-Model Metrics")]
             public bool ShowMetrics;
 
-            [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
-            public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSig>[] BasePredictors;
+            public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] BasePredictorFactories { get; set; }
         }
 
         private const int DefaultNumModels = 50;
@@ -78,21 +77,22 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,
 
             using (var ch = Host.Start("Init"))
             {
-                ch.CheckUserArg(Utils.Size(Args.BasePredictors) > 0, nameof(Args.BasePredictors), "This should have at-least one value");
+                var predictorFactories = Args.BasePredictorFactories;
+                ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(Args.BasePredictorFactories), "This should have at-least one value");
 
                 NumModels = Args.NumModels ??
-                    (Args.BasePredictors.Length == 1 ? DefaultNumModels : Args.BasePredictors.Length);
+                    (predictorFactories.Length == 1 ? DefaultNumModels : predictorFactories.Length);
 
                 ch.CheckUserArg(NumModels > 0, nameof(Args.NumModels), "Must be positive, or null to indicate numModels is the number of base predictors");
 
-                if (Utils.Size(Args.BasePredictors) > NumModels)
+                if (Utils.Size(predictorFactories) > NumModels)
                     ch.Warning("The base predictor count is greater than models count. Some of the base predictors will be ignored.");
 
                 _subsetSelector = Args.SamplingType.CreateComponent(Host);
 
                 Trainers = new ITrainer<IPredictorProducing<TOutput>>[NumModels];
                 for (int i = 0; i < Trainers.Length; i++)
-                    Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host);
+                    Trainers[i] = predictorFactories[i % predictorFactories.Length].CreateComponent(Host);
                 // We infer normalization and calibration preferences from the trainers. However, even if the internal trainers
                 // don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache.
                 Info = new TrainerInfo(
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index 4421cd5838..c0a1a48e08 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -12,7 +12,9 @@
 using Microsoft.ML.Runtime.Ensemble;
 using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
 using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.EntryPoints;
 using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Learners;
 
 [assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer),
     typeof(MulticlassDataPartitionEnsembleTrainer.Arguments),
@@ -28,7 +30,7 @@ namespace Microsoft.ML.Runtime.Ensemble
     /// </summary>
     public sealed class MulticlassDataPartitionEnsembleTrainer :
         EnsembleTrainerBase<VBuffer<Single>, EnsembleMultiClassPredictor,
-        IMulticlassSubModelSelector, IMultiClassOutputCombiner, SignatureMultiClassClassifierTrainer>,
+        IMulticlassSubModelSelector, IMultiClassOutputCombiner>,
         IModelCombiner<TVectorPredictor, TVectorPredictor>
     {
         public const string LoadNameValue = "WeightedEnsembleMulticlass";
@@ -45,9 +47,22 @@ public sealed class Arguments : ArgumentsBase
             [TGUI(Label = "Output combiner", Description = "Output combiner type")]
             public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments();
 
+            [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
+            public IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictors;
+
+            public override IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictorFactories
+            {
+                get { return BasePredictors; }
+                set { BasePredictors = value; }
+            }
+
             public Arguments()
             {
-                BasePredictors = new[] { new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") };
+                BasePredictors = new[]
+                {
+                    new SimpleComponentFactory<ITrainer<TVectorPredictor>>(
+                        env => new MulticlassLogisticRegression(env, new MulticlassLogisticRegression.Arguments()))
+                };
             }
         }
 
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
index 1cc36f20cd..37d98eefd2 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
@@ -12,7 +12,9 @@
 using Microsoft.ML.Runtime.Ensemble;
 using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
 using Microsoft.ML.Runtime.Ensemble.Selector;
+using Microsoft.ML.Runtime.EntryPoints;
 using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Learners;
 
 [assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments),
     new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer) },
@@ -23,7 +25,7 @@ namespace Microsoft.ML.Runtime.Ensemble
 {
     using TScalarPredictor = IPredictorProducing<Single>;
     public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
-       IRegressionSubModelSelector, IRegressionOutputCombiner, SignatureRegressorTrainer>,
+       IRegressionSubModelSelector, IRegressionOutputCombiner>,
        IModelCombiner<TScalarPredictor, TScalarPredictor>
     {
         public const string LoadNameValue = "EnsembleRegression";
@@ -39,9 +41,22 @@ public sealed class Arguments : ArgumentsBase
             [TGUI(Label = "Output combiner", Description = "Output combiner type")]
             public ISupportRegressionOutputCombinerFactory OutputCombiner = new MedianFactory();
 
+            [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
+            public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors;
+
+            public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories
+            {
+                get { return BasePredictors; }
+                set { BasePredictors = value; }
+            }
+
             public Arguments()
             {
-                BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("OnlineGradientDescent") };
+                BasePredictors = new[]
+                {
+                    new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
+                        env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments()))
+                };
             }
         }
 

From 7c06f93e9db7dbeab85fc34c5267f2cef5b61a87 Mon Sep 17 00:00:00 2001
From: Eric Erhardt <eric.erhardt@microsoft.com>
Date: Thu, 16 Aug 2018 09:54:36 -0500
Subject: [PATCH 3/7] Convert SimpleComponentFactory classes to
 ComponentFactoryUtils static class.

---
 .../EntryPoints/ComponentFactory.cs           | 90 ++++++++++++-------
 .../Commands/CrossValidationCommand.cs        |  2 +-
 .../OutputCombiners/MultiStacking.cs          |  4 +-
 .../OutputCombiners/RegressionStacking.cs     |  2 +-
 .../OutputCombiners/Stacking.cs               |  2 +-
 .../Trainer/Binary/EnsembleTrainer.cs         |  2 +-
 .../MulticlassDataPartitionEnsembleTrainer.cs |  2 +-
 .../Regression/RegressionEnsembleTrainer.cs   |  2 +-
 8 files changed, 65 insertions(+), 41 deletions(-)

diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
index 14f339627a..8dea31c13a 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
@@ -30,57 +30,81 @@ public interface IComponentFactory<out TComponent>: IComponentFactory
     }
 
     /// <summary>
-    /// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>)
-    /// that simply wraps a delegate which creates the component.
+    /// An interface for creating a component when we take one extra parameter (and an <see cref="IHostEnvironment"/>).
     /// </summary>
-    public class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent>
+    public interface IComponentFactory<in TArg1, out TComponent> : IComponentFactory
     {
-        private Func<IHostEnvironment, TComponent> _factory;
-
-        public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory)
-        {
-            _factory = factory;
-        }
-
-        public TComponent CreateComponent(IHostEnvironment env)
-        {
-            return _factory(env);
-        }
+        TComponent CreateComponent(IHostEnvironment env, TArg1 argument1);
     }
 
     /// <summary>
-    /// An interface for creating a component when we take one extra parameter (and an <see cref="IHostEnvironment"/>).
+    /// An interface for creating a component when we take two extra parameters (and an <see cref="IHostEnvironment"/>).
     /// </summary>
-    public interface IComponentFactory<in TArg1, out TComponent> : IComponentFactory
+    public interface IComponentFactory<in TArg1, in TArg2, out TComponent> : IComponentFactory
     {
-        TComponent CreateComponent(IHostEnvironment env, TArg1 argument1);
+        TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2);
     }
 
     /// <summary>
-    /// A class for creating a component when we take one extra parameter
-    /// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
-    /// creates the component.
+    /// A utility class for creating IComponentFactory instances.
     /// </summary>
-    public class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent>
+    public static class ComponentFactoryUtils
     {
-        private Func<IHostEnvironment, TArg1, TComponent> _factory;
+        /// <summary>
+        /// Creates a component factory with no extra parameters (other than an <see cref="IHostEnvironment"/>)
+        /// that simply wraps a delegate which creates the component.
+        /// </summary>
+        public static IComponentFactory<TComponent> CreateFromFunction<TComponent>(Func<IHostEnvironment, TComponent> factory)
+        {
+            return new SimpleComponentFactory<TComponent>(factory);
+        }
 
-        public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory)
+        /// <summary>
+        /// Creates a component factory when we take one extra parameter (and an
+        /// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component.
+        /// </summary>
+        public static IComponentFactory<TArg1, TComponent> CreateFromFunction<TArg1, TComponent>(Func<IHostEnvironment, TArg1, TComponent> factory)
         {
-            _factory = factory;
+            return new SimpleComponentFactory<TArg1, TComponent>(factory);
         }
 
-        public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
+        /// <summary>
+        /// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>)
+        /// that simply wraps a delegate which creates the component.
+        /// </summary>
+        private sealed class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent>
         {
-            return _factory(env, argument1);
+            private readonly Func<IHostEnvironment, TComponent> _factory;
+
+            public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory)
+            {
+                _factory = factory;
+            }
+
+            public TComponent CreateComponent(IHostEnvironment env)
+            {
+                return _factory(env);
+            }
         }
-    }
 
-    /// <summary>
-    /// An interface for creating a component when we take two extra parameters (and an <see cref="IHostEnvironment"/>).
-    /// </summary>
-    public interface IComponentFactory<in TArg1, in TArg2, out TComponent> : IComponentFactory
-    {
-        TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2);
+        /// <summary>
+        /// A class for creating a component when we take one extra parameter
+        /// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
+        /// creates the component.
+        /// </summary>
+        private sealed class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent>
+        {
+            private readonly Func<IHostEnvironment, TArg1, TComponent> _factory;
+
+            public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory)
+            {
+                _factory = factory;
+            }
+
+            public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
+            {
+                return _factory(env, argument1);
+            }
+        }
     }
 }
diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
index affd949064..1c264e02d2 100644
--- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
@@ -164,7 +164,7 @@ private void RunCore(IChannel ch, string cmd)
                         new[]
                         {
                             new KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>(
-                                "", new SimpleComponentFactory<IDataView, IDataTransform>(
+                                "", ComponentFactoryUtils.CreateFromFunction<IDataView, IDataTransform>(
                                     (env, input) =>
                                     {
                                         var args = new GenerateNumberTransform.Arguments();
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
index 9b00d3732b..3038e2b089 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
@@ -57,10 +57,10 @@ public override IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorFacto
             public Arguments()
             {
                 // REVIEW: Perhaps we can have a better non-parametetric learner.
-                BasePredictorType = new SimpleComponentFactory<ITrainer<TVectorPredictor>>(
+                BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
                     env => new Ova(env, new Ova.Arguments()
                     {
-                        PredictorType = new SimpleComponentFactory<ITrainer<IPredictorProducing<Single>>>(
+                        PredictorType = ComponentFactoryUtils.CreateFromFunction(
                             e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
                     }));
             }
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
index 5ce1902a30..18066c5362 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
@@ -52,7 +52,7 @@ public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFacto
 
             public Arguments()
             {
-                BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
+                BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
                     env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments()));
             }
 
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
index 04b473cd45..6be4ffea64 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
@@ -49,7 +49,7 @@ public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFacto
 
             public Arguments()
             {
-                BasePredictorType = new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
+                BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
                     env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()));
             }
 
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
index 1ec0c94131..5cf1f6ea4c 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
@@ -59,7 +59,7 @@ public Arguments()
             {
                 BasePredictors = new[]
                 {
-                    new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
+                    ComponentFactoryUtils.CreateFromFunction(
                         env => new LinearSvm(env, new LinearSvm.Arguments()))
                 };
             }
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index c0a1a48e08..3183292383 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -60,7 +60,7 @@ public Arguments()
             {
                 BasePredictors = new[]
                 {
-                    new SimpleComponentFactory<ITrainer<TVectorPredictor>>(
+                    ComponentFactoryUtils.CreateFromFunction(
                         env => new MulticlassLogisticRegression(env, new MulticlassLogisticRegression.Arguments()))
                 };
             }
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
index 37d98eefd2..283fe2f32d 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
@@ -54,7 +54,7 @@ public Arguments()
             {
                 BasePredictors = new[]
                 {
-                    new SimpleComponentFactory<ITrainer<TScalarPredictor>>(
+                    ComponentFactoryUtils.CreateFromFunction(
                         env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments()))
                 };
             }

From d41dbe623d2a78c43ae13380516032158ac6c308 Mon Sep 17 00:00:00 2001
From: Eric Erhardt <eric.erhardt@microsoft.com>
Date: Thu, 16 Aug 2018 14:04:52 -0500
Subject: [PATCH 4/7] Move Stacking arguments abstract properties to internal
 get methods.

This reduces the public surface area.
---
 .../OutputCombiners/BaseStacking.cs                  | 12 ++++++------
 .../OutputCombiners/MultiStacking.cs                 |  8 ++------
 .../OutputCombiners/RegressionStacking.cs            |  8 ++------
 .../OutputCombiners/Stacking.cs                      |  8 ++------
 .../Trainer/Binary/EnsembleTrainer.cs                |  6 +-----
 .../Trainer/EnsembleTrainerBase.cs                   |  6 +++---
 .../MulticlassDataPartitionEnsembleTrainer.cs        |  6 +-----
 .../Trainer/Regression/RegressionEnsembleTrainer.cs  |  6 +-----
 8 files changed, 18 insertions(+), 42 deletions(-)

diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
index 311faf2448..c294a1fcde 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
@@ -25,10 +25,10 @@ public abstract class ArgumentsBase
             [TGUI(Label = "Validation Dataset Proportion")]
             public Single ValidationDatasetProportion = 0.3f;
 
-            public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory { get; set; }
+            internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> GetPredictorFactory();
         }
 
-        protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorFactory;
+        protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorType;
         protected readonly IHost Host;
         protected IPredictorProducing<TOutput> Meta;
 
@@ -43,10 +43,10 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args)
             Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1,
                     nameof(args.ValidationDatasetProportion),
                     "The validation proportion for stacking should be greater than or equal to 0 and less than 1");
-            Host.CheckUserArg(args.BasePredictorFactory != null, nameof(args.BasePredictorFactory));
 
             ValidationDatasetProportion = args.ValidationDatasetProportion;
-            BasePredictorFactory = args.BasePredictorFactory;
+            BasePredictorType = args.GetPredictorFactory();
+            Host.CheckValue(BasePredictorType, nameof(BasePredictorType));
         }
 
         internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
@@ -133,7 +133,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
             using (var ch = host.Start("Training stacked model"))
             {
                 ch.Check(Meta == null, "Train called multiple times");
-                ch.Check(BasePredictorFactory != null);
+                ch.Check(BasePredictorType != null);
 
                 var maps = new ValueMapper<VBuffer<Single>, TOutput>[models.Count];
                 for (int i = 0; i < maps.Length; i++)
@@ -185,7 +185,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
                 var view = bldr.GetDataView();
                 var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
 
-                var trainer = BasePredictorFactory.CreateComponent(host);
+                var trainer = BasePredictorType.CreateComponent(host);
                 if (trainer.Info.NeedNormalization)
                     ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
                 Meta = trainer.Train(rmd);
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
index 3038e2b089..40d59b631a 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
@@ -44,13 +44,9 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF
             [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
                 Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
             [TGUI(Label = "Base predictor")]
-            public IComponentFactory<ITrainer<IPredictorProducing<VBuffer<Single>>>> BasePredictorType;
+            public IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorType;
 
-            public override IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorFactory
-            {
-                get { return BasePredictorType; }
-                set { BasePredictorType = value; }
-            }
+            internal override IComponentFactory<ITrainer<TVectorPredictor>> GetPredictorFactory() => BasePredictorType;
 
             public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);
 
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
index 18066c5362..436e365b79 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
@@ -42,13 +42,9 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF
             [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
                 Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
             [TGUI(Label = "Base predictor")]
-            public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType;
+            public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType;
 
-            public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory
-            {
-                get { return BasePredictorType; }
-                set { BasePredictorType = value; }
-            }
+            internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;
 
             public Arguments()
             {
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
index 6be4ffea64..f0ced9d947 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
@@ -39,13 +39,9 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto
             [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
                 Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
             [TGUI(Label = "Base predictor")]
-            public IComponentFactory<ITrainer<IPredictorProducing<Single>>> BasePredictorType;
+            public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType;
 
-            public override IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorFactory
-            {
-                get { return BasePredictorType; }
-                set { BasePredictorType = value; }
-            }
+            internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;
 
             public Arguments()
             {
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
index 5cf1f6ea4c..6a4b151fab 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
@@ -49,11 +49,7 @@ public sealed class Arguments : ArgumentsBase
             [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
             public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors;
 
-            public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories
-            {
-                get { return BasePredictors; }
-                set { BasePredictors = value; }
-            }
+            internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors;
 
             public Arguments()
             {
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
index 1f9aff7169..5ebf5cfac0 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
@@ -53,7 +53,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel
             [TGUI(Label = "Show Sub-Model Metrics")]
             public bool ShowMetrics;
 
-            public abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] BasePredictorFactories { get; set; }
+            internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] GetPredictorFactories();
         }
 
         private const int DefaultNumModels = 50;
@@ -77,8 +77,8 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,
 
             using (var ch = Host.Start("Init"))
             {
-                var predictorFactories = Args.BasePredictorFactories;
-                ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(Args.BasePredictorFactories), "This should have at-least one value");
+                var predictorFactories = Args.GetPredictorFactories();
+                ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(EnsembleTrainer.Arguments.BasePredictors), "This should have at-least one value");
 
                 NumModels = Args.NumModels ??
                     (predictorFactories.Length == 1 ? DefaultNumModels : predictorFactories.Length);
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index 3183292383..d5c60e24a6 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -50,11 +50,7 @@ public sealed class Arguments : ArgumentsBase
             [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
             public IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictors;
 
-            public override IComponentFactory<ITrainer<TVectorPredictor>>[] BasePredictorFactories
-            {
-                get { return BasePredictors; }
-                set { BasePredictors = value; }
-            }
+            internal override IComponentFactory<ITrainer<TVectorPredictor>>[] GetPredictorFactories() => BasePredictors;
 
             public Arguments()
             {
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
index 283fe2f32d..3fe853aff5 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
@@ -44,11 +44,7 @@ public sealed class Arguments : ArgumentsBase
             [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
             public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors;
 
-            public override IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictorFactories
-            {
-                get { return BasePredictors; }
-                set { BasePredictors = value; }
-            }
+            internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors;
 
             public Arguments()
             {

From 2beb2ce21da3e0b526614e3e9d71cdc7b869c8d4 Mon Sep 17 00:00:00 2001
From: Eric Erhardt <eric.erhardt@microsoft.com>
Date: Fri, 17 Aug 2018 10:43:52 -0500
Subject: [PATCH 5/7] PR feedback

---
 src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
index cb1c1fbaf6..1f8d5b0e8f 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
@@ -54,7 +54,7 @@ public interface IComponentFactory<in TArg1, in TArg2, in TArg3, out TComponent>
     }
 
     /// <summary>
-    /// A utility class for creating IComponentFactory instances.
+    /// A utility class for creating <see cref="IComponentFactory"/> instances.
     /// </summary>
     public static class ComponentFactoryUtils
     {

From 4923971fbde6ee6eecf35eed4fd046a12c521d24 Mon Sep 17 00:00:00 2001
From: Eric Erhardt <eric.erhardt@microsoft.com>
Date: Fri, 17 Aug 2018 14:33:36 -0500
Subject: [PATCH 6/7] PR Feedback

---
 .../Selector/SubModelSelector/BaseSubModelSelector.cs        | 4 ++--
 src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs  | 5 ++---
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
index ac79e400a9..513ad95139 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs
@@ -81,10 +81,10 @@ public virtual void CalculateMetrics(FeatureSubsetModel<IPredictorProducing<TOut
                 // Because the training and test datasets are drawn from the same base dataset, the test data role mappings
                 // are the same as for the train data.
                 IDataScorerTransform scorePipe = ScoreUtils.GetScorer(model.Predictor, testData, Host, testData.Schema);
-                // REVIEW: Should we somehow allow the user to customize the evaluator?
-                // By what mechanism should we allow that?
                 RoleMappedData scoredTestData = new RoleMappedData(scorePipe,
                     GetColumnRoles(testData.Schema, scorePipe.Schema));
+                // REVIEW: Should we somehow allow the user to customize the evaluator?
+                // By what mechanism should we allow that?
                 IEvaluator evaluator = GetEvaluator(Host);
                 // REVIEW: with the new evaluators, metrics of individual models are no longer
                 // printed to the Console. Consider adding an option on the combiner to print them.
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
index 6a4b151fab..9165360ccb 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
@@ -5,15 +5,14 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
+using Microsoft.ML.Ensemble.EntryPoints;
 using Microsoft.ML.Runtime;
 using Microsoft.ML.Runtime.CommandLine;
-using Microsoft.ML.Runtime.Data;
 using Microsoft.ML.Runtime.Ensemble;
 using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
 using Microsoft.ML.Runtime.Ensemble.Selector;
-using Microsoft.ML.Ensemble.EntryPoints;
-using Microsoft.ML.Runtime.Internal.Internallearn;
 using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Internallearn;
 using Microsoft.ML.Runtime.Learners;
 
 [assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments),

From 78a4178d1905dd2be2b9209ed4eb57c6b1f55ea2 Mon Sep 17 00:00:00 2001
From: Eric Erhardt <eric.erhardt@microsoft.com>
Date: Fri, 17 Aug 2018 16:43:50 -0500
Subject: [PATCH 7/7] Trigger