From 0c1ff77d1164f09f09361d93183302f9e7d9f6c9 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 7 Jan 2019 01:28:49 +0000 Subject: [PATCH 1/8] Changes for FastTree & related learners --- src/Microsoft.ML.FastTree/BoostingFastTree.cs | 15 +- src/Microsoft.ML.FastTree/FastTree.cs | 38 ++++- .../FastTreeClassification.cs | 25 ++- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 24 ++- .../FastTreeRegression.cs | 21 ++- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 25 ++- src/Microsoft.ML.FastTree/RandomForest.cs | 16 +- .../RandomForestClassification.cs | 25 ++- .../RandomForestRegression.cs | 25 ++- .../TreeTrainersCatalog.cs | 148 ++++++++++++++--- .../TreeTrainersStatic.cs | 149 ++++++++++++++++-- .../Algorithms/SmacSweeper.cs | 3 +- test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 3 +- .../TrainerEstimators/TreeEstimators.cs | 15 +- 14 files changed, 461 insertions(+), 71 deletions(-) diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 41995b1585..dc5d9d9967 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -28,9 +28,8 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, int numLeaves, int numTrees, int minDatapointsInLeaves, - double learningRate, - Action advancedSettings) - : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, advancedSettings) + double learningRate) + : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves) { if (Args.LearningRates != learningRate) @@ -40,6 +39,16 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, } } + protected BoostingFastTreeTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + Action advancedSettings) + : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) + { + } + protected override void CheckArgs(IChannel ch) { if (Args.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 9c5e7ec4cd..f5ae2db79f 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -114,8 +114,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, string groupIdColumn, int numLeaves, int numTrees, - int minDatapointsInLeaves, - Action advancedSettings) + int minDatapointsInLeaves) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new TArgs(); @@ -126,6 +125,41 @@ private protected FastTreeTrainerBase(IHostEnvironment env, Args.NumTrees = numTrees; Args.MinDocumentsInLeafs = minDatapointsInLeaves; + Args.LabelColumn = label.Name; + Args.FeatureColumn = featureColumn; + + if (weightColumn != null) + Args.WeightColumn = Optional.Explicit(weightColumn); + + if (groupIdColumn != null) + Args.GroupIdColumn = Optional.Explicit(groupIdColumn); + + // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. + // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. + // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. + Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true, supportTest: true); + // REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment + // with memory consumption more than 5GB, GC get stuck in infinite loop. + // Before, we could check a specific type of the environment here, but now it is internal, so we will need another + // mechanism to detect that we are running in Scope. + AllowGC = true; + + Initialize(env); + } + + /// + /// Constructor to use when instantiating the classes deriving from here through the API. + /// + private protected FastTreeTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + Action advancedSettings) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) + { + Args = new TArgs(); + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index be17d13661..74c20e3d88 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -127,7 +127,6 @@ public sealed partial class FastTreeBinaryClassificationTrainer : /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. - /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -135,9 +134,27 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + double learningRate = Defaults.LearningRates) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate) + { + // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss + _sigmoidParameter = 2.0 * Args.LearningRates; + } + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastTreeBinaryClassificationTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn, + Action advancedSettings) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 0c967ab1e5..83cc1cc270 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -71,7 +71,6 @@ public sealed partial class FastTreeRankingTrainer /// Total number of decision trees to create in the ensemble. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -80,9 +79,28 @@ public FastTreeRankingTrainer(IHostEnvironment env, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, + double learningRate = Defaults.LearningRates) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, learningRate) + { + Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); + } + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastTreeRankingTrainer(IHostEnvironment env, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string groupIdColumn = DefaultColumnNames.GroupId, + string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 2cb321fafe..7bddb0c485 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -62,7 +62,6 @@ public sealed partial class FastTreeRegressionTrainer /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. - /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -70,9 +69,25 @@ public FastTreeRegressionTrainer(IHostEnvironment env, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, + double learningRate = Defaults.LearningRates) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate) + { + } + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastTreeRegressionTrainer(IHostEnvironment env, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 8e9234c808..720f8c29ae 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -59,7 +59,6 @@ public sealed partial class FastTreeTweedieTrainer /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. - /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -67,9 +66,29 @@ public FastTreeTweedieTrainer(IHostEnvironment env, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, + double learningRate = Defaults.LearningRates) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + + Initialize(); + } + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastTreeTweedieTrainer(IHostEnvironment env, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 43aa5a0149..ba4762d4e0 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -36,9 +36,23 @@ protected RandomForestTrainerBase(IHostEnvironment env, int numTrees, int minDatapointsInLeaves, double learningRate, + bool quantileEnabled = false) + : base(env, label, featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves) + { + _quantileEnabled = quantileEnabled; + } + + /// + /// Constructor invoked by the API code-path. + /// + protected RandomForestTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, Action advancedSettings, bool quantileEnabled = false) - : base(env, label, featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, advancedSettings) + : base(env, label, featureColumn, weightColumn, null, advancedSettings) { _quantileEnabled = quantileEnabled; } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 7417a67d5f..80d23128bb 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -145,7 +145,6 @@ public sealed class Arguments : FastForestArgumentsBase /// Total number of decision trees to create in the ensemble. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - /// A delegate to apply all the advanced arguments to the algorithm. public FastForestClassification(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -153,9 +152,27 @@ public FastForestClassification(IHostEnvironment env, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + double learningRate = Defaults.LearningRates) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + } + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastForestClassification(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn, + Action advancedSettings) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index f90a37bfa7..71136be494 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -163,7 +163,6 @@ public sealed class Arguments : FastForestArgumentsBase /// Total number of decision trees to create in the ensemble. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - /// A delegate to apply all the advanced arguments to the algorithm. public FastForestRegression(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -171,9 +170,27 @@ public FastForestRegression(IHostEnvironment env, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings) + double learningRate = Defaults.LearningRates) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + } + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The optional name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastForestRegression(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn, + Action advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index afc670cdb4..ebb81eefc9 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -24,7 +24,6 @@ public static class TreeExtensions /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -32,12 +31,30 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) + double learningRate = Defaults.LearningRates) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRegressionTrainer(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + return new FastTreeRegressionTrainer(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The feature column. + /// The optional weights column. + /// Algorithm advanced settings. + public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, + string labelColumn, + string featureColumn, + string weights, + Action advancedSettings) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeRegressionTrainer(env, labelColumn, featureColumn, weights, advancedSettings); } /// @@ -51,7 +68,6 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -59,12 +75,30 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) + double learningRate = Defaults.LearningRates) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeBinaryClassificationTrainer(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate); + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The labelColumn column. + /// The featureColumn column. + /// The optional weights column. + /// Algorithm advanced settings. + public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string labelColumn, + string featureColumn, + string weights, + Action advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeBinaryClassificationTrainer(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + return new FastTreeBinaryClassificationTrainer(env, labelColumn, featureColumn, weights, advancedSettings); } /// @@ -79,7 +113,6 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -88,12 +121,32 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) + double learningRate = Defaults.LearningRates) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeRankingTrainer(env, labelColumn, featureColumn, groupId, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate); + } + + /// + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . + /// + /// The . + /// The labelColumn column. + /// The featureColumn column. + /// The groupId column. + /// The optional weights column. + /// Algorithm advanced settings. + public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, + string labelColumn, + string featureColumn, + string groupId, + string weights, + Action advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRankingTrainer(env, labelColumn, featureColumn, groupId, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + return new FastTreeRankingTrainer(env, labelColumn, featureColumn, groupId, weights, advancedSettings); } /// @@ -157,7 +210,6 @@ public static RegressionGamTrainer GeneralizedAdditiveModels(this RegressionCont /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -165,12 +217,30 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, + double learningRate = Defaults.LearningRates) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeTweedieTrainer(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The labelColumn column. + /// The featureColumn column. + /// The optional weights column. + /// Algorithm advanced settings. + public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, + string labelColumn, + string featureColumn, + string weights, Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeTweedieTrainer(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + return new FastTreeTweedieTrainer(env, labelColumn, featureColumn, weights, advancedSettings); } /// @@ -184,7 +254,6 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. public static FastForestRegression FastForest(this RegressionContext.RegressionTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -192,12 +261,30 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, + double learningRate = Defaults.LearningRates) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastForestRegression(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The labelColumn column. + /// The featureColumn column. + /// The optional weights column. + /// Algorithm advanced settings. + public static FastForestRegression FastForest(this RegressionContext.RegressionTrainers ctx, + string labelColumn, + string featureColumn, + string weights, Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastForestRegression(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + return new FastForestRegression(env, labelColumn, featureColumn, weights, advancedSettings); } /// @@ -211,7 +298,6 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. public static FastForestClassification FastForest(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, @@ -219,12 +305,30 @@ public static FastForestClassification FastForest(this BinaryClassificationConte int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) + double learningRate = Defaults.LearningRates) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastForestClassification(env, labelColumn, featureColumn, weights,numLeaves, numTrees, minDatapointsInLeaves, learningRate); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The labelColumn column. + /// The featureColumn column. + /// The optional weights column. + /// Algorithm advanced settings. + public static FastForestClassification FastForest(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string labelColumn, + string featureColumn, + string weights, + Action advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastForestClassification(env, labelColumn, featureColumn, weights,numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + return new FastForestClassification(env, labelColumn, featureColumn, weights, advancedSettings); } } } diff --git a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs index 1e9ff0f4de..0c81119760 100644 --- a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs @@ -26,7 +26,6 @@ public static class TreeRegressionExtensions /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -45,16 +44,55 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, - Action advancedSettings = null, Action onFit = null) { - CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { var trainer = new FastTreeRegressionTrainer(env, labelName, featuresName, weightsName, numLeaves, - numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + numTrees, minDatapointsInLeaves, learningRate); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Score; + } + + /// + /// FastTree extension method. + /// Predicts a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The Score output column indicating the predicted value. + /// + /// + /// + /// + public static Scalar FastTree(this RegressionContext.RegressionTrainers ctx, + Scalar label, Vector features, Scalar weights, + Action advancedSettings, + Action onFit = null) + { + CheckUserValues(label, features, weights, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.Regression( + (env, labelName, featuresName, weightsName) => + { + var trainer = new FastTreeRegressionTrainer(env, labelName, featuresName, weightsName, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -75,7 +113,6 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -95,16 +132,58 @@ public static (Scalar score, Scalar probability, Scalar pred int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, - Action advancedSettings = null, Action> onFit = null) { - CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, numLeaves, - numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + numTrees, minDatapointsInLeaves, learningRate); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + else + return trainer; + }, label, features, weights); + + return rec.Output; + } + + /// + /// FastTree extension method. + /// Predict a target using a decision tree binary classificaiton model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. + /// + /// + /// + /// + public static (Scalar score, Scalar probability, Scalar predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + Scalar label, Vector features, Scalar weights, + Action advancedSettings, + Action> onFit = null) + { + CheckUserValues(label, features, weights, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.BinaryClassifier( + (env, labelName, featuresName, weightsName) => + { + var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -128,29 +207,62 @@ public static (Scalar score, Scalar probability, Scalar pred /// The maximum number of leaves per decision tree. /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive /// the linear model that was trained. Note that this action cannot change the result in any way; /// it is only a way for the caller to be informed about what was learnt. /// The Score output column indicating the predicted value. - public static Scalar FastTree(this RankingContext.RankingTrainers ctx, + public static Scalar FastTree(this RankingContext.RankingTrainers ctx, Scalar label, Vector features, Key groupId, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, - Action advancedSettings = null, Action onFit = null) { - CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, onFit); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, - numTrees, minDatapointsInLeaves, learningRate, advancedSettings); + numTrees, minDatapointsInLeaves, learningRate); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, groupId, weights); + + return rec.Score; + } + + /// + /// FastTree . + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . + /// + /// The . + /// The label column. + /// The features column. + /// The groupId column. + /// The optional weights column. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The Score output column indicating the predicted value. + public static Scalar FastTree(this RankingContext.RankingTrainers ctx, + Scalar label, Vector features, Key groupId, Scalar weights, + Action advancedSettings, + Action onFit = null) + { + CheckUserValues(label, features, weights, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.Ranker( + (env, labelName, featuresName, groupIdName, weightsName) => + { + var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -164,7 +276,6 @@ internal static void CheckUserValues(PipelineColumn label, Vector feature int numTrees, int minDatapointsInLeaves, double learningRate, - Delegate advancedSettings, Delegate onFit) { Contracts.CheckValue(label, nameof(label)); @@ -174,6 +285,16 @@ internal static void CheckUserValues(PipelineColumn label, Vector feature Contracts.CheckParam(numTrees > 0, nameof(numTrees), "Must be positive"); Contracts.CheckParam(minDatapointsInLeaves > 0, nameof(minDatapointsInLeaves), "Must be positive"); Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive"); + Contracts.CheckValueOrNull(onFit); + } + + internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, + Delegate advancedSettings, + Delegate onFit) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); Contracts.CheckValueOrNull(advancedSettings); Contracts.CheckValueOrNull(onFit); } diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index 730573b73b..5af931160f 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -135,7 +135,8 @@ private FastForestRegressionModelParameters FitModel(IEnumerable pre { // Set relevant random forest arguments. // Train random forest. - var trainer = new FastForestRegression(_host, DefaultColumnNames.Label, DefaultColumnNames.Features, advancedSettings: s => + var trainer = new FastForestRegression(_host, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + advancedSettings: s => { s.FeatureFraction = _args.SplitRatio; s.NumTrees = _args.NumOfTrees; diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index e3e84bccd5..e5c73f976b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -104,7 +104,8 @@ public void OvaFastTree() // Pipeline var pipeline = new Ova( mlContext, - new FastTreeBinaryClassificationTrainer(mlContext, "Label", "Features", advancedSettings: s => { s.NumThreads = 1; }), + new FastTreeBinaryClassificationTrainer(mlContext, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + advancedSettings: s => { s.NumThreads = 1; }), useProbabilities: false); var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 4938061b22..95da39af33 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -24,10 +24,13 @@ public void FastTreeBinaryEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new FastTreeBinaryClassificationTrainer(Env, "Label", "Features", numTrees: 10, numLeaves: 5, advancedSettings: s => - { - s.NumThreads = 1; - }); + var trainer = new FastTreeBinaryClassificationTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + advancedSettings: s => { + s.NumThreads = 1; + s.NumTrees = 10; + s.NumLeaves = 5; + }); + var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -80,7 +83,7 @@ public void FastForestClassificationEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new FastForestClassification(Env, "Label", "Features", advancedSettings: s => + var trainer = new FastForestClassification(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s => { s.NumLeaves = 10; s.NumTrees = 20; @@ -211,7 +214,7 @@ public void TweedieRegressorEstimator() public void FastForestRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastForestRegression(Env, "Label", "Features", advancedSettings: s => + var trainer = new FastForestRegression(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s => { s.BaggingSize = 2; s.NumTrees = 10; From 0d380ed60ffec3f9972d85d7bf0b03ee3860b9c0 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 7 Jan 2019 01:44:09 +0000 Subject: [PATCH 2/8] Removing defaults from some of the newly added APIs --- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 8 ++++---- src/Microsoft.ML.FastTree/FastTreeRegression.cs | 6 +++--- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 8 ++++---- .../TrainerEstimators/TreeEstimators.cs | 6 +++--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 83cc1cc270..8df1f8587b 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -95,10 +95,10 @@ public FastTreeRankingTrainer(IHostEnvironment env, /// The name for the column containing the initial weight. /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRankingTrainer(IHostEnvironment env, - string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string groupIdColumn = DefaultColumnNames.GroupId, - string weightColumn = null, + string labelColumn, + string featureColumn, + string groupIdColumn, + string weightColumn, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 7bddb0c485..77e7d3364e 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -83,9 +83,9 @@ public FastTreeRegressionTrainer(IHostEnvironment env, /// The name for the column containing the initial weight. /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRegressionTrainer(IHostEnvironment env, - string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string weightColumn = null, + string labelColumn, + string featureColumn, + string weightColumn, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 720f8c29ae..ec4d19e399 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -84,10 +84,10 @@ public FastTreeTweedieTrainer(IHostEnvironment env, /// The name for the column containing the initial weight. /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeTweedieTrainer(IHostEnvironment env, - string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string weightColumn = null, - Action advancedSettings = null) + string labelColumn, + string featureColumn, + string weightColumn, + Action advancedSettings) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 95da39af33..d610e26ab2 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -104,7 +104,7 @@ public void FastTreeRankerEstimator() { var (pipe, dataView) = GetRankingPipeline(); - var trainer = new FastTreeRankingTrainer(Env, "Label0", "NumericFeatures", "Group", + var trainer = new FastTreeRankingTrainer(Env, "Label0", "NumericFeatures", "Group", null, advancedSettings: s => { s.NumTrees = 10; }); var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -139,7 +139,7 @@ public void LightGBMRankerEstimator() public void FastTreeRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastTreeRegressionTrainer(Env, "Label", "Features", advancedSettings: s => + var trainer = new FastTreeRegressionTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s => { s.NumTrees = 10; s.NumThreads = 1; @@ -196,7 +196,7 @@ public void GAMRegressorEstimator() public void TweedieRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastTreeTweedieTrainer(Env, "Label", "Features", advancedSettings: s => + var trainer = new FastTreeTweedieTrainer(Env, "Label", "Features", null, advancedSettings: s => { s.EntropyCoefficient = 0.3; s.OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent; From 9c9c114fc403eef5ffb14b5970c0b0d5f2ac59f9 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 7 Jan 2019 01:57:43 +0000 Subject: [PATCH 3/8] Argument -> Options --- .../FastTreeArguments.cs | 20 ++++++++-------- .../FastTreeClassification.cs | 14 +++++------ src/Microsoft.ML.FastTree/FastTreeRanking.cs | 18 +++++++------- .../FastTreeRegression.cs | 16 ++++++------- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 16 ++++++------- .../RandomForestClassification.cs | 18 +++++++------- .../RandomForestRegression.cs | 24 +++++++++---------- .../TreeTrainersCatalog.cs | 12 +++++----- .../TreeTrainersStatic.cs | 6 ++--- .../Common/EntryPoints/core_ep-list.tsv | 12 +++++----- .../UnitTests/TestEntryPoints.cs | 2 +- .../TestPredictors.cs | 6 ++--- 12 files changed, 82 insertions(+), 82 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 9485396078..3d66975c30 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -8,10 +8,10 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Trainers.FastTree; -[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Arguments))] -[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Arguments))] -[assembly: EntryPointModule(typeof(FastTreeTweedieTrainer.Arguments))] -[assembly: EntryPointModule(typeof(FastTreeRankingTrainer.Arguments))] +[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Options))] +[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Options))] +[assembly: EntryPointModule(typeof(FastTreeTweedieTrainer.Options))] +[assembly: EntryPointModule(typeof(FastTreeRankingTrainer.Options))] namespace Microsoft.ML.Trainers.FastTree { @@ -24,7 +24,7 @@ internal interface IFastTreeTrainerFactory : IComponentFactory public sealed partial class FastTreeBinaryClassificationTrainer { [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] - public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory + public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")] [TGUI(Label = "Optimize for unbalanced")] @@ -37,9 +37,9 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory public sealed partial class FastTreeRegressionTrainer { [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] - public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory + public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory { - public Arguments() + public Options() { EarlyStoppingMetrics = 1; // Use L1 by default. } @@ -51,7 +51,7 @@ public Arguments() public sealed partial class FastTreeTweedieTrainer { [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] - public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory + public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory { // REVIEW: It is possible to estimate this index parameter from the distribution of data, using // a combination of univariate optimization and grid search, following section 4.2 of the paper. However @@ -68,7 +68,7 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory public sealed partial class FastTreeRankingTrainer { [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] - public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory + public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma seperated list of gains associated to each relevance label.", ShortName = "gains")] [TGUI(NoSweep = true)] @@ -105,7 +105,7 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory [TGUI(NotGui = true)] public bool NormalizeQueryLambdas; - public Arguments() + public Options() { EarlyStoppingMetrics = 1; } diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 74c20e3d88..0502dc8295 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -17,7 +17,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Arguments), +[assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, FastTreeBinaryClassificationTrainer.UserNameValue, FastTreeBinaryClassificationTrainer.LoadNameValue, @@ -103,7 +103,7 @@ private static IPredictorProducing Create(IHostEnvironment env, ModelLoad /// public sealed partial class FastTreeBinaryClassificationTrainer : - BoostingFastTreeTrainerBase>, IPredictorWithFeatureWeights> + BoostingFastTreeTrainerBase>, IPredictorWithFeatureWeights> { /// /// The LoadName for the assembly containing the trainer. @@ -153,7 +153,7 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings) + Action advancedSettings) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss @@ -161,9 +161,9 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the legacy class. /// - internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args) + internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss @@ -419,14 +419,14 @@ public static partial class FastTree ShortName = FastTreeBinaryClassificationTrainer.ShortName, XmlInclude = new[] { @"", @"" })] - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryClassificationTrainer.Arguments input) + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryClassificationTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainFastTree"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new FastTreeBinaryClassificationTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 8df1f8587b..703c3985b7 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -20,7 +20,7 @@ using Microsoft.ML.Training; // REVIEW: Do we really need all these names? -[assembly: LoadableClass(FastTreeRankingTrainer.Summary, typeof(FastTreeRankingTrainer), typeof(FastTreeRankingTrainer.Arguments), +[assembly: LoadableClass(FastTreeRankingTrainer.Summary, typeof(FastTreeRankingTrainer), typeof(FastTreeRankingTrainer.Options), new[] { typeof(SignatureRankerTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, FastTreeRankingTrainer.UserNameValue, FastTreeRankingTrainer.LoadNameValue, @@ -43,7 +43,7 @@ namespace Microsoft.ML.Trainers.FastTree { /// public sealed partial class FastTreeRankingTrainer - : BoostingFastTreeTrainerBase, FastTreeRankingModelParameters> + : BoostingFastTreeTrainerBase, FastTreeRankingModelParameters> { internal const string LoadNameValue = "FastTreeRanking"; internal const string UserNameValue = "FastTree (Boosted Trees) Ranking"; @@ -99,16 +99,16 @@ public FastTreeRankingTrainer(IHostEnvironment env, string featureColumn, string groupIdColumn, string weightColumn, - Action advancedSettings = null) + Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the legacy class. /// - internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args) + internal FastTreeRankingTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { } @@ -566,7 +566,7 @@ private enum DupeIdInfo // Keeps track of labels of top 3 documents per query public short[][] TrainQueriesTopLabels; - public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Arguments args, IParallelTraining parallelTraining) + public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options args, IParallelTraining parallelTraining) : base(trainset, args.LearningRates, args.Shrinkage, @@ -664,7 +664,7 @@ private void SetupSecondaryGains(Arguments args) } #endif - private void SetupBaselineRisk(Arguments args) + private void SetupBaselineRisk(Options args) { double[] scores = Dataset.Skeleton.GetData("BaselineScores"); if (scores == null) @@ -1180,14 +1180,14 @@ public static partial class FastTree ShortName = FastTreeRankingTrainer.ShortName, XmlInclude = new[] { @"", @""})] - public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, FastTreeRankingTrainer.Arguments input) + public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, FastTreeRankingTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainFastTree"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new FastTreeRankingTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 77e7d3364e..a6acd1940a 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -15,7 +15,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(FastTreeRegressionTrainer.Summary, typeof(FastTreeRegressionTrainer), typeof(FastTreeRegressionTrainer.Arguments), +[assembly: LoadableClass(FastTreeRegressionTrainer.Summary, typeof(FastTreeRegressionTrainer), typeof(FastTreeRegressionTrainer.Options), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, FastTreeRegressionTrainer.UserNameValue, FastTreeRegressionTrainer.LoadNameValue, @@ -35,7 +35,7 @@ namespace Microsoft.ML.Trainers.FastTree { /// public sealed partial class FastTreeRegressionTrainer - : BoostingFastTreeTrainerBase, FastTreeRegressionModelParameters> + : BoostingFastTreeTrainerBase, FastTreeRegressionModelParameters> { public const string LoadNameValue = "FastTreeRegression"; internal const string UserNameValue = "FastTree (Boosted Trees) Regression"; @@ -86,15 +86,15 @@ public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings = null) + Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the legacy class. /// - internal FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) + internal FastTreeRegressionTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { } @@ -414,7 +414,7 @@ public ObjectiveImpl(Dataset trainData, RegressionGamTrainer.Arguments args) : _labels = GetDatasetRegressionLabels(trainData); } - public ObjectiveImpl(Dataset trainData, Arguments args) + public ObjectiveImpl(Dataset trainData, Options args) : base( trainData, args.LearningRates, @@ -516,14 +516,14 @@ public static partial class FastTree ShortName = FastTreeRegressionTrainer.ShortName, XmlInclude = new[] { @"", @""})] - public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, FastTreeRegressionTrainer.Arguments input) + public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, FastTreeRegressionTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainFastTree"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new FastTreeRegressionTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index ec4d19e399..3fce1f56e3 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -16,7 +16,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(FastTreeTweedieTrainer.Summary, typeof(FastTreeTweedieTrainer), typeof(FastTreeTweedieTrainer.Arguments), +[assembly: LoadableClass(FastTreeTweedieTrainer.Summary, typeof(FastTreeTweedieTrainer), typeof(FastTreeTweedieTrainer.Options), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, FastTreeTweedieTrainer.UserNameValue, FastTreeTweedieTrainer.LoadNameValue, @@ -33,7 +33,7 @@ namespace Microsoft.ML.Trainers.FastTree // https://arxiv.org/pdf/1508.06378.pdf /// public sealed partial class FastTreeTweedieTrainer - : BoostingFastTreeTrainerBase, FastTreeTweedieModelParameters> + : BoostingFastTreeTrainerBase, FastTreeTweedieModelParameters> { internal const string LoadNameValue = "FastTreeTweedieRegression"; internal const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; @@ -87,7 +87,7 @@ public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings) + Action advancedSettings) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -97,9 +97,9 @@ public FastTreeTweedieTrainer(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the legacy class. /// - internal FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) + internal FastTreeTweedieTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { Initialize(); @@ -355,7 +355,7 @@ private sealed class ObjectiveImpl : ObjectiveFunctionBase, IStepSearch private readonly Double _index2; // 2 minus the index parameter. private readonly Double _maxClamp; - public ObjectiveImpl(Dataset trainData, Arguments args) + public ObjectiveImpl(Dataset trainData, Options args) : base( trainData, args.LearningRates, @@ -535,14 +535,14 @@ public static partial class FastTree UserName = FastTreeTweedieTrainer.UserNameValue, ShortName = FastTreeTweedieTrainer.ShortName, XmlInclude = new[] { @"" })] - public static CommonOutputs.RegressionOutput TrainTweedieRegression(IHostEnvironment env, FastTreeTweedieTrainer.Arguments input) + public static CommonOutputs.RegressionOutput TrainTweedieRegression(IHostEnvironment env, FastTreeTweedieTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainTweeedie"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new FastTreeTweedieTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 80d23128bb..73566636fc 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -17,7 +17,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(FastForestClassification.Summary, typeof(FastForestClassification), typeof(FastForestClassification.Arguments), +[assembly: LoadableClass(FastForestClassification.Summary, typeof(FastForestClassification), typeof(FastForestClassification.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, FastForestClassification.UserNameValue, FastForestClassification.LoadNameValue, @@ -110,9 +110,9 @@ private static IPredictorProducing Create(IHostEnvironment env, ModelLoad /// public sealed partial class FastForestClassification : - RandomForestTrainerBase>, IPredictorWithFeatureWeights> + RandomForestTrainerBase>, IPredictorWithFeatureWeights> { - public sealed class Arguments : FastForestArgumentsBase + public sealed class Options : FastForestArgumentsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Upper bound on absolute value of single tree output", ShortName = "mo")] public Double MaxTreeOutput = 100; @@ -171,7 +171,7 @@ public FastForestClassification(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings) + Action advancedSettings) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -179,9 +179,9 @@ public FastForestClassification(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the legacy class. /// - public FastForestClassification(IHostEnvironment env, Arguments args) + public FastForestClassification(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { } @@ -248,7 +248,7 @@ private sealed class ObjectiveFunctionImpl : RandomForestObjectiveFunction { private readonly bool[] _labels; - public ObjectiveFunctionImpl(Dataset trainSet, bool[] trainSetLabels, Arguments args) + public ObjectiveFunctionImpl(Dataset trainSet, bool[] trainSetLabels, Options args) : base(trainSet, args, args.MaxTreeOutput) { _labels = trainSetLabels; @@ -272,14 +272,14 @@ public static partial class FastForest ShortName = FastForestClassification.ShortName, XmlInclude = new[] { @"", @""})] - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastForestClassification.Arguments input) + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastForestClassification.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainFastForest"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new FastForestClassification(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 71136be494..b7cb4c97fa 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -15,7 +15,7 @@ using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Training; -[assembly: LoadableClass(FastForestRegression.Summary, typeof(FastForestRegression), typeof(FastForestRegression.Arguments), +[assembly: LoadableClass(FastForestRegression.Summary, typeof(FastForestRegression), typeof(FastForestRegression.Options), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, FastForestRegression.UserNameValue, FastForestRegression.LoadNameValue, @@ -136,9 +136,9 @@ ISchemaBindableMapper IQuantileRegressionPredictor.CreateMapper(Double[] quantil /// public sealed partial class FastForestRegression - : RandomForestTrainerBase, FastForestRegressionModelParameters> + : RandomForestTrainerBase, FastForestRegressionModelParameters> { - public sealed class Arguments : FastForestArgumentsBase + public sealed class Options : FastForestArgumentsBase { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Shuffle the labels on every iteration. " + "Useful probably only if using this tree as a tree leaf featurizer for multiclass.")] @@ -189,7 +189,7 @@ public FastForestRegression(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings) + Action advancedSettings) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -197,9 +197,9 @@ public FastForestRegression(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the legacy class. /// - public FastForestRegression(IHostEnvironment env, Arguments args) + public FastForestRegression(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), true) { } @@ -256,14 +256,14 @@ private abstract class ObjectiveFunctionImplBase : RandomForestObjectiveFunction { private readonly float[] _labels; - public static ObjectiveFunctionImplBase Create(Dataset trainData, Arguments args) + public static ObjectiveFunctionImplBase Create(Dataset trainData, Options args) { if (args.ShuffleLabels) return new ShuffleImpl(trainData, args); return new BasicImpl(trainData, args); } - private ObjectiveFunctionImplBase(Dataset trainData, Arguments args) + private ObjectiveFunctionImplBase(Dataset trainData, Options args) : base(trainData, args, double.MaxValue) // No notion of maximum step size. { _labels = FastTreeRegressionTrainer.GetDatasetRegressionLabels(trainData); @@ -283,7 +283,7 @@ private sealed class ShuffleImpl : ObjectiveFunctionImplBase private readonly Random _rgen; private readonly int _labelLim; - public ShuffleImpl(Dataset trainData, Arguments args) + public ShuffleImpl(Dataset trainData, Options args) : base(trainData, args) { Contracts.AssertValue(args); @@ -321,7 +321,7 @@ public override double[] GetGradient(IChannel ch, double[] scores) private sealed class BasicImpl : ObjectiveFunctionImplBase { - public BasicImpl(Dataset trainData, Arguments args) + public BasicImpl(Dataset trainData, Options args) : base(trainData, args) { } @@ -337,14 +337,14 @@ public static partial class FastForest ShortName = FastForestRegression.ShortName, XmlInclude = new[] { @"", @""})] - public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, FastForestRegression.Arguments input) + public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, FastForestRegression.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainFastForest"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new FastForestRegression(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index ebb81eefc9..dc9e0b395d 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -50,7 +50,7 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi string labelColumn, string featureColumn, string weights, - Action advancedSettings) + Action advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -94,7 +94,7 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica string labelColumn, string featureColumn, string weights, - Action advancedSettings) + Action advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -142,7 +142,7 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer string featureColumn, string groupId, string weights, - Action advancedSettings) + Action advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -236,7 +236,7 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr string labelColumn, string featureColumn, string weights, - Action advancedSettings = null) + Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -280,7 +280,7 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT string labelColumn, string featureColumn, string weights, - Action advancedSettings = null) + Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -324,7 +324,7 @@ public static FastForestClassification FastForest(this BinaryClassificationConte string labelColumn, string featureColumn, string weights, - Action advancedSettings) + Action advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); diff --git a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs index 0c81119760..ccfd05a63d 100644 --- a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs @@ -84,7 +84,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c /// public static Scalar FastTree(this RegressionContext.RegressionTrainers ctx, Scalar label, Vector features, Scalar weights, - Action advancedSettings, + Action advancedSettings, Action onFit = null) { CheckUserValues(label, features, weights, advancedSettings, onFit); @@ -175,7 +175,7 @@ public static (Scalar score, Scalar probability, Scalar pred /// public static (Scalar score, Scalar probability, Scalar predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, Scalar label, Vector features, Scalar weights, - Action advancedSettings, + Action advancedSettings, Action> onFit = null) { CheckUserValues(label, features, weights, advancedSettings, onFit); @@ -254,7 +254,7 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c /// The Score output column indicating the predicted value. public static Scalar FastTree(this RankingContext.RankingTrainers ctx, Scalar label, Vector features, Key groupId, Scalar weights, - Action advancedSettings, + Action advancedSettings, Action onFit = null) { CheckUserValues(label, features, weights, advancedSettings, onFit); diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 75f83c5030..afc53acc10 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -43,12 +43,12 @@ Trainers.AveragedPerceptronBinaryClassifier Averaged Perceptron Binary Classifie Trainers.EnsembleBinaryClassifier Train binary ensemble. Microsoft.ML.Ensemble.EntryPoints.Ensemble CreateBinaryEnsemble Microsoft.ML.Ensemble.EnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.EnsembleClassification Train multiclass ensemble. Microsoft.ML.Ensemble.EntryPoints.Ensemble CreateMultiClassEnsemble Microsoft.ML.Ensemble.MulticlassDataPartitionEnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput Trainers.EnsembleRegression Train regression ensemble. Microsoft.ML.Ensemble.EntryPoints.Ensemble CreateRegressionEnsemble Microsoft.ML.Ensemble.RegressionEnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput -Trainers.FastForestBinaryClassifier Uses a random forest learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastForest TrainBinary Microsoft.ML.Trainers.FastTree.FastForestClassification+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.FastForestRegressor Trains a random forest to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastForest TrainRegression Microsoft.ML.Trainers.FastTree.FastForestRegression+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput -Trainers.FastTreeBinaryClassifier Uses a logit-boost boosted tree learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastTree TrainBinary Microsoft.ML.Trainers.FastTree.FastTreeBinaryClassificationTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.FastTreeRanker Trains gradient boosted decision trees to the LambdaRank quasi-gradient. Microsoft.ML.Trainers.FastTree.FastTree TrainRanking Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput -Trainers.FastTreeRegressor Trains gradient boosted decision trees to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastTree TrainRegression Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput -Trainers.FastTreeTweedieRegressor Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression. Microsoft.ML.Trainers.FastTree.FastTree TrainTweedieRegression Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Trainers.FastForestBinaryClassifier Uses a random forest learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastForest TrainBinary Microsoft.ML.Trainers.FastTree.FastForestClassification+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.FastForestRegressor Trains a random forest to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastForest TrainRegression Microsoft.ML.Trainers.FastTree.FastForestRegression+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Trainers.FastTreeBinaryClassifier Uses a logit-boost boosted tree learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastTree TrainBinary Microsoft.ML.Trainers.FastTree.FastTreeBinaryClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.FastTreeRanker Trains gradient boosted decision trees to the LambdaRank quasi-gradient. Microsoft.ML.Trainers.FastTree.FastTree TrainRanking Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput +Trainers.FastTreeRegressor Trains gradient boosted decision trees to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastTree TrainRegression Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Trainers.FastTreeTweedieRegressor Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression. Microsoft.ML.Trainers.FastTree.FastTree TrainTweedieRegression Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware factorization machine for binary classification Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer TrainBinary Microsoft.ML.FactorizationMachine.FieldAwareFactorizationMachineTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.GeneralizedAdditiveModelBinaryClassifier Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainBinary Microsoft.ML.Trainers.FastTree.BinaryClassificationGamTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.GeneralizedAdditiveModelRegressor Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainRegression Microsoft.ML.Trainers.FastTree.RegressionGamTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 4540e26326..12bb992b1c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -3581,7 +3581,7 @@ public void EntryPointTreeLeafFeaturizer() Column = new[] { new ColumnConcatenatingTransformer.Column { Name = "Features", Source = new[] { "Categories", "NumericFeatures" } } } }); - var fastTree = FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Arguments + var fastTree = FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Options { FeatureColumn = "Features", NumTrees = 5, diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 223fb683cf..2b982dc2d0 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -611,7 +611,7 @@ public void TestTreeEnsembleCombiner() var fastTrees = new PredictorModel[3]; for (int i = 0; i < 3; i++) { - fastTrees[i] = FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Arguments + fastTrees[i] = FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Options { FeatureColumn = "Features", NumTrees = 5, @@ -634,7 +634,7 @@ public void TestTreeEnsembleCombinerWithCategoricalSplits() var fastTrees = new PredictorModel[3]; for (int i = 0; i < 3; i++) { - fastTrees[i] = FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Arguments + fastTrees[i] = FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Options { FeatureColumn = "Features", NumTrees = 5, @@ -733,7 +733,7 @@ public void TestEnsembleCombiner() var predictors = new PredictorModel[] { - FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Arguments + FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Options { FeatureColumn = "Features", NumTrees = 5, From b255eb3e7b9bdb8f0fe481bbadd1cec94d724538 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 7 Jan 2019 02:34:57 +0000 Subject: [PATCH 4/8] Pass objects as arguments instead of delegate --- src/Microsoft.ML.FastTree/BoostingFastTree.cs | 2 +- src/Microsoft.ML.FastTree/FastTree.cs | 7 +-- .../FastTreeClassification.cs | 4 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 4 +- .../FastTreeRegression.cs | 4 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 4 +- src/Microsoft.ML.FastTree/RandomForest.cs | 2 +- .../RandomForestClassification.cs | 4 +- .../RandomForestRegression.cs | 4 +- .../TreeTrainersCatalog.cs | 12 ++--- .../TreeTrainersStatic.cs | 17 ++++--- .../Algorithms/SmacSweeper.cs | 8 +-- test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 2 +- .../TrainerEstimators/TreeEstimators.cs | 50 +++++++++---------- 14 files changed, 59 insertions(+), 65 deletions(-) diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index dc5d9d9967..963624f2de 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -44,7 +44,7 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, string featureColumn, string weightColumn, string groupIdColumn, - Action advancedSettings) + TArgs advancedSettings) : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) { } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index f5ae2db79f..8fc322722f 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -155,13 +155,10 @@ private protected FastTreeTrainerBase(IHostEnvironment env, string featureColumn, string weightColumn, string groupIdColumn, - Action advancedSettings) + TArgs advancedSettings) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { - Args = new TArgs(); - - //apply the advanced args, if the user supplied any - advancedSettings?.Invoke(Args); + Args = advancedSettings; Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 0502dc8295..04ad335260 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -148,12 +148,12 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. + /// Advanced arguments to the algorithm. public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings) + Options advancedSettings) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 703c3985b7..e52bb2557e 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -93,13 +93,13 @@ public FastTreeRankingTrainer(IHostEnvironment env, /// The name of the feature column. /// The name for the column containing the group ID. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. + /// Advanced arguments to the algorithm. public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, string weightColumn, - Action advancedSettings = null) + Options advancedSettings) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index a6acd1940a..49d2515c57 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -81,12 +81,12 @@ public FastTreeRegressionTrainer(IHostEnvironment env, /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. + /// Advanced arguments to the algorithm. public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings = null) + Options advancedSettings) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 3fce1f56e3..03aa0bd511 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -82,12 +82,12 @@ public FastTreeTweedieTrainer(IHostEnvironment env, /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. + /// Advanced arguments to the algorithm. public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings) + Options advancedSettings) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index ba4762d4e0..c249cab45b 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -50,7 +50,7 @@ protected RandomForestTrainerBase(IHostEnvironment env, string featureColumn, string weightColumn, string groupIdColumn, - Action advancedSettings, + TArgs advancedSettings, bool quantileEnabled = false) : base(env, label, featureColumn, weightColumn, null, advancedSettings) { diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 73566636fc..d5c427fdeb 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -166,12 +166,12 @@ public FastForestClassification(IHostEnvironment env, /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. + /// Advanced arguments to the algorithm. public FastForestClassification(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings) + Options advancedSettings) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index b7cb4c97fa..2c44ee0256 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -184,12 +184,12 @@ public FastForestRegression(IHostEnvironment env, /// The name of the label column. /// The name of the feature column. /// The optional name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. + /// Advanced arguments to the algorithm. public FastForestRegression(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn, - Action advancedSettings) + Options advancedSettings) : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index dc9e0b395d..4f5fa93fa6 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -50,7 +50,7 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi string labelColumn, string featureColumn, string weights, - Action advancedSettings) + FastTreeRegressionTrainer.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -94,7 +94,7 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica string labelColumn, string featureColumn, string weights, - Action advancedSettings) + FastTreeBinaryClassificationTrainer.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -142,7 +142,7 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer string featureColumn, string groupId, string weights, - Action advancedSettings) + FastTreeRankingTrainer.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -236,7 +236,7 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr string labelColumn, string featureColumn, string weights, - Action advancedSettings = null) + FastTreeTweedieTrainer.Options advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -280,7 +280,7 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT string labelColumn, string featureColumn, string weights, - Action advancedSettings = null) + FastForestRegression.Options advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); @@ -324,7 +324,7 @@ public static FastForestClassification FastForest(this BinaryClassificationConte string labelColumn, string featureColumn, string weights, - Action advancedSettings) + FastForestClassification.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); diff --git a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs index ccfd05a63d..cce3aaa1a5 100644 --- a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs @@ -84,10 +84,11 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c /// public static Scalar FastTree(this RegressionContext.RegressionTrainers ctx, Scalar label, Vector features, Scalar weights, - Action advancedSettings, + FastTreeRegressionTrainer.Options advancedSettings, Action onFit = null) { - CheckUserValues(label, features, weights, advancedSettings, onFit); + Contracts.CheckValueOrNull(advancedSettings); + CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -175,10 +176,11 @@ public static (Scalar score, Scalar probability, Scalar pred /// public static (Scalar score, Scalar probability, Scalar predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, Scalar label, Vector features, Scalar weights, - Action advancedSettings, + FastTreeBinaryClassificationTrainer.Options advancedSettings, Action> onFit = null) { - CheckUserValues(label, features, weights, advancedSettings, onFit); + Contracts.CheckValueOrNull(advancedSettings); + CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -254,10 +256,11 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c /// The Score output column indicating the predicted value. public static Scalar FastTree(this RankingContext.RankingTrainers ctx, Scalar label, Vector features, Key groupId, Scalar weights, - Action advancedSettings, + FastTreeRankingTrainer.Options advancedSettings, Action onFit = null) { - CheckUserValues(label, features, weights, advancedSettings, onFit); + Contracts.CheckValueOrNull(advancedSettings); + CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => @@ -289,13 +292,11 @@ internal static void CheckUserValues(PipelineColumn label, Vector feature } internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, - Delegate advancedSettings, Delegate onFit) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); Contracts.CheckValueOrNull(weights); - Contracts.CheckValueOrNull(advancedSettings); Contracts.CheckValueOrNull(onFit); } } diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index 5af931160f..7f79578ad8 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -136,11 +136,11 @@ private FastForestRegressionModelParameters FitModel(IEnumerable pre // Set relevant random forest arguments. // Train random forest. var trainer = new FastForestRegression(_host, DefaultColumnNames.Label, DefaultColumnNames.Features, null, - advancedSettings: s => + new FastForestRegression.Options { - s.FeatureFraction = _args.SplitRatio; - s.NumTrees = _args.NumOfTrees; - s.MinDocumentsInLeafs = _args.NMinForSplit; + FeatureFraction = _args.SplitRatio, + NumTrees = _args.NumOfTrees, + MinDocumentsInLeafs = _args.NMinForSplit, }); var predictor = trainer.Train(data); diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index e5c73f976b..c0a7a6dd6a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -105,7 +105,7 @@ public void OvaFastTree() var pipeline = new Ova( mlContext, new FastTreeBinaryClassificationTrainer(mlContext, DefaultColumnNames.Label, DefaultColumnNames.Features, null, - advancedSettings: s => { s.NumThreads = 1; }), + new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1 }), useProbabilities: false); var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index d610e26ab2..49149dde32 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -25,10 +25,10 @@ public void FastTreeBinaryEstimator() var (pipe, dataView) = GetBinaryClassificationPipeline(); var trainer = new FastTreeBinaryClassificationTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, - advancedSettings: s => { - s.NumThreads = 1; - s.NumTrees = 10; - s.NumLeaves = 5; + new FastTreeBinaryClassificationTrainer.Options { + NumThreads = 1, + NumTrees = 10, + NumLeaves = 5, }); var pipeWithTrainer = pipe.Append(trainer); @@ -83,11 +83,12 @@ public void FastForestClassificationEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new FastForestClassification(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s => - { - s.NumLeaves = 10; - s.NumTrees = 20; - }); + var trainer = new FastForestClassification(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + new FastForestClassification.Options { + NumLeaves = 10, + NumTrees = 20, + }); + var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -104,8 +105,7 @@ public void FastTreeRankerEstimator() { var (pipe, dataView) = GetRankingPipeline(); - var trainer = new FastTreeRankingTrainer(Env, "Label0", "NumericFeatures", "Group", null, - advancedSettings: s => { s.NumTrees = 10; }); + var trainer = new FastTreeRankingTrainer(Env, "Label0", "NumericFeatures", "Group", null, new FastTreeRankingTrainer.Options { NumTrees = 10 }); var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -139,12 +139,8 @@ public void LightGBMRankerEstimator() public void FastTreeRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastTreeRegressionTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s => - { - s.NumTrees = 10; - s.NumThreads = 1; - s.NumLeaves = 5; - }); + var trainer = new FastTreeRegressionTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + new FastTreeRegressionTrainer.Options { NumTrees = 10, NumThreads = 1, NumLeaves = 5 }); TestEstimatorCore(trainer, dataView); var model = trainer.Train(dataView, dataView); @@ -196,11 +192,11 @@ public void GAMRegressorEstimator() public void TweedieRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastTreeTweedieTrainer(Env, "Label", "Features", null, advancedSettings: s => - { - s.EntropyCoefficient = 0.3; - s.OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent; - }); + var trainer = new FastTreeTweedieTrainer(Env, "Label", "Features", null, + new FastTreeTweedieTrainer.Options { + EntropyCoefficient = 0.3, + OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent, + }); TestEstimatorCore(trainer, dataView); var model = trainer.Train(dataView, dataView); @@ -214,11 +210,11 @@ public void TweedieRegressorEstimator() public void FastForestRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastForestRegression(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s => - { - s.BaggingSize = 2; - s.NumTrees = 10; - }); + var trainer = new FastForestRegression(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + new FastForestRegression.Options { + BaggingSize = 2, + NumTrees = 10, + }); TestEstimatorCore(trainer, dataView); var model = trainer.Train(dataView, dataView); From 334156149316fd1a084528654292423b1641dde1 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 8 Jan 2019 00:07:39 +0000 Subject: [PATCH 5/8] review comments - 1 --- .../EntryPoints/InputBase.cs | 21 ++ src/Microsoft.ML.FastTree/BoostingFastTree.cs | 17 +- src/Microsoft.ML.FastTree/FastTree.cs | 37 +--- .../FastTreeArguments.cs | 181 ++++++++++++++++++ .../FastTreeClassification.cs | 21 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 22 +-- .../FastTreeRegression.cs | 19 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 23 +-- src/Microsoft.ML.FastTree/RandomForest.cs | 15 -- .../RandomForestClassification.cs | 19 -- .../RandomForestRegression.cs | 19 -- .../TreeTrainersCatalog.cs | 54 +----- .../TreeTrainersStatic.cs | 6 +- .../Algorithms/SmacSweeper.cs | 2 +- test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 3 +- .../TrainerEstimators/TreeEstimators.cs | 17 +- 16 files changed, 232 insertions(+), 244 deletions(-) diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs index b895fcbf91..edfd865e8b 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs @@ -35,15 +35,27 @@ public enum CachingOptions [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))] public abstract class LearnerInputBase { + /// + /// The data to be used for training. + /// [Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public IDataView TrainingData; + /// + /// Column to use for features. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public string FeatureColumn = DefaultColumnNames.Features; + /// + /// Normalize option for the feature column. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public NormalizeOption NormalizeFeatures = NormalizeOption.Auto; + /// + /// Whether learner should cache input training data. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public CachingOptions Caching = CachingOptions.Auto; } @@ -54,6 +66,9 @@ public abstract class LearnerInputBase [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))] public abstract class LearnerInputBaseWithLabel : LearnerInputBase { + /// + /// Column to use for labels. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public string LabelColumn = DefaultColumnNames.Label; } @@ -65,6 +80,9 @@ public abstract class LearnerInputBaseWithLabel : LearnerInputBase [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))] public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel { + /// + /// Column to use for example weight. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); } @@ -95,6 +113,9 @@ public abstract class EvaluateInputBase [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))] public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight { + /// + /// Column to use for example groupId. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public Optional GroupIdColumn = Optional.Implicit(DefaultColumnNames.GroupId); } diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 963624f2de..460f37efda 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -31,22 +31,7 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, double learningRate) : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves) { - - if (Args.LearningRates != learningRate) - { - using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments.")) - Args.LearningRates = learningRate; - } - } - - protected BoostingFastTreeTrainerBase(IHostEnvironment env, - SchemaShape.Column label, - string featureColumn, - string weightColumn, - string groupIdColumn, - TArgs advancedSettings) - : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) - { + Args.LearningRates = learningRate; } protected override void CheckArgs(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 8fc322722f..d03e836076 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -148,42 +148,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, } /// - /// Constructor to use when instantiating the classes deriving from here through the API. - /// - private protected FastTreeTrainerBase(IHostEnvironment env, - SchemaShape.Column label, - string featureColumn, - string weightColumn, - string groupIdColumn, - TArgs advancedSettings) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) - { - Args = advancedSettings; - - Args.LabelColumn = label.Name; - Args.FeatureColumn = featureColumn; - - if (weightColumn != null) - Args.WeightColumn = Optional.Explicit(weightColumn); - - if (groupIdColumn != null) - Args.GroupIdColumn = Optional.Explicit(groupIdColumn); - - // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. - // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. - // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. - Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true, supportTest: true); - // REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment - // with memory consumption more than 5GB, GC get stuck in infinite loop. - // Before, we could check a specific type of the environment here, but now it is internal, so we will need another - // mechanism to detect that we are running in Scope. - AllowGC = true; - - Initialize(env); - } - - /// - /// Legacy constructor that is used when invoking the classes deriving from this, through maml. + /// Constructor that is used when invoking the classes deriving from this, through maml. /// private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 3d66975c30..c265afc02f 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -26,6 +26,9 @@ public sealed partial class FastTreeBinaryClassificationTrainer [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)] public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory { + /// + /// Should we use derivatives optimized for unbalanced sets? + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")] [TGUI(Label = "Optimize for unbalanced")] public bool UnbalancedSets = false; @@ -149,9 +152,15 @@ internal static class Defaults public abstract class TreeArgs : LearnerInputBaseWithGroupId { + /// + /// Allows to choose Parallel FastTree Learning Algorithm. + /// [Argument(ArgumentType.Multiple, HelpText = "Allows to choose Parallel FastTree Learning Algorithm", ShortName = "parag")] public ISupportParallelTraining ParallelTrainer = new SingleTrainerFactory(); + /// + /// The number of threads to use. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The number of threads to use", ShortName = "t", NullName = "")] public int? NumThreads = null; @@ -162,47 +171,86 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId // 4. tree learner // 5. bagging provider // 6. emsemble compressor + /// + /// The seed of the random number generator. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The seed of the random number generator", ShortName = "r1")] public int RngSeed = 123; // this random seed is only for active feature selection + /// + /// The seed of the active feature selection. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The seed of the active feature selection", ShortName = "r3", Hide = true)] [TGUI(NotGui = true)] public int FeatureSelectSeed = 123; + /// + /// The entropy (regularization) coefficient between 0 and 1. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The entropy (regularization) coefficient between 0 and 1", ShortName = "e")] public Double EntropyCoefficient; // REVIEW: Different short name from TLC FR arguments. + /// + /// The number of histograms in the pool (between 2 and numLeaves). + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The number of histograms in the pool (between 2 and numLeaves)", ShortName = "ps")] public int HistogramPoolSize = -1; + /// + /// Whether to utilize the disk or the data's native transposition facilities (where applicable) when performing the transpose. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether to utilize the disk or the data's native transposition facilities (where applicable) when performing the transpose", ShortName = "dt")] public bool? DiskTranspose; + /// + /// Whether to collectivize features during dataset preparation to speed up training. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether to collectivize features during dataset preparation to speed up training", ShortName = "flocks", Hide = true)] public bool FeatureFlocks = true; + /// + /// Whether to do split based on multiple categorical feature values. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether to do split based on multiple categorical feature values.", ShortName = "cat")] public bool CategoricalSplit = false; + /// + /// Maximum categorical split groups to consider when splitting on a categorical feature. Split groups are a collection of split points. This is used to reduce overfitting when there many categorical features. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Maximum categorical split groups to consider when splitting on a categorical feature. " + "Split groups are a collection of split points. This is used to reduce overfitting when " + "there many categorical features.", ShortName = "mcg")] public int MaxCategoricalGroupsPerNode = 64; + /// + /// Maximum categorical split points to consider when splitting on a categorical feature. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Maximum categorical split points to consider when splitting on a categorical feature.", ShortName = "maxcat")] public int MaxCategoricalSplitPoints = 64; + /// + /// Minimum categorical docs percentage in a bin to consider for a split. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Minimum categorical docs percentage in a bin to consider for a split.", ShortName = "mdop")] public double MinDocsPercentageForCategoricalSplit = 0.001; + /// + /// Minimum categorical doc count in a bin to consider for a split. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Minimum categorical doc count in a bin to consider for a split.", ShortName = "mdo")] public int MinDocsForCategoricalSplit = 100; + /// + /// Bias for calculating gradient for each feature bin for a categorical feature. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Bias for calculating gradient for each feature bin for a categorical feature.", ShortName = "bias")] public double Bias = 0; + /// + /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Bundle low population bins. " + "Bundle.None(0): no bundling, " + "Bundle.AggregateLowPopulation(1): Bundle low population, " + @@ -211,35 +259,62 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId // REVIEW: Different default from TLC FR. I prefer the TLC FR default of 255. // REVIEW: Reverting back to 255 to make the same defaults of FR. + /// + /// Maximum number of distinct values (bins) per feature. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Maximum number of distinct values (bins) per feature", ShortName = "mb")] public int MaxBins = 255; // save one for undefs + /// + /// Sparsity level needed to use sparse feature representation. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Sparsity level needed to use sparse feature representation", ShortName = "sp")] public Double SparsifyThreshold = 0.7; + /// + /// The feature first use penalty coefficient. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The feature first use penalty coefficient", ShortName = "ffup")] public Double FeatureFirstUsePenalty; + /// + /// The feature re-use penalty (regularization) coefficient. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The feature re-use penalty (regularization) coefficient", ShortName = "frup")] public Double FeatureReusePenalty; /// Only consider a gain if its likelihood versus a random choice gain is above a certain value. /// So 0.95 would mean restricting to gains that have less than a 0.05 change of being generated randomly through choice of a random split. + /// + /// Tree fitting gain confidence requirement (should be in the range [0,1) ). + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Tree fitting gain confidence requirement (should be in the range [0,1) ).", ShortName = "gainconf")] public Double GainConfidenceLevel; + /// + /// The temperature of the randomized softmax distribution for choosing the feature. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The temperature of the randomized softmax distribution for choosing the feature", ShortName = "smtemp")] public Double SoftmaxTemperature; + /// + /// Print execution time breakdown to stdout. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Print execution time breakdown to stdout", ShortName = "et")] public bool ExecutionTimes; // REVIEW: Different from original FastRank arguments (shortname l vs. nl). Different default from TLC FR Wrapper (20 vs. 20). + /// + /// The max number of leaves in each regression tree. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The max number of leaves in each regression tree", ShortName = "nl", SortOrder = 2)] [TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")] [TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)] public int NumLeaves = Defaults.NumLeaves; + /// + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// // REVIEW: Arrays not supported in GUI // REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper. [Argument(ArgumentType.LastOccurenceWins, HelpText = "The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data", ShortName = "mil", SortOrder = 3)] @@ -247,18 +322,30 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId [TlcModule.SweepableDiscreteParamAttribute("MinDocumentsInLeafs", new object[] { 1, 10, 50 })] public int MinDocumentsInLeafs = Defaults.MinDocumentsInLeaves; + /// + /// Total number of decision trees to create in the ensemble. + /// // REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper. [Argument(ArgumentType.LastOccurenceWins, HelpText = "Total number of decision trees to create in the ensemble", ShortName = "iter", SortOrder = 1)] [TGUI(Description = "Total number of trees constructed", SuggestedSweeps = "20,100,500")] [TlcModule.SweepableDiscreteParamAttribute("NumTrees", new object[] { 20, 100, 500 })] public int NumTrees = Defaults.NumTrees; + /// + /// The fraction of features (chosen randomly) to use on each iteration. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The fraction of features (chosen randomly) to use on each iteration", ShortName = "ff")] public Double FeatureFraction = 1; + /// + /// Number of trees in each bag (0 for disabling bagging). + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Number of trees in each bag (0 for disabling bagging)", ShortName = "bag")] public int BaggingSize; + /// + /// Percentage of training examples used in each bag. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Percentage of training examples used in each bag", ShortName = "bagfrac")] // REVIEW: sweeping bagfrac doesn't make sense unless 'baggingSize' is non-zero. The 'SuggestedSweeps' here // are used to denote 'sensible range', but the GUI will interpret this as 'you must sweep these values'. So, I'm keeping @@ -266,38 +353,65 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId // [TGUI(SuggestedSweeps = "0.5,0.7,0.9")] public Double BaggingTrainFraction = 0.7; + /// + /// The fraction of features (chosen randomly) to use on each split. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The fraction of features (chosen randomly) to use on each split", ShortName = "sf")] public Double SplitFraction = 1; + /// + /// Smoothing paramter for tree regularization. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing paramter for tree regularization", ShortName = "s")] public Double Smoothing; + /// + /// When a root split is impossible, allow training to proceed. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "When a root split is impossible, allow training to proceed", ShortName = "allowempty,dummies", Hide = true)] [TGUI(NotGui = true)] public bool AllowEmptyTrees = true; + /// + /// The level of feature compression to use. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The level of feature compression to use", ShortName = "fcomp", Hide = true)] [TGUI(NotGui = true)] public int FeatureCompressionLevel = 1; + /// + /// Compress the tree Ensemble. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Compress the tree Ensemble", ShortName = "cmp", Hide = true)] [TGUI(NotGui = true)] public bool CompressEnsemble; + /// + /// Maximum Number of trees after compression. + /// // REVIEW: Not used. [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum Number of trees after compression", ShortName = "cmpmax", Hide = true)] [TGUI(NotGui = true)] public int MaxTreesAfterCompression = -1; + /// + /// Print metrics graph for the first test set. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Print metrics graph for the first test set", ShortName = "graph", Hide = true)] [TGUI(NotGui = true)] public bool PrintTestGraph; + /// + /// Print Train and Validation metrics in graph. + /// //It is only enabled if printTestGraph is also set [Argument(ArgumentType.LastOccurenceWins, HelpText = "Print Train and Validation metrics in graph", ShortName = "graphtv", Hide = true)] [TGUI(NotGui = true)] public bool PrintTrainValidGraph; + /// + /// Calculate metric values for train/valid/test every k rounds. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Calculate metric values for train/valid/test every k rounds", ShortName = "tf")] public int TestFrequency = int.MaxValue; @@ -333,72 +447,130 @@ public abstract class BoostedTreeArgs : TreeArgs // REVIEW: TLC FR likes to call it bestStepRegressionTrees which might be more appropriate. //Use the second derivative for split gains (not just outputs). Use MaxTreeOutput to "clip" cases where the second derivative is too close to zero. //Turning BSR on makes larger steps in initial stages and converges to better results with fewer trees (though in the end, it asymptotes to the same results). + /// + /// Use best regression step trees? + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Use best regression step trees?", ShortName = "bsr")] public bool BestStepRankingRegressionTrees = false; + /// + /// Should we use line search for a step size. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use line search for a step size", ShortName = "ls")] public bool UseLineSearch; + /// + /// Number of post-bracket line search steps. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of post-bracket line search steps", ShortName = "lssteps")] public int NumPostBracketSteps; + /// + /// Minimum line search step size. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Minimum line search step size", ShortName = "minstep")] public Double MinStepSize; public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDescent, ConjugateGradientDescent }; + + /// + /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent). + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent)", ShortName = "oa")] public OptimizationAlgorithmType OptimizationAlgorithm = OptimizationAlgorithmType.GradientDescent; + /// + /// Early stopping rule. (Validation set (/valid) is required). + /// [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", ShortName = "esr", NullName = "")] [TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")] public IEarlyStoppingCriterionFactory EarlyStoppingRule; + /// + /// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3). + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3)", ShortName = "esmt")] [TGUI(Description = "Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3)")] public int EarlyStoppingMetrics; + /// + /// Enable post-training pruning to avoid overfitting. (a validation set is required). + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Enable post-training pruning to avoid overfitting. (a validation set is required)", ShortName = "pruning")] public bool EnablePruning; + /// + /// Use window and tolerance for pruning. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Use window and tolerance for pruning", ShortName = "prtol")] public bool UseTolerantPruning; + /// + /// The tolerance threshold for pruning. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The tolerance threshold for pruning", ShortName = "prth")] [TGUI(Description = "Pruning threshold")] public Double PruningThreshold = 0.004; + /// + /// The moving window size for pruning. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The moving window size for pruning", ShortName = "prws")] [TGUI(Description = "Pruning window size")] public int PruningWindowSize = 5; + /// + /// The learning rate. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The learning rate", ShortName = "lr", SortOrder = 4)] [TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")] [TlcModule.SweepableFloatParamAttribute("LearningRates", 0.025f, 0.4f, isLogScale: true)] public Double LearningRates = Defaults.LearningRates; + /// + /// Shrinkage. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Shrinkage", ShortName = "shrk")] [TGUI(Label = "Shrinkage", SuggestedSweeps = "0.25-4;log")] [TlcModule.SweepableFloatParamAttribute("Shrinkage", 0.025f, 4f, isLogScale: true)] public Double Shrinkage = 1; + /// + /// Dropout rate for tree regularization. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Dropout rate for tree regularization", ShortName = "tdrop")] [TGUI(SuggestedSweeps = "0,0.000000001,0.05,0.1,0.2")] [TlcModule.SweepableDiscreteParamAttribute("DropoutRate", new object[] { 0.0f, 1E-9f, 0.05f, 0.1f, 0.2f })] public Double DropoutRate = 0; + /// + /// Sample each query 1 in k times in the GetDerivatives function. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Sample each query 1 in k times in the GetDerivatives function", ShortName = "sr")] public int GetDerivativesSampleRate = 1; + /// + /// Write the last ensemble instead of the one determined by early stopping. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Write the last ensemble instead of the one determined by early stopping", ShortName = "hl")] public bool WriteLastEnsemble; + /// + /// Upper bound on absolute value of single tree output. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Upper bound on absolute value of single tree output", ShortName = "mo")] public Double MaxTreeOutput = 100; + /// + /// Training starts from random ordering (determined by /r1). + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Training starts from random ordering (determined by /r1)", ShortName = "rs", Hide = true)] [TGUI(NotGui = true)] public bool RandomStart; + /// + /// Filter zero lambdas during training. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Filter zero lambdas during training", ShortName = "fzl", Hide = true)] [TGUI(NotGui = true)] public bool FilterZeroLambdas; @@ -418,14 +590,23 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc public int forceGCFeatureExtraction = 100; #endif + /// + /// Freeform defining the scores that should be used as the baseline ranker. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Freeform defining the scores that should be used as the baseline ranker", ShortName = "basescores", Hide = true)] [TGUI(NotGui = true)] public string BaselineScoresFormula; + /// + /// Baseline alpha for tradeoffs of risk (0 is normal training). + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Baseline alpha for tradeoffs of risk (0 is normal training)", ShortName = "basealpha", Hide = true)] [TGUI(NotGui = true)] public string BaselineAlphaRisk; + /// + /// The discount freeform which specifies the per position discounts of documents in a query (uses a single variable P for position where P=0 is first position). + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "The discount freeform which specifies the per position discounts of documents in a query (uses a single variable P for position where P=0 is first position)", ShortName = "pdff", Hide = true)] [TGUI(NotGui = true)] diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 04ad335260..89b21ac7a1 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -141,29 +141,10 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, _sigmoidParameter = 2.0 * Args.LearningRates; } - /// - /// Initializes a new instance of - /// - /// The private instance of . - /// The name of the label column. - /// The name of the feature column. - /// The name for the column containing the initial weight. - /// Advanced arguments to the algorithm. - public FastTreeBinaryClassificationTrainer(IHostEnvironment env, - string labelColumn, - string featureColumn, - string weightColumn, - Options advancedSettings) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) - { - // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss - _sigmoidParameter = 2.0 * Args.LearningRates; - } - /// /// Initializes a new instance of by using the legacy class. /// - internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options args) + public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index e52bb2557e..80205bca4a 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -85,30 +85,10 @@ public FastTreeRankingTrainer(IHostEnvironment env, Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } - /// - /// Initializes a new instance of - /// - /// The private instance of . - /// The name of the label column. - /// The name of the feature column. - /// The name for the column containing the group ID. - /// The name for the column containing the initial weight. - /// Advanced arguments to the algorithm. - public FastTreeRankingTrainer(IHostEnvironment env, - string labelColumn, - string featureColumn, - string groupIdColumn, - string weightColumn, - Options advancedSettings) - : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) - { - Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); - } - /// /// Initializes a new instance of by using the legacy class. /// - internal FastTreeRankingTrainer(IHostEnvironment env, Options args) + public FastTreeRankingTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 49d2515c57..e65985fb94 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -74,27 +74,10 @@ public FastTreeRegressionTrainer(IHostEnvironment env, { } - /// - /// Initializes a new instance of - /// - /// The private instance of . - /// The name of the label column. - /// The name of the feature column. - /// The name for the column containing the initial weight. - /// Advanced arguments to the algorithm. - public FastTreeRegressionTrainer(IHostEnvironment env, - string labelColumn, - string featureColumn, - string weightColumn, - Options advancedSettings) - : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) - { - } - /// /// Initializes a new instance of by using the legacy class. /// - internal FastTreeRegressionTrainer(IHostEnvironment env, Options args) + public FastTreeRegressionTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 03aa0bd511..5b6cbfe33e 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -75,31 +75,10 @@ public FastTreeTweedieTrainer(IHostEnvironment env, Initialize(); } - /// - /// Initializes a new instance of - /// - /// The private instance of . - /// The name of the label column. - /// The name of the feature column. - /// The name for the column containing the initial weight. - /// Advanced arguments to the algorithm. - public FastTreeTweedieTrainer(IHostEnvironment env, - string labelColumn, - string featureColumn, - string weightColumn, - Options advancedSettings) - : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) - { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - Initialize(); - } - /// /// Initializes a new instance of by using the legacy class. /// - internal FastTreeTweedieTrainer(IHostEnvironment env, Options args) + public FastTreeTweedieTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) { Initialize(); diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index c249cab45b..f1168bdcd7 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -42,21 +42,6 @@ protected RandomForestTrainerBase(IHostEnvironment env, _quantileEnabled = quantileEnabled; } - /// - /// Constructor invoked by the API code-path. - /// - protected RandomForestTrainerBase(IHostEnvironment env, - SchemaShape.Column label, - string featureColumn, - string weightColumn, - string groupIdColumn, - TArgs advancedSettings, - bool quantileEnabled = false) - : base(env, label, featureColumn, weightColumn, null, advancedSettings) - { - _quantileEnabled = quantileEnabled; - } - protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { Host.CheckValue(ch, nameof(ch)); diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index d5c427fdeb..74fd704ab5 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -159,25 +159,6 @@ public FastForestClassification(IHostEnvironment env, Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } - /// - /// Initializes a new instance of - /// - /// The private instance of . - /// The name of the label column. - /// The name of the feature column. - /// The name for the column containing the initial weight. - /// Advanced arguments to the algorithm. - public FastForestClassification(IHostEnvironment env, - string labelColumn, - string featureColumn, - string weightColumn, - Options advancedSettings) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) - { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - } - /// /// Initializes a new instance of by using the legacy class. /// diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 2c44ee0256..c4635f48bc 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -177,25 +177,6 @@ public FastForestRegression(IHostEnvironment env, Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } - /// - /// Initializes a new instance of - /// - /// The private instance of . - /// The name of the label column. - /// The name of the feature column. - /// The optional name for the column containing the initial weight. - /// Advanced arguments to the algorithm. - public FastForestRegression(IHostEnvironment env, - string labelColumn, - string featureColumn, - string weightColumn, - Options advancedSettings) - : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) - { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - } - /// /// Initializes a new instance of by using the legacy class. /// diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index 4f5fa93fa6..4fc2a90538 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -42,19 +42,13 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// The label column. - /// The feature column. - /// The optional weights column. /// Algorithm advanced settings. public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, - string labelColumn, - string featureColumn, - string weights, FastTreeRegressionTrainer.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRegressionTrainer(env, labelColumn, featureColumn, weights, advancedSettings); + return new FastTreeRegressionTrainer(env, advancedSettings); } /// @@ -86,19 +80,13 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica /// Predict a target using a decision tree binary classification model trained with the . /// /// The . - /// The labelColumn column. - /// The featureColumn column. - /// The optional weights column. /// Algorithm advanced settings. public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, - string labelColumn, - string featureColumn, - string weights, FastTreeBinaryClassificationTrainer.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeBinaryClassificationTrainer(env, labelColumn, featureColumn, weights, advancedSettings); + return new FastTreeBinaryClassificationTrainer(env, advancedSettings); } /// @@ -132,21 +120,13 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . /// /// The . - /// The labelColumn column. - /// The featureColumn column. - /// The groupId column. - /// The optional weights column. /// Algorithm advanced settings. public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, - string labelColumn, - string featureColumn, - string groupId, - string weights, FastTreeRankingTrainer.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRankingTrainer(env, labelColumn, featureColumn, groupId, weights, advancedSettings); + return new FastTreeRankingTrainer(env, advancedSettings); } /// @@ -228,19 +208,13 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// The labelColumn column. - /// The featureColumn column. - /// The optional weights column. /// Algorithm advanced settings. public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, - string labelColumn, - string featureColumn, - string weights, - FastTreeTweedieTrainer.Options advancedSettings = null) + FastTreeTweedieTrainer.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeTweedieTrainer(env, labelColumn, featureColumn, weights, advancedSettings); + return new FastTreeTweedieTrainer(env, advancedSettings); } /// @@ -272,19 +246,13 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// The labelColumn column. - /// The featureColumn column. - /// The optional weights column. /// Algorithm advanced settings. public static FastForestRegression FastForest(this RegressionContext.RegressionTrainers ctx, - string labelColumn, - string featureColumn, - string weights, - FastForestRegression.Options advancedSettings = null) + FastForestRegression.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastForestRegression(env, labelColumn, featureColumn, weights, advancedSettings); + return new FastForestRegression(env, advancedSettings); } /// @@ -316,19 +284,13 @@ public static FastForestClassification FastForest(this BinaryClassificationConte /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// The labelColumn column. - /// The featureColumn column. - /// The optional weights column. /// Algorithm advanced settings. public static FastForestClassification FastForest(this BinaryClassificationContext.BinaryClassificationTrainers ctx, - string labelColumn, - string featureColumn, - string weights, FastForestClassification.Options advancedSettings) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastForestClassification(env, labelColumn, featureColumn, weights, advancedSettings); + return new FastForestClassification(env, advancedSettings); } } } diff --git a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs index cce3aaa1a5..e14764a83c 100644 --- a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs @@ -93,7 +93,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { - var trainer = new FastTreeRegressionTrainer(env, labelName, featuresName, weightsName, advancedSettings); + var trainer = new FastTreeRegressionTrainer(env, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -185,7 +185,7 @@ public static (Scalar score, Scalar probability, Scalar pred var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { - var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, advancedSettings); + var trainer = new FastTreeBinaryClassificationTrainer(env, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -265,7 +265,7 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { - var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, advancedSettings); + var trainer = new FastTreeRankingTrainer(env, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index 7f79578ad8..5a80a854e7 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -135,7 +135,7 @@ private FastForestRegressionModelParameters FitModel(IEnumerable pre { // Set relevant random forest arguments. // Train random forest. - var trainer = new FastForestRegression(_host, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + var trainer = new FastForestRegression(_host, new FastForestRegression.Options { FeatureFraction = _args.SplitRatio, diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index c0a7a6dd6a..9e4d8f5ad4 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -104,8 +104,7 @@ public void OvaFastTree() // Pipeline var pipeline = new Ova( mlContext, - new FastTreeBinaryClassificationTrainer(mlContext, DefaultColumnNames.Label, DefaultColumnNames.Features, null, - new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1 }), + new FastTreeBinaryClassificationTrainer(mlContext, new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1 }), useProbabilities: false); var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 49149dde32..14a27ac283 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -24,7 +24,7 @@ public void FastTreeBinaryEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new FastTreeBinaryClassificationTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + var trainer = new FastTreeBinaryClassificationTrainer(Env, new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1, NumTrees = 10, @@ -83,7 +83,7 @@ public void FastForestClassificationEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new FastForestClassification(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + var trainer = new FastForestClassification(Env, new FastForestClassification.Options { NumLeaves = 10, NumTrees = 20, @@ -105,7 +105,12 @@ public void FastTreeRankerEstimator() { var (pipe, dataView) = GetRankingPipeline(); - var trainer = new FastTreeRankingTrainer(Env, "Label0", "NumericFeatures", "Group", null, new FastTreeRankingTrainer.Options { NumTrees = 10 }); + var trainer = new FastTreeRankingTrainer(Env, + new FastTreeRankingTrainer.Options { + FeatureColumn = "NumericFeatures", + NumTrees = 10 + }); + var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -139,7 +144,7 @@ public void LightGBMRankerEstimator() public void FastTreeRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastTreeRegressionTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + var trainer = new FastTreeRegressionTrainer(Env, new FastTreeRegressionTrainer.Options { NumTrees = 10, NumThreads = 1, NumLeaves = 5 }); TestEstimatorCore(trainer, dataView); @@ -192,7 +197,7 @@ public void GAMRegressorEstimator() public void TweedieRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastTreeTweedieTrainer(Env, "Label", "Features", null, + var trainer = new FastTreeTweedieTrainer(Env, new FastTreeTweedieTrainer.Options { EntropyCoefficient = 0.3, OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent, @@ -210,7 +215,7 @@ public void TweedieRegressorEstimator() public void FastForestRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastForestRegression(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, + var trainer = new FastForestRegression(Env, new FastForestRegression.Options { BaggingSize = 2, NumTrees = 10, From 05018b8492b6b8d890e78cdd009726525f0db0a5 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 8 Jan 2019 18:43:23 +0000 Subject: [PATCH 6/8] review comments - 2. updating comments, help summary etc --- src/Microsoft.ML.FastTree/FastTreeArguments.cs | 8 ++++---- src/Microsoft.ML.FastTree/FastTreeClassification.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeRegression.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 2 +- src/Microsoft.ML.FastTree/RandomForestClassification.cs | 2 +- src/Microsoft.ML.FastTree/RandomForestRegression.cs | 2 +- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index c265afc02f..926ac66bb7 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -27,9 +27,9 @@ public sealed partial class FastTreeBinaryClassificationTrainer public sealed class Options : BoostedTreeArgs, IFastTreeTrainerFactory { /// - /// Should we use derivatives optimized for unbalanced sets? + /// Option for using derivatives optimized for unbalanced sets. /// - [Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")] + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Option for using derivatives optimized for unbalanced sets", ShortName = "us")] [TGUI(Label = "Optimize for unbalanced")] public bool UnbalancedSets = false; @@ -448,9 +448,9 @@ public abstract class BoostedTreeArgs : TreeArgs //Use the second derivative for split gains (not just outputs). Use MaxTreeOutput to "clip" cases where the second derivative is too close to zero. //Turning BSR on makes larger steps in initial stages and converges to better results with fewer trees (though in the end, it asymptotes to the same results). /// - /// Use best regression step trees? + /// Option for using best regression step trees. /// - [Argument(ArgumentType.LastOccurenceWins, HelpText = "Use best regression step trees?", ShortName = "bsr")] + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Option for using best regression step trees", ShortName = "bsr")] public bool BestStepRankingRegressionTrees = false; /// diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 89b21ac7a1..4eeae24097 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -142,7 +142,7 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the class. /// public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 80205bca4a..3846b3ac51 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -86,7 +86,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the class. /// public FastTreeRankingTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index e65985fb94..92659112c9 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -75,7 +75,7 @@ public FastTreeRegressionTrainer(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the class. /// public FastTreeRegressionTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 5b6cbfe33e..99b47a117c 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -76,7 +76,7 @@ public FastTreeTweedieTrainer(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the class. /// public FastTreeTweedieTrainer(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 74fd704ab5..96cce28e59 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -160,7 +160,7 @@ public FastForestClassification(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the class. /// public FastForestClassification(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index c4635f48bc..049bcbdaaa 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -178,7 +178,7 @@ public FastForestRegression(IHostEnvironment env, } /// - /// Initializes a new instance of by using the legacy class. + /// Initializes a new instance of by using the class. /// public FastForestRegression(IHostEnvironment env, Options args) : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), true) From cb78deec8ce5b88671501203ea55549115d27464 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 8 Jan 2019 19:28:06 +0000 Subject: [PATCH 7/8] review comments - 3. Rename Options objects as options (instead of args or advancedSettings used so far) --- .../FastTreeClassification.cs | 6 ++- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 50 ++++++++++--------- .../FastTreeRegression.cs | 22 ++++---- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 28 ++++++----- .../RandomForestClassification.cs | 10 ++-- .../RandomForestRegression.cs | 32 ++++++------ .../TreeTrainersCatalog.cs | 36 ++++++------- src/Microsoft.ML.Legacy/CSharpApi.cs | 20 ++++---- .../TreeTrainersStatic.cs | 24 ++++----- .../Common/EntryPoints/core_manifest.json | 20 ++++---- 10 files changed, 130 insertions(+), 118 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 4eeae24097..2ff3cee1db 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -144,8 +144,10 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, /// /// Initializes a new instance of by using the class. /// - public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options args) - : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) + /// The instance of . + /// Algorithm advanced settings. + public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options options) + : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 3846b3ac51..543404a081 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -88,8 +88,10 @@ public FastTreeRankingTrainer(IHostEnvironment env, /// /// Initializes a new instance of by using the class. /// - public FastTreeRankingTrainer(IHostEnvironment env, Options args) - : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) + /// The instance of . + /// Algorithm advanced settings. + public FastTreeRankingTrainer(IHostEnvironment env, Options options) + : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn)) { } @@ -546,14 +548,14 @@ private enum DupeIdInfo // Keeps track of labels of top 3 documents per query public short[][] TrainQueriesTopLabels; - public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options args, IParallelTraining parallelTraining) + public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options options, IParallelTraining parallelTraining) : base(trainset, - args.LearningRates, - args.Shrinkage, - args.MaxTreeOutput, - args.GetDerivativesSampleRate, - args.BestStepRankingRegressionTrees, - args.RngSeed) + options.LearningRates, + options.Shrinkage, + options.MaxTreeOutput, + options.GetDerivativesSampleRate, + options.BestStepRankingRegressionTrees, + options.RngSeed) { _labels = labels; @@ -567,8 +569,8 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg _labelCounts[q] = new int[relevancyLevel]; // precomputed arrays - _maxDcgTruncationLevel = args.LambdaMartMaxTruncation; - _trainDcg = args.TrainDcg; + _maxDcgTruncationLevel = options.LambdaMartMaxTruncation; + _trainDcg = options.TrainDcg; if (_trainDcg) { _inverseMaxDcgt = new double[Dataset.NumQueries]; @@ -583,7 +585,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg } _discount = new double[Dataset.MaxDocsPerQuery]; - FillDiscounts(args.PositionDiscountFreeform); + FillDiscounts(options.PositionDiscountFreeform); _oneTwoThree = new int[Dataset.MaxDocsPerQuery]; for (int d = 0; d < Dataset.MaxDocsPerQuery; ++d) @@ -593,7 +595,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg int numThreads = BlockingThreadPool.NumThreads; _comparers = new DcgPermutationComparer[numThreads]; for (int i = 0; i < numThreads; ++i) - _comparers[i] = DcgPermutationComparerFactory.GetDcgPermutationFactory(args.SortingAlgorithm); + _comparers[i] = DcgPermutationComparerFactory.GetDcgPermutationFactory(options.SortingAlgorithm); _permutationBuffers = new int[numThreads][]; for (int i = 0; i < numThreads; ++i) @@ -603,13 +605,13 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg FillGainLabels(); #region parameters - _sigmoidParam = args.LearningRates; - _costFunctionParam = args.CostFunctionParam; - _distanceWeight2 = args.DistanceWeight2; - _normalizeQueryLambdas = args.NormalizeQueryLambdas; + _sigmoidParam = options.LearningRates; + _costFunctionParam = options.CostFunctionParam; + _distanceWeight2 = options.DistanceWeight2; + _normalizeQueryLambdas = options.NormalizeQueryLambdas; - _useShiftedNdcg = args.ShiftedNdcg; - _filterZeroLambdas = args.FilterZeroLambdas; + _useShiftedNdcg = options.ShiftedNdcg; + _filterZeroLambdas = options.FilterZeroLambdas; #endregion _scoresCopy = new double[Dataset.NumDocs]; @@ -620,7 +622,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg #if OLD_DATALOAD SetupSecondaryGains(cmd); #endif - SetupBaselineRisk(args); + SetupBaselineRisk(options); _parallelTraining = parallelTraining; } @@ -644,7 +646,7 @@ private void SetupSecondaryGains(Arguments args) } #endif - private void SetupBaselineRisk(Options args) + private void SetupBaselineRisk(Options options) { double[] scores = Dataset.Skeleton.GetData("BaselineScores"); if (scores == null) @@ -652,10 +654,10 @@ private void SetupBaselineRisk(Options args) // Calculate the DCG with the discounts as they exist in the objective function (this // can differ versus the actual DCG discount) - DcgCalculator calc = new DcgCalculator(Dataset.MaxDocsPerQuery, args.SortingAlgorithm); + DcgCalculator calc = new DcgCalculator(Dataset.MaxDocsPerQuery, options.SortingAlgorithm); _baselineDcg = calc.DcgFromScores(Dataset, scores, _discount); - IniFileParserInterface ffi = IniFileParserInterface.CreateFromFreeform(string.IsNullOrEmpty(args.BaselineAlphaRisk) ? "0" : args.BaselineAlphaRisk); + IniFileParserInterface ffi = IniFileParserInterface.CreateFromFreeform(string.IsNullOrEmpty(options.BaselineAlphaRisk) ? "0" : options.BaselineAlphaRisk); IniFileParserInterface.FeatureEvaluator ffe = ffi.GetFeatureEvaluators()[0]; IniFileParserInterface.FeatureMap ffmap = ffi.GetFeatureMap(); string[] ffnames = Enumerable.Range(0, ffmap.RawFeatureCount) @@ -672,7 +674,7 @@ private void SetupBaselineRisk(Options args) uint[] vals = new uint[ffmap.RawFeatureCount]; int iInd = Array.IndexOf(ffnames, "I"); int tInd = Array.IndexOf(ffnames, "T"); - int totalTrees = args.NumTrees; + int totalTrees = options.NumTrees; if (tInd >= 0) vals[tInd] = (uint)totalTrees; _baselineAlpha = Enumerable.Range(0, totalTrees).Select(i => diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 92659112c9..89c6b3141c 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -77,8 +77,10 @@ public FastTreeRegressionTrainer(IHostEnvironment env, /// /// Initializes a new instance of by using the class. /// - public FastTreeRegressionTrainer(IHostEnvironment env, Options args) - : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) + /// The instance of . + /// Algorithm advanced settings. + public FastTreeRegressionTrainer(IHostEnvironment env, Options options) + : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn)) { } @@ -397,17 +399,17 @@ public ObjectiveImpl(Dataset trainData, RegressionGamTrainer.Arguments args) : _labels = GetDatasetRegressionLabels(trainData); } - public ObjectiveImpl(Dataset trainData, Options args) + public ObjectiveImpl(Dataset trainData, Options options) : base( trainData, - args.LearningRates, - args.Shrinkage, - args.MaxTreeOutput, - args.GetDerivativesSampleRate, - args.BestStepRankingRegressionTrees, - args.RngSeed) + options.LearningRates, + options.Shrinkage, + options.MaxTreeOutput, + options.GetDerivativesSampleRate, + options.BestStepRankingRegressionTrees, + options.RngSeed) { - if (args.DropoutRate > 0 && LearningRate > 0) // Don't do shrinkage if dropouts are used. + if (options.DropoutRate > 0 && LearningRate > 0) // Don't do shrinkage if dropouts are used. Shrinkage = 1.0 / LearningRate; _labels = GetDatasetRegressionLabels(trainData); diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 99b47a117c..c1b365ec81 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -78,8 +78,10 @@ public FastTreeTweedieTrainer(IHostEnvironment env, /// /// Initializes a new instance of by using the class. /// - public FastTreeTweedieTrainer(IHostEnvironment env, Options args) - : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn)) + /// The instance of . + /// Algorithm advanced settings. + public FastTreeTweedieTrainer(IHostEnvironment env, Options options) + : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn)) { Initialize(); } @@ -334,17 +336,17 @@ private sealed class ObjectiveImpl : ObjectiveFunctionBase, IStepSearch private readonly Double _index2; // 2 minus the index parameter. private readonly Double _maxClamp; - public ObjectiveImpl(Dataset trainData, Options args) + public ObjectiveImpl(Dataset trainData, Options options) : base( trainData, - args.LearningRates, - args.Shrinkage, - args.MaxTreeOutput, - args.GetDerivativesSampleRate, - args.BestStepRankingRegressionTrees, - args.RngSeed) + options.LearningRates, + options.Shrinkage, + options.MaxTreeOutput, + options.GetDerivativesSampleRate, + options.BestStepRankingRegressionTrees, + options.RngSeed) { - if (args.DropoutRate > 0 && LearningRate > 0) // Don't do shrinkage if dropouts are used. + if (options.DropoutRate > 0 && LearningRate > 0) // Don't do shrinkage if dropouts are used. Shrinkage = 1.0 / LearningRate; _labels = GetDatasetRegressionLabels(trainData); @@ -355,9 +357,9 @@ public ObjectiveImpl(Dataset trainData, Options args) _labels[i] = 0; } - _index1 = 1 - args.Index; - _index2 = 2 - args.Index; - _maxClamp = Math.Abs(args.MaxTreeOutput); + _index1 = 1 - options.Index; + _index2 = 2 - options.Index; + _maxClamp = Math.Abs(options.MaxTreeOutput); } public void AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores) diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 96cce28e59..4275727eff 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -162,8 +162,10 @@ public FastForestClassification(IHostEnvironment env, /// /// Initializes a new instance of by using the class. /// - public FastForestClassification(IHostEnvironment env, Options args) - : base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) + /// The instance of . + /// Algorithm advanced settings. + public FastForestClassification(IHostEnvironment env, Options options) + : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { } @@ -229,8 +231,8 @@ private sealed class ObjectiveFunctionImpl : RandomForestObjectiveFunction { private readonly bool[] _labels; - public ObjectiveFunctionImpl(Dataset trainSet, bool[] trainSetLabels, Options args) - : base(trainSet, args, args.MaxTreeOutput) + public ObjectiveFunctionImpl(Dataset trainSet, bool[] trainSetLabels, Options options) + : base(trainSet, options, options.MaxTreeOutput) { _labels = trainSetLabels; } diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 049bcbdaaa..ff5a7b8a80 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -180,8 +180,10 @@ public FastForestRegression(IHostEnvironment env, /// /// Initializes a new instance of by using the class. /// - public FastForestRegression(IHostEnvironment env, Options args) - : base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), true) + /// The instance of . + /// Algorithm advanced settings. + public FastForestRegression(IHostEnvironment env, Options options) + : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn), true) { } @@ -237,15 +239,15 @@ private abstract class ObjectiveFunctionImplBase : RandomForestObjectiveFunction { private readonly float[] _labels; - public static ObjectiveFunctionImplBase Create(Dataset trainData, Options args) + public static ObjectiveFunctionImplBase Create(Dataset trainData, Options options) { - if (args.ShuffleLabels) - return new ShuffleImpl(trainData, args); - return new BasicImpl(trainData, args); + if (options.ShuffleLabels) + return new ShuffleImpl(trainData, options); + return new BasicImpl(trainData, options); } - private ObjectiveFunctionImplBase(Dataset trainData, Options args) - : base(trainData, args, double.MaxValue) // No notion of maximum step size. + private ObjectiveFunctionImplBase(Dataset trainData, Options options) + : base(trainData, options, double.MaxValue) // No notion of maximum step size. { _labels = FastTreeRegressionTrainer.GetDatasetRegressionLabels(trainData); Contracts.Assert(_labels.Length == trainData.NumDocs); @@ -264,11 +266,11 @@ private sealed class ShuffleImpl : ObjectiveFunctionImplBase private readonly Random _rgen; private readonly int _labelLim; - public ShuffleImpl(Dataset trainData, Options args) - : base(trainData, args) + public ShuffleImpl(Dataset trainData, Options options) + : base(trainData, options) { - Contracts.AssertValue(args); - Contracts.Assert(args.ShuffleLabels); + Contracts.AssertValue(options); + Contracts.Assert(options.ShuffleLabels); _rgen = new Random(0); // Ideally we'd get this from the host. @@ -277,7 +279,7 @@ public ShuffleImpl(Dataset trainData, Options args) var lab = _labels[i]; if (!(0 <= lab && lab < Utils.ArrayMaxSize)) { - throw Contracts.ExceptUserArg(nameof(args.ShuffleLabels), + throw Contracts.ExceptUserArg(nameof(options.ShuffleLabels), "Label {0} for example {1} outside of allowed range" + "[0,{2}) when doing shuffled labels", lab, i, Utils.ArrayMaxSize); } @@ -302,8 +304,8 @@ public override double[] GetGradient(IChannel ch, double[] scores) private sealed class BasicImpl : ObjectiveFunctionImplBase { - public BasicImpl(Dataset trainData, Options args) - : base(trainData, args) + public BasicImpl(Dataset trainData, Options options) + : base(trainData, options) { } } diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index 4fc2a90538..b4c5f1a506 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -42,13 +42,13 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// Algorithm advanced settings. + /// Algorithm advanced settings. public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, - FastTreeRegressionTrainer.Options advancedSettings) + FastTreeRegressionTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRegressionTrainer(env, advancedSettings); + return new FastTreeRegressionTrainer(env, options); } /// @@ -80,13 +80,13 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica /// Predict a target using a decision tree binary classification model trained with the . /// /// The . - /// Algorithm advanced settings. + /// Algorithm advanced settings. public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, - FastTreeBinaryClassificationTrainer.Options advancedSettings) + FastTreeBinaryClassificationTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeBinaryClassificationTrainer(env, advancedSettings); + return new FastTreeBinaryClassificationTrainer(env, options); } /// @@ -120,13 +120,13 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . /// /// The . - /// Algorithm advanced settings. + /// Algorithm advanced settings. public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, - FastTreeRankingTrainer.Options advancedSettings) + FastTreeRankingTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRankingTrainer(env, advancedSettings); + return new FastTreeRankingTrainer(env, options); } /// @@ -208,13 +208,13 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// Algorithm advanced settings. + /// Algorithm advanced settings. public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, - FastTreeTweedieTrainer.Options advancedSettings) + FastTreeTweedieTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeTweedieTrainer(env, advancedSettings); + return new FastTreeTweedieTrainer(env, options); } /// @@ -246,13 +246,13 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// Algorithm advanced settings. + /// Algorithm advanced settings. public static FastForestRegression FastForest(this RegressionContext.RegressionTrainers ctx, - FastForestRegression.Options advancedSettings) + FastForestRegression.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastForestRegression(env, advancedSettings); + return new FastForestRegression(env, options); } /// @@ -284,13 +284,13 @@ public static FastForestClassification FastForest(this BinaryClassificationConte /// Predict a target using a decision tree regression model trained with the . /// /// The . - /// Algorithm advanced settings. + /// Algorithm advanced settings. public static FastForestClassification FastForest(this BinaryClassificationContext.BinaryClassificationTrainers ctx, - FastForestClassification.Options advancedSettings) + FastForestClassification.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastForestClassification(env, advancedSettings); + return new FastForestClassification(env, options); } } } diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index b607908914..f3d6f11e25 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -6687,13 +6687,13 @@ public sealed partial class FastTreeBinaryClassifier : Microsoft.ML.EntryPoints. /// - /// Should we use derivatives optimized for unbalanced sets + /// Option for using derivatives optimized for unbalanced sets /// [Obsolete] public bool UnbalancedSets { get; set; } = false; /// - /// Use best regression step trees? + /// Option for using best regression step trees /// [Obsolete] public bool BestStepRankingRegressionTrees { get; set; } = false; @@ -7199,7 +7199,7 @@ public sealed partial class FastTreeRanker : Microsoft.ML.EntryPoints.CommonInpu public bool NormalizeQueryLambdas { get; set; } = false; /// - /// Use best regression step trees? + /// Option for using best regression step trees /// [Obsolete] public bool BestStepRankingRegressionTrees { get; set; } = false; @@ -7657,7 +7657,7 @@ public sealed partial class FastTreeRegressor : Microsoft.ML.EntryPoints.CommonI /// - /// Use best regression step trees? + /// Option for using best regression step trees /// [Obsolete] public bool BestStepRankingRegressionTrees { get; set; } = false; @@ -8120,7 +8120,7 @@ public sealed partial class FastTreeTweedieRegressor : Microsoft.ML.EntryPoints. public double Index { get; set; } = 1.5d; /// - /// Use best regression step trees? + /// Option for using best regression step trees /// [Obsolete] public bool BestStepRankingRegressionTrees { get; set; } = false; @@ -21112,13 +21112,13 @@ public abstract class FastTreeTrainer : ComponentKind {} public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTrainer { /// - /// Should we use derivatives optimized for unbalanced sets + /// Option for using derivatives optimized for unbalanced sets /// [Obsolete] public bool UnbalancedSets { get; set; } = false; /// - /// Use best regression step trees? + /// Option for using best regression step trees /// [Obsolete] public bool BestStepRankingRegressionTrees { get; set; } = false; @@ -21582,7 +21582,7 @@ public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer public bool NormalizeQueryLambdas { get; set; } = false; /// - /// Use best regression step trees? + /// Option for using best regression step trees /// [Obsolete] public bool BestStepRankingRegressionTrees { get; set; } = false; @@ -21998,7 +21998,7 @@ public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer { /// - /// Use best regression step trees? + /// Option for using best regression step trees /// [Obsolete] public bool BestStepRankingRegressionTrees { get; set; } = false; @@ -22420,7 +22420,7 @@ public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer public double Index { get; set; } = 1.5d; /// - /// Use best regression step trees? + /// Option for using best regression step trees /// [Obsolete] public bool BestStepRankingRegressionTrees { get; set; } = false; diff --git a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs index e14764a83c..7bac497fb9 100644 --- a/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs @@ -69,7 +69,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c /// The label column. /// The features column. /// The optional weights column. - /// Algorithm advanced settings. + /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -84,16 +84,16 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c /// public static Scalar FastTree(this RegressionContext.RegressionTrainers ctx, Scalar label, Vector features, Scalar weights, - FastTreeRegressionTrainer.Options advancedSettings, + FastTreeRegressionTrainer.Options options, Action onFit = null) { - Contracts.CheckValueOrNull(advancedSettings); + Contracts.CheckValueOrNull(options); CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { - var trainer = new FastTreeRegressionTrainer(env, advancedSettings); + var trainer = new FastTreeRegressionTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -160,7 +160,7 @@ public static (Scalar score, Scalar probability, Scalar pred /// The label column. /// The features column. /// The optional weights column. - /// Algorithm advanced settings. + /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -176,16 +176,16 @@ public static (Scalar score, Scalar probability, Scalar pred /// public static (Scalar score, Scalar probability, Scalar predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, Scalar label, Vector features, Scalar weights, - FastTreeBinaryClassificationTrainer.Options advancedSettings, + FastTreeBinaryClassificationTrainer.Options options, Action> onFit = null) { - Contracts.CheckValueOrNull(advancedSettings); + Contracts.CheckValueOrNull(options); CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { - var trainer = new FastTreeBinaryClassificationTrainer(env, advancedSettings); + var trainer = new FastTreeBinaryClassificationTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -247,7 +247,7 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c /// The features column. /// The groupId column. /// The optional weights column. - /// Algorithm advanced settings. + /// Algorithm advanced settings. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -256,16 +256,16 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c /// The Score output column indicating the predicted value. public static Scalar FastTree(this RankingContext.RankingTrainers ctx, Scalar label, Vector features, Key groupId, Scalar weights, - FastTreeRankingTrainer.Options advancedSettings, + FastTreeRankingTrainer.Options options, Action onFit = null) { - Contracts.CheckValueOrNull(advancedSettings); + Contracts.CheckValueOrNull(options); CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { - var trainer = new FastTreeRankingTrainer(env, advancedSettings); + var trainer = new FastTreeRankingTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 3e95b0e09d..974919906e 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -6585,7 +6585,7 @@ { "Name": "UnbalancedSets", "Type": "Bool", - "Desc": "Should we use derivatives optimized for unbalanced sets", + "Desc": "Option for using derivatives optimized for unbalanced sets", "Aliases": [ "us" ], @@ -6597,7 +6597,7 @@ { "Name": "BestStepRankingRegressionTrees", "Type": "Bool", - "Desc": "Use best regression step trees?", + "Desc": "Option for using best regression step trees", "Aliases": [ "bsr" ], @@ -7590,7 +7590,7 @@ { "Name": "BestStepRankingRegressionTrees", "Type": "Bool", - "Desc": "Use best regression step trees?", + "Desc": "Option for using best regression step trees", "Aliases": [ "bsr" ], @@ -8490,7 +8490,7 @@ { "Name": "BestStepRankingRegressionTrees", "Type": "Bool", - "Desc": "Use best regression step trees?", + "Desc": "Option for using best regression step trees", "Aliases": [ "bsr" ], @@ -9399,7 +9399,7 @@ { "Name": "BestStepRankingRegressionTrees", "Type": "Bool", - "Desc": "Use best regression step trees?", + "Desc": "Option for using best regression step trees", "Aliases": [ "bsr" ], @@ -25154,7 +25154,7 @@ { "Name": "UnbalancedSets", "Type": "Bool", - "Desc": "Should we use derivatives optimized for unbalanced sets", + "Desc": "Option for using derivatives optimized for unbalanced sets", "Aliases": [ "us" ], @@ -25166,7 +25166,7 @@ { "Name": "BestStepRankingRegressionTrees", "Type": "Bool", - "Desc": "Use best regression step trees?", + "Desc": "Option for using best regression step trees", "Aliases": [ "bsr" ], @@ -26141,7 +26141,7 @@ { "Name": "BestStepRankingRegressionTrees", "Type": "Bool", - "Desc": "Use best regression step trees?", + "Desc": "Option for using best regression step trees", "Aliases": [ "bsr" ], @@ -27023,7 +27023,7 @@ { "Name": "BestStepRankingRegressionTrees", "Type": "Bool", - "Desc": "Use best regression step trees?", + "Desc": "Option for using best regression step trees", "Aliases": [ "bsr" ], @@ -27914,7 +27914,7 @@ { "Name": "BestStepRankingRegressionTrees", "Type": "Bool", - "Desc": "Use best regression step trees?", + "Desc": "Option for using best regression step trees", "Aliases": [ "bsr" ], From cd165e6bd95ded7d01037bafb76dbd010cd6aa9f Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Thu, 10 Jan 2019 05:48:06 +0000 Subject: [PATCH 8/8] making the constructors internal --- src/Microsoft.ML.FastTree/FastTreeClassification.cs | 4 ++-- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 4 ++-- src/Microsoft.ML.FastTree/FastTreeRegression.cs | 4 ++-- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 4 ++-- .../RandomForestClassification.cs | 4 ++-- src/Microsoft.ML.FastTree/RandomForestRegression.cs | 4 ++-- test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 2 +- .../TrainerEstimators/TreeEstimators.cs | 12 ++++++------ 8 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 2ff3cee1db..1eedfeb5b8 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -127,7 +127,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer : /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. - public FastTreeBinaryClassificationTrainer(IHostEnvironment env, + internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, @@ -146,7 +146,7 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, /// /// The instance of . /// Algorithm advanced settings. - public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options options) + internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 543404a081..4834e23d9e 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -71,7 +71,7 @@ public sealed partial class FastTreeRankingTrainer /// Total number of decision trees to create in the ensemble. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - public FastTreeRankingTrainer(IHostEnvironment env, + internal FastTreeRankingTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string groupIdColumn = DefaultColumnNames.GroupId, @@ -90,7 +90,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, /// /// The instance of . /// Algorithm advanced settings. - public FastTreeRankingTrainer(IHostEnvironment env, Options options) + internal FastTreeRankingTrainer(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 89c6b3141c..356fd1fb91 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -62,7 +62,7 @@ public sealed partial class FastTreeRegressionTrainer /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. - public FastTreeRegressionTrainer(IHostEnvironment env, + internal FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, @@ -79,7 +79,7 @@ public FastTreeRegressionTrainer(IHostEnvironment env, /// /// The instance of . /// Algorithm advanced settings. - public FastTreeRegressionTrainer(IHostEnvironment env, Options options) + internal FastTreeRegressionTrainer(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index c1b365ec81..d5bd8e11f2 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -59,7 +59,7 @@ public sealed partial class FastTreeTweedieTrainer /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. - public FastTreeTweedieTrainer(IHostEnvironment env, + internal FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, @@ -80,7 +80,7 @@ public FastTreeTweedieTrainer(IHostEnvironment env, /// /// The instance of . /// Algorithm advanced settings. - public FastTreeTweedieTrainer(IHostEnvironment env, Options options) + internal FastTreeTweedieTrainer(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn)) { Initialize(); diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 4275727eff..8c01920bbc 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -145,7 +145,7 @@ public sealed class Options : FastForestArgumentsBase /// Total number of decision trees to create in the ensemble. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - public FastForestClassification(IHostEnvironment env, + internal FastForestClassification(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, @@ -164,7 +164,7 @@ public FastForestClassification(IHostEnvironment env, /// /// The instance of . /// Algorithm advanced settings. - public FastForestClassification(IHostEnvironment env, Options options) + internal FastForestClassification(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index ff5a7b8a80..a111e1314a 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -163,7 +163,7 @@ public sealed class Options : FastForestArgumentsBase /// Total number of decision trees to create in the ensemble. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. - public FastForestRegression(IHostEnvironment env, + internal FastForestRegression(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, @@ -182,7 +182,7 @@ public FastForestRegression(IHostEnvironment env, /// /// The instance of . /// Algorithm advanced settings. - public FastForestRegression(IHostEnvironment env, Options options) + internal FastForestRegression(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn), true) { } diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 9e4d8f5ad4..2afa97170a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -104,7 +104,7 @@ public void OvaFastTree() // Pipeline var pipeline = new Ova( mlContext, - new FastTreeBinaryClassificationTrainer(mlContext, new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1 }), + mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1 }), useProbabilities: false); var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 14a27ac283..574589a39a 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -24,7 +24,7 @@ public void FastTreeBinaryEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new FastTreeBinaryClassificationTrainer(Env, + var trainer = ML.BinaryClassification.Trainers.FastTree( new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1, NumTrees = 10, @@ -83,7 +83,7 @@ public void FastForestClassificationEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = new FastForestClassification(Env, + var trainer = ML.BinaryClassification.Trainers.FastForest( new FastForestClassification.Options { NumLeaves = 10, NumTrees = 20, @@ -105,7 +105,7 @@ public void FastTreeRankerEstimator() { var (pipe, dataView) = GetRankingPipeline(); - var trainer = new FastTreeRankingTrainer(Env, + var trainer = ML.Ranking.Trainers.FastTree( new FastTreeRankingTrainer.Options { FeatureColumn = "NumericFeatures", NumTrees = 10 @@ -144,7 +144,7 @@ public void LightGBMRankerEstimator() public void FastTreeRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastTreeRegressionTrainer(Env, + var trainer = ML.Regression.Trainers.FastTree( new FastTreeRegressionTrainer.Options { NumTrees = 10, NumThreads = 1, NumLeaves = 5 }); TestEstimatorCore(trainer, dataView); @@ -197,7 +197,7 @@ public void GAMRegressorEstimator() public void TweedieRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastTreeTweedieTrainer(Env, + var trainer = ML.Regression.Trainers.FastTreeTweedie( new FastTreeTweedieTrainer.Options { EntropyCoefficient = 0.3, OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent, @@ -215,7 +215,7 @@ public void TweedieRegressorEstimator() public void FastForestRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = new FastForestRegression(Env, + var trainer = ML.Regression.Trainers.FastForest( new FastForestRegression.Options { BaggingSize = 2, NumTrees = 10,