Skip to content

WIP Introduce I*PredictionKind*TrainerFactory and propagate them #670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.IO;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.EntryPoints;

namespace Microsoft.ML.Runtime.Internal.Internallearn
{
Expand Down Expand Up @@ -204,4 +205,30 @@ public interface ICanGetTrainingLabelNames : IPredictor
{
string[] GetLabelNamesOrNull(out ColumnType labelType);
}

[TlcModule.ComponentKind("BinaryTrainerFactory")]
public interface IBinaryTrainerFactory : IComponentFactory<ITrainer<IPredictorProducing<float>>>
{
}

[TlcModule.ComponentKind("MulticlassTrainerFactory")]
public interface IMulticlassTrainerFactory : IComponentFactory<ITrainer<IPredictorProducing<VBuffer<float>>>>
{
}

[TlcModule.ComponentKind("RankingTrainerFactory")]
public interface IRankingTrainerFactory : IComponentFactory<ITrainer<IPredictorProducing<float>>>
{
}

[TlcModule.ComponentKind("RegressionTrainerFactory")]
public interface IRegressionTrainerFactory : IComponentFactory<ITrainer<IPredictorProducing<float>>>
{
}

[TlcModule.ComponentKind("ClusteringTrainer")]
public interface IClusteringTrainerFactory : IComponentFactory<ITrainer<IPredictorProducing<VBuffer<float>>>>
{
}

}
9 changes: 7 additions & 2 deletions src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using System.Linq;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Ensemble;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Ensemble.Selector;
Expand All @@ -18,6 +18,8 @@
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
EnsembleTrainer.UserNameValue, EnsembleTrainer.LoadNameValue, "pe", "ParallelEnsemble")]

[assembly: EntryPointModule(typeof(EnsembleTrainer.Arguments))]

namespace Microsoft.ML.Runtime.Ensemble
{
using TDistPredictor = IDistPredictorProducing<Single, Single>;
Expand All @@ -33,7 +35,7 @@ public sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredict
public const string UserNameValue = "Parallel Ensemble (bagging, stacking, etc)";
public const string Summary = "A generic ensemble classifier for binary classification.";

public sealed class Arguments : ArgumentsBase
public sealed class Arguments : ArgumentsBase, IBinaryTrainerFactory
{
[Argument(ArgumentType.Multiple, HelpText = "Algorithm to prune the base learners for selective Ensemble", ShortName = "pt", SortOrder = 4)]
[TGUI(Label = "Sub-Model Selector(pruning) Type",
Expand All @@ -48,6 +50,8 @@ public Arguments()
{
BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("LinearSVM") };
}

public ITrainer<TScalarPredictor> CreateComponent(IHostEnvironment env) => new EnsembleTrainer(env, this);
}

private readonly ISupportBinaryOutputCombinerFactory _outputCombiner;
Expand Down Expand Up @@ -83,4 +87,5 @@ public TScalarPredictor CombineModels(IEnumerable<TScalarPredictor> models)
models.Select(k => new FeatureSubsetModel<TScalarPredictor>(k)).ToArray(), combiner);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.Runtime.Ensemble;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;

[assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer),
Expand All @@ -20,6 +21,8 @@
MulticlassDataPartitionEnsembleTrainer.UserNameValue,
MulticlassDataPartitionEnsembleTrainer.LoadNameValue)]

[assembly: EntryPointModule(typeof(MulticlassDataPartitionEnsembleTrainer.Arguments))]

namespace Microsoft.ML.Runtime.Ensemble
{
using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
Expand All @@ -35,7 +38,7 @@ public sealed class MulticlassDataPartitionEnsembleTrainer :
public const string UserNameValue = "Multi-class Parallel Ensemble (bagging, stacking, etc)";
public const string Summary = "A generic ensemble classifier for multi-class classification.";

public sealed class Arguments : ArgumentsBase
public sealed class Arguments : ArgumentsBase, IMulticlassTrainerFactory
{
[Argument(ArgumentType.Multiple, HelpText = "Algorithm to prune the base learners for selective Ensemble", ShortName = "pt", SortOrder = 4)]
[TGUI(Label = "Sub-Model Selector(pruning) Type", Description = "Algorithm to prune the base learners for selective Ensemble")]
Expand All @@ -49,6 +52,8 @@ public Arguments()
{
BasePredictors = new[] { new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") };
}

public ITrainer<TVectorPredictor> CreateComponent(IHostEnvironment env) => new MulticlassDataPartitionEnsembleTrainer(env, this);
}

private readonly ISupportMulticlassOutputCombinerFactory _outputCombiner;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@
using Microsoft.ML.Ensemble.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;

[assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments),
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer) },
RegressionEnsembleTrainer.UserNameValue,
RegressionEnsembleTrainer.LoadNameValue)]

[assembly: EntryPointModule(typeof(RegressionEnsembleTrainer.Arguments))]

namespace Microsoft.ML.Runtime.Ensemble
{
using TScalarPredictor = IPredictorProducing<Single>;
Expand All @@ -29,7 +31,7 @@ public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase<Single, TSca
public const string LoadNameValue = "EnsembleRegression";
public const string UserNameValue = "Regression Ensemble (bagging, stacking, etc)";

public sealed class Arguments : ArgumentsBase
public sealed class Arguments : ArgumentsBase, IRegressionTrainerFactory
{
[Argument(ArgumentType.Multiple, HelpText = "Algorithm to prune the base learners for selective Ensemble", ShortName = "pt", SortOrder = 4)]
[TGUI(Label = "Sub-Model Selector(pruning) Type", Description = "Algorithm to prune the base learners for selective Ensemble")]
Expand All @@ -43,6 +45,8 @@ public Arguments()
{
BasePredictors = new[] { new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("OnlineGradientDescent") };
}

public ITrainer<TScalarPredictor> CreateComponent(IHostEnvironment env) => new RegressionEnsembleTrainer(env, this);
}

private readonly ISupportRegressionOutputCombinerFactory _outputCombiner;
Expand Down
30 changes: 13 additions & 17 deletions src/Microsoft.ML.FastTree/FastTreeArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,39 @@

namespace Microsoft.ML.Runtime.FastTree
{
[TlcModule.ComponentKind("FastTreeTrainer")]
public interface IFastTreeTrainerFactory : IComponentFactory<ITrainer>
{
}

/// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
public sealed partial class FastTreeBinaryClassificationTrainer
{
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed class Arguments : BoostedTreeArgs, IBinaryTrainerFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")]
[TGUI(Label = "Optimize for unbalanced")]
public bool UnbalancedSets = false;

public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this);
public ITrainer<IPredictorProducing<float>> CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this);
}
}

public sealed partial class FastTreeRegressionTrainer
{
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed class Arguments : BoostedTreeArgs, IRegressionTrainerFactory
{
public Arguments()
{
EarlyStoppingMetrics = 1; // Use L1 by default.
}

public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeRegressionTrainer(env, this);
public ITrainer<IPredictorProducing<float>> CreateComponent(IHostEnvironment env) => new FastTreeRegressionTrainer(env, this);
}
}

public sealed partial class FastTreeTweedieTrainer
{
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed class Arguments : BoostedTreeArgs, IRegressionTrainerFactory
{
// REVIEW: It is possible to estimate this index parameter from the distribution of data, using
// a combination of univariate optimization and grid search, following section 4.2 of the paper. However
Expand All @@ -61,14 +57,14 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
"and intermediate values are compound Poisson loss.")]
public Double Index = 1.5;

public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeTweedieTrainer(env, this);
public ITrainer<IPredictorProducing<float>> CreateComponent(IHostEnvironment env) => new FastTreeTweedieTrainer(env, this);
}
}

public sealed partial class FastTreeRankingTrainer
{
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
public sealed class Arguments : BoostedTreeArgs, IRankingTrainerFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma seperated list of gains associated to each relevance label.", ShortName = "gains")]
[TGUI(NoSweep = true)]
Expand Down Expand Up @@ -110,7 +106,7 @@ public Arguments()
EarlyStoppingMetrics = 1;
}

public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this);
public ITrainer<IPredictorProducing<float>> CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this);

internal override void Check(IExceptionContext ectx)
{
Expand Down Expand Up @@ -228,14 +224,14 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId
// REVIEW: Different from original FastRank arguments (shortname l vs. nl). Different default from TLC FR Wrapper (20 vs. 20).
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The max number of leaves in each regression tree", ShortName = "nl", SortOrder = 2)]
[TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")]
[TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale:true, stepSize:4)]
[TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)]
public int NumLeaves = 20;

// REVIEW: Arrays not supported in GUI
// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data", ShortName = "mil", SortOrder = 3)]
[TGUI(Description = "Minimum number of training instances required to form a leaf", SuggestedSweeps = "1,10,50")]
[TlcModule.SweepableDiscreteParamAttribute("MinDocumentsInLeafs", new object[] {1, 10, 50})]
[TlcModule.SweepableDiscreteParamAttribute("MinDocumentsInLeafs", new object[] { 1, 10, 50 })]
public int MinDocumentsInLeafs = 10;

// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
Expand Down Expand Up @@ -364,17 +360,17 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc

[Argument(ArgumentType.LastOccurenceWins, HelpText = "The learning rate", ShortName = "lr", SortOrder = 4)]
[TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")]
[TlcModule.SweepableFloatParamAttribute("LearningRates", 0.025f, 0.4f, isLogScale:true)]
[TlcModule.SweepableFloatParamAttribute("LearningRates", 0.025f, 0.4f, isLogScale: true)]
public Double LearningRates = 0.2;

[Argument(ArgumentType.AtMostOnce, HelpText = "Shrinkage", ShortName = "shrk")]
[TGUI(Label = "Shrinkage", SuggestedSweeps = "0.25-4;log")]
[TlcModule.SweepableFloatParamAttribute("Shrinkage", 0.025f, 4f, isLogScale:true)]
[TlcModule.SweepableFloatParamAttribute("Shrinkage", 0.025f, 4f, isLogScale: true)]
public Double Shrinkage = 1;

[Argument(ArgumentType.AtMostOnce, HelpText = "Dropout rate for tree regularization", ShortName = "tdrop")]
[TGUI(SuggestedSweeps = "0,0.000000001,0.05,0.1,0.2")]
[TlcModule.SweepableDiscreteParamAttribute("DropoutRate", new object[] { 0.0f, 1E-9f, 0.05f, 0.1f, 0.2f})]
[TlcModule.SweepableDiscreteParamAttribute("DropoutRate", new object[] { 0.0f, 1E-9f, 0.05f, 0.1f, 0.2f })]
public Double DropoutRate = 0;

[Argument(ArgumentType.AtMostOnce, HelpText = "Sample each query 1 in k times in the GetDerivatives function", ShortName = "sr")]
Expand Down
1 change: 0 additions & 1 deletion src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.FastTree.Internal;
Expand Down
9 changes: 7 additions & 2 deletions src/Microsoft.ML.FastTree/GamTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
"GAM Vizualization Command", GamPredictorBase.VisualizationCommand.LoadName, "gamviz", DocName = "command/GamViz.md")]

[assembly: LoadableClass(typeof(void), typeof(Gam), null, typeof(SignatureEntryPointModule), "GAM")]
[assembly: EntryPointModule(typeof(BinaryClassificationGamTrainer.Arguments))]
[assembly: EntryPointModule(typeof(RegressionGamTrainer.Arguments))]

namespace Microsoft.ML.Runtime.FastTree
{
Expand All @@ -57,8 +59,9 @@ namespace Microsoft.ML.Runtime.FastTree
public sealed class RegressionGamTrainer :
GamTrainerBase<RegressionGamTrainer.Arguments, RegressionGamPredictor>
{
public partial class Arguments : ArgumentsBase
public partial class Arguments : ArgumentsBase, IRegressionTrainerFactory
{
public ITrainer<IPredictorProducing<float>> CreateComponent(IHostEnvironment env) => new RegressionGamTrainer(env, this);
}

internal const string LoadNameValue = "RegressionGamTrainer";
Expand Down Expand Up @@ -89,7 +92,7 @@ protected override ObjectiveFunctionBase CreateObjectiveFunction()
public sealed class BinaryClassificationGamTrainer :
GamTrainerBase<BinaryClassificationGamTrainer.Arguments, BinaryClassGamPredictor>
{
public sealed class Arguments : ArgumentsBase
public sealed class Arguments : ArgumentsBase, IBinaryTrainerFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")]
[TGUI(Label = "Optimize for unbalanced")]
Expand All @@ -100,6 +103,8 @@ public sealed class Arguments : ArgumentsBase

[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public int MaxCalibrationExamples = 1000000;

public ITrainer<IPredictorProducing<float>> CreateComponent(IHostEnvironment env) => new BinaryClassificationGamTrainer(env, this);
}

internal const string LoadNameValue = "BinaryClassificationGamTrainer";
Expand Down
6 changes: 4 additions & 2 deletions src/Microsoft.ML.FastTree/RandomForestClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Linq;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.FastTree.Internal;
Expand All @@ -30,6 +29,7 @@
FastForestClassificationPredictor.LoaderSignature)]

[assembly: LoadableClass(typeof(void), typeof(FastForest), null, typeof(SignatureEntryPointModule), "FastForest")]
[assembly: EntryPointModule(typeof(FastForestClassification.Arguments))]

namespace Microsoft.ML.Runtime.FastTree
{
Expand Down Expand Up @@ -110,7 +110,7 @@ public static IPredictorProducing<Float> Create(IHostEnvironment env, ModelLoadC
public sealed partial class FastForestClassification :
RandomForestTrainerBase<FastForestClassification.Arguments, IPredictorWithFeatureWeights<Float>>
{
public sealed class Arguments : FastForestArgumentsBase
public sealed class Arguments : FastForestArgumentsBase, IBinaryTrainerFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Upper bound on absolute value of single tree output", ShortName = "mo")]
public Double MaxTreeOutput = 100;
Expand All @@ -120,6 +120,8 @@ public sealed class Arguments : FastForestArgumentsBase

[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public int MaxCalibrationExamples = 1000000;

public ITrainer<IPredictorProducing<float>> CreateComponent(IHostEnvironment env) => new FastForestClassification(env, this);
}

internal const string LoadNameValue = "FastForestClassification";
Expand Down
5 changes: 4 additions & 1 deletion src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
[assembly: LoadableClass(typeof(FastForestRegressionPredictor), null, typeof(SignatureLoadModel),
"FastForest Regression Executor",
FastForestRegressionPredictor.LoaderSignature)]
[assembly: EntryPointModule(typeof(FastForestRegression.Arguments))]

namespace Microsoft.ML.Runtime.FastTree
{
Expand Down Expand Up @@ -140,11 +141,13 @@ public ISchemaBindableMapper CreateMapper(Double[] quantiles)
/// <include file='doc.xml' path='doc/members/member[@name="FastForest"]/*' />
public sealed partial class FastForestRegression : RandomForestTrainerBase<FastForestRegression.Arguments, FastForestRegressionPredictor>
{
public sealed class Arguments : FastForestArgumentsBase
public sealed class Arguments : FastForestArgumentsBase, IRegressionTrainerFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Shuffle the labels on every iteration. " +
"Useful probably only if using this tree as a tree leaf featurizer for multiclass.")]
public bool ShuffleLabels;

public ITrainer<IPredictorProducing<float>> CreateComponent(IHostEnvironment env) => new FastForestRegression(env, this);
}

internal const string Summary = "Trains a random forest to fit target values using least-squares.";
Expand Down
Loading