From cd043003902f88c67d282a889fbdc0aa6dcf1ac4 Mon Sep 17 00:00:00 2001 From: feiyun0112 Date: Fri, 26 Oct 2018 20:40:37 +0800 Subject: [PATCH 1/3] FastTreeRankingTrainer expose non-advanced args(#1246) --- src/Microsoft.ML.FastTree/BoostingFastTree.cs | 16 +++++++++++++--- src/Microsoft.ML.FastTree/FastTree.cs | 16 ++++++++++++++-- src/Microsoft.ML.FastTree/FastTreeCatalog.cs | 2 +- .../FastTreeClassification.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 18 +++++++++++++++--- .../FastTreeRegression.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeStatic.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 2 +- src/Microsoft.ML.FastTree/RandomForest.cs | 2 +- 9 files changed, 48 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 7a68902038..b8a1fbff31 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -21,10 +21,20 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh { } - protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) + protected BoostingFastTreeTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn = null, + string groupIdColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings) { + //override with the directly provided values. + Args.LearningRates = learningRate; } protected override void CheckArgs(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 0aa360cd62..7d6ba197b4 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -92,8 +92,15 @@ public abstract class FastTreeTrainerBase : /// /// Constructor to use when instantiating the classes deriving from here through the API. /// - private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) + private protected FastTreeTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn = null, + string groupIdColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + Action advancedSettings = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Args = new TArgs(); @@ -113,6 +120,11 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l if (groupIdColumn != null) Args.GroupIdColumn = groupIdColumn; + //override with the directly provided values. + Args.NumLeaves = numLeaves; + Args.NumTrees = numTrees; + Args.MinDocumentsInLeafs = minDocumentsInLeafs; + // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. diff --git a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs b/src/Microsoft.ML.FastTree/FastTreeCatalog.cs index 103a5676f1..9e439903dc 100644 --- a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs +++ b/src/Microsoft.ML.FastTree/FastTreeCatalog.cs @@ -94,7 +94,7 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings); + return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings: advancedSettings); } } } diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 7f7080281d..871dff657f 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -136,7 +136,7 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index d96f82f741..4250cecfe2 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -67,10 +67,22 @@ public sealed partial class FastTreeRankingTrainer /// The name of the feature column. /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, - string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + public FastTreeRankingTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string groupIdColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings: advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 4cc09c9243..d3f2ac00dc 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -72,7 +72,7 @@ public FastTreeRegressionTrainer(IHostEnvironment env, int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/FastTreeStatic.cs index 1aa8528b3e..9fd7f0fa9b 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/FastTreeStatic.cs @@ -153,7 +153,7 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { - var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, advancedSettings); + var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 479d83399f..e7f4c097d3 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -59,7 +59,7 @@ public sealed partial class FastTreeTweedieTrainer /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 6043774b43..19f750a68a 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -30,7 +30,7 @@ protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, SchemaShape. /// protected RandomForestTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, bool quantileEnabled = false, Action advancedSettings = null) - : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) + : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) { _quantileEnabled = quantileEnabled; } From 3876f0f83665eea3d14db523eb93d4b993244b09 Mon Sep 17 00:00:00 2001 From: feiyun0112 Date: Sat, 27 Oct 2018 22:13:20 +0800 Subject: [PATCH 2/3] merge --- .vsts-dotnet-ci.yml | 10 +- README.md | 2 +- build/ci/phase-template.yml | 4 +- docs/building/unix-instructions.md | 2 +- init-tools.cmd | 11 +- .../Microsoft.ML.TensorFlow.Redist.nupkgproj | 2 +- run.cmd | 2 +- .../Training/TrainerEstimatorBase.cs | 37 +++- .../Training/TrainerUtils.cs | 73 ++----- .../Microsoft.ML.Ensemble.csproj | 2 +- src/Microsoft.ML.FastTree/BoostingFastTree.cs | 22 ++- src/Microsoft.ML.FastTree/FastTree.cs | 58 ++---- src/Microsoft.ML.FastTree/FastTreeCatalog.cs | 100 ---------- .../FastTreeClassification.cs | 14 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 6 +- .../FastTreeRegression.cs | 7 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 26 ++- .../GamClassification.cs | 15 +- src/Microsoft.ML.FastTree/GamRegression.cs | 18 +- src/Microsoft.ML.FastTree/GamTrainer.cs | 17 +- src/Microsoft.ML.FastTree/RandomForest.cs | 15 +- .../RandomForestClassification.cs | 24 ++- .../RandomForestRegression.cs | 18 +- .../TreeTrainersCatalog.cs | 181 ++++++++++++++++++ ...astTreeStatic.cs => TreeTrainersStatic.cs} | 36 ++-- .../OlsLinearRegression.cs | 4 +- .../LightGbmBinaryTrainer.cs | 18 +- src/Microsoft.ML.LightGBM/LightGbmCatalog.cs | 83 +++++++- .../LightGbmMulticlassTrainer.cs | 25 ++- .../LightGbmRankingTrainer.cs | 29 ++- .../LightGbmRegressionTrainer.cs | 18 +- src/Microsoft.ML.LightGBM/LightGbmStatic.cs | 115 +++++++++-- .../LightGbmTrainerBase.cs | 58 +++--- .../FactorizationMachineCatalog.cs | 6 +- .../FactorizationMachineStatic.cs | 9 +- .../Standard/LinearClassificationTrainer.cs | 45 +---- .../LogisticRegression/LbfgsPredictorBase.cs | 84 +++++--- .../MulticlassLogisticRegression.cs | 4 +- .../MultiClass/MultiClassNaiveBayesTrainer.cs | 4 +- .../Standard/SdcaCatalog.cs | 17 +- .../Standard/SdcaMultiClass.cs | 9 +- .../Standard/SdcaRegression.cs | 5 +- .../Standard/SdcaStatic.cs | 30 +-- .../BenchmarksTest.cs | 4 + .../Training.cs | 79 ++++++++ .../BaseTestBaseline.cs | 2 + .../TensorflowTests.cs | 2 +- .../TensorFlowEstimatorTests.cs | 2 +- 48 files changed, 852 insertions(+), 502 deletions(-) delete mode 100644 src/Microsoft.ML.FastTree/FastTreeCatalog.cs create mode 100644 src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs rename src/Microsoft.ML.FastTree/{FastTreeStatic.cs => TreeTrainersStatic.cs} (90%) diff --git a/.vsts-dotnet-ci.yml b/.vsts-dotnet-ci.yml index fa7c7a139c..69ff9b9004 100644 --- a/.vsts-dotnet-ci.yml +++ b/.vsts-dotnet-ci.yml @@ -18,7 +18,15 @@ phases: - template: /build/ci/phase-template.yml parameters: - name: Windows_NT + name: Windows_x64 + buildScript: build.cmd + queue: + name: Hosted VS2017 + +- template: /build/ci/phase-template.yml + parameters: + name: Windows_x86 + architecture: x86 buildScript: build.cmd queue: name: Hosted VS2017 diff --git a/README.md b/README.md index 2c4eb066c7..9ecead077a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ML.NET allows .NET developers to develop their own models and infuse custom mach ML.NET was originally developed in Microsoft Research, and evolved into a significant framework over the last decade and is used across many product groups in Microsoft like Windows, Bing, PowerPoint, Excel and more. -With this first preview release, ML.NET enables machine learning tasks like classification (for example: support text classification, sentiment analysis) and regression (for example, price-prediction). +ML.NET enables machine learning tasks like classification (for example: support text classification, sentiment analysis) and regression (for example, price-prediction). Along with these ML capabilities, this first release of ML.NET also brings the first draft of .NET APIs for training models, using models for predictions, as well as the core components of this framework such as learning algorithms, transforms, and ML data structures. diff --git a/build/ci/phase-template.yml b/build/ci/phase-template.yml index 69b030db4c..72a2d83051 100644 --- a/build/ci/phase-template.yml +++ b/build/ci/phase-template.yml @@ -1,5 +1,6 @@ parameters: name: '' + architecture: x64 buildScript: '' queue: {} @@ -8,6 +9,7 @@ phases: variables: _buildScript: ${{ parameters.buildScript }} _phaseName: ${{ parameters.name }} + _arch: ${{ parameters.architecture }} queue: parallel: 99 matrix: @@ -17,7 +19,7 @@ phases: _configuration: Release ${{ insert }}: ${{ parameters.queue }} steps: - - script: $(_buildScript) -$(_configuration) + - script: $(_buildScript) -$(_configuration) -buildArch=$(_arch) displayName: Build - ${{ if eq(parameters.name, 'MacOS') }}: - script: brew install libomp diff --git a/docs/building/unix-instructions.md b/docs/building/unix-instructions.md index 95ca3124ad..7919e97d63 100644 --- a/docs/building/unix-instructions.md +++ b/docs/building/unix-instructions.md @@ -3,7 +3,7 @@ Building ML.NET on Linux and macOS ## Building 1. Install the prerequisites ([Linux](#user-content-linux), [macOS](#user-content-macos)) -2. Clone the machine learning repo `git clone https://github.com/dotnet/machinelearning.git` +2. Clone the machine learning repo `git clone --recursive https://github.com/dotnet/machinelearning.git` 3. Navigate to the `machinelearning` directory 4. Run the build script `./build.sh` diff --git a/init-tools.cmd b/init-tools.cmd index 0e6621ef52..3743cb413d 100644 --- a/init-tools.cmd +++ b/init-tools.cmd @@ -12,6 +12,7 @@ set BUILD_TOOLS_PATH=%PACKAGES_DIR%\Microsoft.DotNet.BuildTools\%BUILDTOOLS_VERS set INIT_TOOLS_RESTORE_PROJECT=%~dp0init-tools.msbuild set BUILD_TOOLS_SEMAPHORE_DIR=%TOOLRUNTIME_DIR%\%BUILDTOOLS_VERSION% set BUILD_TOOLS_SEMAPHORE=%BUILD_TOOLS_SEMAPHORE_DIR%\init-tools.completed +set ARCH=x64 :: if force option is specified then clean the tool runtime and build tools package directory to force it to get recreated if [%1]==[force] ( @@ -47,9 +48,17 @@ echo Running %0 > "%INIT_TOOLS_LOG%" set /p DOTNET_VERSION=< "%~dp0DotnetCLIVersion.txt" if exist "%DOTNET_CMD%" goto :afterdotnetrestore +:Arg_Loop +if [%1] == [] goto :ArchSet +if /i [%1] == [x86] ( set ARCH=x86&&goto ArchSet) +shift +goto :Arg_Loop + +:ArchSet + echo Installing dotnet cli... if NOT exist "%DOTNET_PATH%" mkdir "%DOTNET_PATH%" -set DOTNET_ZIP_NAME=dotnet-sdk-%DOTNET_VERSION%-win-x64.zip +set DOTNET_ZIP_NAME=dotnet-sdk-%DOTNET_VERSION%-win-%ARCH%.zip set DOTNET_REMOTE_PATH=https://dotnetcli.azureedge.net/dotnet/Sdk/%DOTNET_VERSION%/%DOTNET_ZIP_NAME% set DOTNET_LOCAL_PATH=%DOTNET_PATH%%DOTNET_ZIP_NAME% echo Installing '%DOTNET_REMOTE_PATH%' to '%DOTNET_LOCAL_PATH%' >> "%INIT_TOOLS_LOG%" diff --git a/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj b/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj index 0de839bf08..e5856c63f5 100644 --- a/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj +++ b/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj @@ -16,6 +16,6 @@ - + diff --git a/run.cmd b/run.cmd index d4e8ec8121..aa6ed53658 100644 --- a/run.cmd +++ b/run.cmd @@ -11,7 +11,7 @@ set DOTNET_SKIP_FIRST_TIME_EXPERIENCE=1 set DOTNET_MULTILEVEL_LOOKUP=0 :: Restore the Tools directory -call "%~dp0init-tools.cmd" +call "%~dp0init-tools.cmd" %* if NOT [%ERRORLEVEL%]==[0] exit /b 1 set _toolRuntime=%~dp0Tools diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index ad074fd737..ad2916368d 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -50,7 +50,10 @@ public abstract class TrainerEstimatorBase : ITrainerEstim public abstract PredictionKind PredictionKind { get; } - public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null) + public TrainerEstimatorBase(IHost host, + SchemaShape.Column feature, + SchemaShape.Column label, + SchemaShape.Column weight = null) { Contracts.CheckValue(host, nameof(host)); Host = host; @@ -149,9 +152,39 @@ protected TTransformer TrainTransformer(IDataView trainSet, protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema); - private RoleMappedData MakeRoles(IDataView data) => + protected virtual RoleMappedData MakeRoles(IDataView data) => new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name); IPredictor ITrainer.Train(TrainContext context) => Train(context); } + + /// + /// This represents a basic class for 'simple trainer'. + /// A 'simple trainer' accepts one feature column and one label column, also optionally a weight column. + /// It produces a 'prediction transformer'. + /// + public abstract class TrainerEstimatorBaseWithGroupId : TrainerEstimatorBase + where TTransformer : ISingleFeaturePredictionTransformer + where TModel : IPredictor + { + /// + /// The optional groupID column that the ranking trainers expects. + /// + public readonly SchemaShape.Column GroupIdColumn; + + public TrainerEstimatorBaseWithGroupId(IHost host, + SchemaShape.Column feature, + SchemaShape.Column label, + SchemaShape.Column weight = null, + SchemaShape.Column groupId = null) + :base(host, feature, label, weight) + { + Host.CheckValueOrNull(groupId); + GroupIdColumn = groupId; + } + + protected override RoleMappedData MakeRoles(IDataView data) => + new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name); + + } } diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index dff5748635..e966687a9e 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -362,9 +362,14 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) /// /// The for the label column for regression tasks. /// - /// name of the weight column - public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn) - => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); + /// name of the weight column + public static SchemaShape.Column MakeU4ScalarColumn(string columnName) + { + if (columnName == null) + return null; + + return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); + } /// /// The for the feature column. @@ -377,69 +382,13 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn) /// The for the weight column. /// /// name of the weight column - public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) + /// whether the column is implicitly, or explicitly defined + public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true) { - if (weightColumn == null) + if (weightColumn == null || !isExplicit) return null; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } - - private static void CheckArgColName(IHostEnvironment host, string defaultColName, string argValue) - { - if (argValue != defaultColName) - throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead."); - } - - /// - /// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn); - - if (args.GroupIdColumn != null) - CheckArgColName(host, DefaultColumnNames.GroupId, args.GroupIdColumn); - } - - /// - /// Check that the label, feature, and weights column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithWeight args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn); - } - - /// - /// Check that the label and feature column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithLabel args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - } - - /// - /// If, after applying the advancedArgs delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user. - /// - public static void CheckArgsAndAdvancedSettingMismatch(IChannel channel, T methodParam, T defaultVal, T setting, string argName) - { - if (!setting.Equals(defaultVal) && !setting.Equals(methodParam)) - channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}"); - } } /// diff --git a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj index ef0ad01a40..ac67f8db1d 100644 --- a/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj +++ b/src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj @@ -2,7 +2,7 @@ netstandard2.0 - Microsoft.ML.Ensemble + Microsoft.ML CORECLR diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index b8a1fbff31..99658d2a89 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -24,17 +24,21 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, - string groupIdColumn = null, - int numLeaves = Defaults.NumLeaves, - int numTrees = Defaults.NumTrees, - int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings) : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings) { - //override with the directly provided values. - Args.LearningRates = learningRate; + + if (Args.LearningRates != learningRate) + { + using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments.")) + Args.LearningRates = learningRate; + } } protected override void CheckArgs(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 7d6ba197b4..b73b01faa1 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.Conversion; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; @@ -45,7 +46,7 @@ internal static class FastTreeShared } public abstract class FastTreeTrainerBase : - TrainerEstimatorBase + TrainerEstimatorBaseWithGroupId where TTransformer: ISingleFeaturePredictionTransformer where TArgs : TreeArgs, new() where TModel : IPredictorProducing @@ -95,30 +96,33 @@ public abstract class FastTreeTrainerBase : private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, - string groupIdColumn = null, - int numLeaves = Defaults.NumLeaves, - int numTrees = Defaults.NumTrees, - int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, - Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + Action advancedSettings) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new TArgs(); + // set up the directly provided values + // override with the directly provided values. + Args.NumLeaves = numLeaves; + Args.NumTrees = numTrees; + Args.MinDocumentsInLeafs = minDocumentsInLeafs; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args); - Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = weightColumn; + Args.WeightColumn = Optional.Explicit(weightColumn); ; if (groupIdColumn != null) - Args.GroupIdColumn = groupIdColumn; + Args.GroupIdColumn = Optional.Explicit(groupIdColumn); ; //override with the directly provided values. Args.NumLeaves = numLeaves; @@ -140,7 +144,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, /// Legacy constructor that is used when invoking the classes deriving from this, through maml. /// private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Args = args; @@ -171,32 +175,6 @@ protected virtual Float GetMaxLabel() return Float.PositiveInfinity; } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation. - /// - protected void CheckArgsAndAdvancedSettingMismatch(int numLeaves, - int numTrees, - int minDocumentsInLeafs, - double learningRate, - BoostedTreeArgs snapshot, - BoostedTreeArgs currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, currentArgs.NumTrees, nameof(numTrees)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, currentArgs.MinDocumentsInLeafs, nameof(minDocumentsInLeafs)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRates, currentArgs.LearningRates, nameof(learningRate)); - } - } - private void Initialize(IHostEnvironment env) { int numThreads = Args.NumThreads ?? Environment.ProcessorCount; diff --git a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs b/src/Microsoft.ML.FastTree/FastTreeCatalog.cs deleted file mode 100644 index 9e439903dc..0000000000 --- a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.FastTree; -using System; - -namespace Microsoft.ML -{ - /// - /// FastTree extension methods. - /// - public static class FastTreeRegressionExtensions - { - /// - /// Predict a target using a decision tree regression model trained with the . - /// - /// The . - /// The label column. - /// The features column. - /// The optional weights column. - /// Total number of decision trees to create in the ensemble. - /// The maximum number of leaves per decision tree. - /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. - /// The learning rate. - /// Algorithm advanced settings. - public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, - string label = DefaultColumnNames.Label, - string features = DefaultColumnNames.Features, - string weights = null, - int numLeaves = Defaults.NumLeaves, - int numTrees = Defaults.NumTrees, - int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) - { - Contracts.CheckValue(ctx, nameof(ctx)); - var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRegressionTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); - } - } - - public static class FastTreeBinaryClassificationExtensions - { - - /// - /// Predict a target using a decision tree binary classification model trained with the . - /// - /// The . - /// The label column. - /// The features column. - /// The optional weights column. - /// Total number of decision trees to create in the ensemble. - /// The maximum number of leaves per decision tree. - /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. - /// The learning rate. - /// Algorithm advanced settings. - public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, - string label = DefaultColumnNames.Label, - string features = DefaultColumnNames.Features, - string weights = null, - int numLeaves = Defaults.NumLeaves, - int numTrees = Defaults.NumTrees, - int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) - { - Contracts.CheckValue(ctx, nameof(ctx)); - var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeBinaryClassificationTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); - } - } - - public static class FastTreeRankingExtensions - { - - /// - /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . - /// - /// The . - /// The label column. - /// The features column. - /// The groupId column. - /// The optional weights column. - /// Algorithm advanced settings. - public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, - string label = DefaultColumnNames.Label, - string groupId = DefaultColumnNames.GroupId, - string features = DefaultColumnNames.Features, - string weights = null, - Action advancedSettings = null) - { - Contracts.CheckValue(ctx, nameof(ctx)); - var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings: advancedSettings); - } - } -} diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 871dff657f..b864b484ca 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -122,11 +122,11 @@ public sealed partial class FastTreeBinaryClassificationTrainer : /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The learning rate. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. + /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn, @@ -138,20 +138,8 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Action advancedSettings = null) : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, numTrees, minDocumentsInLeafs, learningRate, new Arguments(), Args); - - //override with the directly provided values. - Args.NumLeaves = numLeaves; - Args.NumTrees = numTrees; - Args.MinDocumentsInLeafs = minDocumentsInLeafs; - Args.LearningRates = learningRate; } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 4250cecfe2..b7baf5999e 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -82,10 +82,8 @@ public FastTreeRankingTrainer(IHostEnvironment env, int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings: advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } @@ -93,7 +91,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, /// Initializes a new instance of by using the legacy class. /// internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index d3f2ac00dc..8186ef2bb8 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -58,11 +58,11 @@ public sealed partial class FastTreeRegressionTrainer /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The learning rate. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. + /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn, @@ -74,11 +74,6 @@ public FastTreeRegressionTrainer(IHostEnvironment env, Action advancedSettings = null) : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, numTrees, minDocumentsInLeafs, learningRate, new Arguments(), Args); } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index e7f4c097d3..66079ae207 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -35,10 +35,10 @@ namespace Microsoft.ML.Trainers.FastTree public sealed partial class FastTreeTweedieTrainer : BoostingFastTreeTrainerBase, FastTreeTweediePredictor> { - public const string LoadNameValue = "FastTreeTweedieRegression"; - public const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; - public const string Summary = "Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression."; - public const string ShortName = "fttweedie"; + internal const string LoadNameValue = "FastTreeTweedieRegression"; + internal const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; + internal const string Summary = "Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression."; + internal const string ShortName = "fttweedie"; private TestHistory _firstTestSetHistory; private Test _trainRegressionTest; @@ -54,12 +54,22 @@ public sealed partial class FastTreeTweedieTrainer /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. /// A delegate to apply all the advanced arguments to the algorithm. - public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + public FastTreeTweedieTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 9672bcfcdf..57c3f37eeb 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -62,13 +62,18 @@ internal BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// A delegate to apply all the advanced arguments to the algorithm. - public BinaryClassificationGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) + public BinaryClassificationGamTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - _sigmoidParameter = 1; } diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 40a0ff0c93..481a991a8e 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -6,11 +6,11 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.FastTree.Internal; using System; [assembly: LoadableClass(RegressionGamTrainer.Summary, @@ -51,12 +51,18 @@ internal RegressionGamTrainer(IHostEnvironment env, Arguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public RegressionGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) + public RegressionGamTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } internal override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index e44196b04c..e4318187d0 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -132,15 +132,26 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight protected IParallelTraining ParallelTraining; - private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, - string weightColumn = null, Action advancedSettings = null) + private protected GamTrainerBase(IHostEnvironment env, + string name, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Args = new TArgs(); + Args.MinDocuments = minDocumentsInLeafs; + Args.LearningRates = learningRate; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); + Args.LabelColumn = label.Name; + Args.FeatureColumn = featureColumn; if (weightColumn != null) Args.WeightColumn = weightColumn; @@ -154,7 +165,7 @@ private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape. private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Contracts.CheckValue(env, nameof(env)); Host.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 19f750a68a..f29383473c 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -28,9 +28,18 @@ protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, SchemaShape. /// /// Constructor invoked by the API code-path. /// - protected RandomForestTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, bool quantileEnabled = false, Action advancedSettings = null) - : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + protected RandomForestTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings, + bool quantileEnabled = false) + : base(env, label, featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings) { _quantileEnabled = quantileEnabled; } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index cc4f91d895..bfa5efa460 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -7,12 +7,12 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.FastTree.Internal; using System; using System.Linq; @@ -80,7 +80,7 @@ private static VersionInfo GetVersionInfo() public FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) - { } + { } private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) @@ -139,12 +139,22 @@ public sealed class Arguments : FastForestArgumentsBase /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public FastForestClassification(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + public FastForestClassification(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 3af5440d13..23b110f072 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -160,12 +160,22 @@ public sealed class Arguments : FastForestArgumentsBase /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. /// A delegate to apply all the advanced arguments to the algorithm. - public FastForestRegression(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, true, advancedSettings) + public FastForestRegression(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs new file mode 100644 index 0000000000..2d6d1c25ef --- /dev/null +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -0,0 +1,181 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Trainers.FastTree; +using System; + +namespace Microsoft.ML +{ + /// + /// FastTree extension methods. + /// + public static class TreeExtensions + { + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeRegressionTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeBinaryClassificationTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . + /// + /// The . + /// The label column. + /// The features column. + /// The groupId column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, + string label = DefaultColumnNames.Label, + string groupId = DefaultColumnNames.GroupId, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeRankingTrainer(env, label, features, groupId, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new BinaryClassificationGamTrainer(env, label, features, weights, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new RegressionGamTrainer(env, label, features, weights, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeTweedieTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + } +} diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs similarity index 90% rename from src/Microsoft.ML.FastTree/FastTreeStatic.cs rename to src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index 9fd7f0fa9b..f3008957f4 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.StaticPipe /// /// FastTree extension methods. /// - public static class FastTreeRegressionExtensions + public static class TreeRegressionExtensions { /// /// FastTree extension method. @@ -50,7 +50,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c Action advancedSettings = null, Action onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -64,10 +64,6 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c return rec.Score; } - } - - public static class FastTreeBinaryClassificationExtensions - { /// /// FastTree extension method. @@ -98,7 +94,7 @@ public static (Scalar score, Scalar probability, Scalar pred Action advancedSettings = null, Action> onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -114,10 +110,6 @@ public static (Scalar score, Scalar probability, Scalar pred return rec.Output; } - } - - public static class FastTreeRankingExtensions - { /// /// FastTree . @@ -139,7 +131,7 @@ public static class FastTreeRankingExtensions /// the linear model that was trained. Note that this action cannot change the result in any way; /// it is only a way for the caller to be informed about what was learnt. /// The Score output column indicating the predicted value. - public static Scalar FastTree(this RankingContext.RankingTrainers ctx, + public static Scalar FastTree(this RankingContext.RankingTrainers ctx, Scalar label, Vector features, Key groupId, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, @@ -148,12 +140,13 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c Action advancedSettings = null, Action onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { - var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, + numTrees, minDatapointsInLeafs, learningRate, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -161,17 +154,14 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c return rec.Score; } - } - internal class FastTreeStaticsUtils - { internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, - int numLeaves, - int numTrees, - int minDatapointsInLeafs, - double learningRate, - Delegate advancedSettings, - Delegate onFit) + int numLeaves, + int numTrees, + int minDatapointsInLeafs, + double learningRate, + Delegate advancedSettings, + Delegate onFit) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 502b6f8e61..cba842a635 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -76,8 +76,6 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, st string weightColumn = null, Action advancedSettings = null) : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings)) { - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); } /// @@ -85,7 +83,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, st /// internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative"); diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index fe770b80a4..bd9adbebbc 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -106,11 +106,14 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, int? numLeaves = null, @@ -118,19 +121,8 @@ public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string fe double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - - // override with the directly provided values - Args.NumBoostRound = numBoostRound; - Args.NumLeaves = numLeaves ?? Args.NumLeaves; - Args.LearningRate = learningRate ?? Args.LearningRate; - Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } private protected override IPredictorWithFeatureWeights CreatePredictor() diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs index 31ae1a0f53..56a892be4a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML /// /// Regression trainer estimators. /// - public static class LightGbmRegressionExtensions + public static class LightGbmExtensions { /// /// Predict a target using a decision tree regression model trained with the . @@ -25,7 +25,10 @@ public static class LightGbmRegressionExtensions /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static LightGbmRegressorTrainer LightGbm(this RegressionContext.RegressionTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, @@ -40,13 +43,6 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionContext.Regressio var env = CatalogUtils.GetEnvironment(ctx); return new LightGbmRegressorTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); } - } - - /// - /// Binary Classification trainer estimators. - /// - public static class LightGbmClassificationExtensions - { /// /// Predict a target using a decision tree binary classification model trained with the . @@ -59,7 +55,10 @@ public static class LightGbmClassificationExtensions /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, @@ -75,5 +74,69 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationContext.Bi return new LightGbmBinaryTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The weights column. + /// The groupId column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + public static LightGbmRankingTrainer LightGbm(this RankingContext.RankingTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string groupId = DefaultColumnNames.GroupId, + string weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LightGbmRankingTrainer(env, label, features, groupId, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LightGbmMulticlassTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + } } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 579c612579..bc73e3a41f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -46,15 +46,26 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. - public LightGbmMulticlassTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + public LightGbmMulticlassTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); _numClass = -1; } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index dfd1652394..fdd4b09959 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -92,15 +92,28 @@ internal LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. - /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. - public LightGbmRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + /// The name of the column containing the group ID. + /// The name of the column containing the initial weight. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + public LightGbmRankingTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string groupIdColumn, + string weightColumn = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 1d94738f79..612ef15f6d 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -92,11 +92,14 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBaseThe name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, int? numLeaves = null, @@ -104,19 +107,8 @@ public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - - // override with the directly provided values - Args.NumBoostRound = numBoostRound; - Args.NumLeaves = numLeaves ?? Args.NumLeaves; - Args.LearningRate = learningRate ?? Args.LearningRate; - Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index 361816389d..445a1ad6ad 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.StaticPipe.Runtime; using System; @@ -14,10 +15,10 @@ namespace Microsoft.ML.StaticPipe /// /// Regression trainer estimators. /// - public static partial class RegressionTrainers + public static class LightGbmTrainers { /// - /// LightGbm extension method. + /// Predict a target using a tree regression model trained with the . /// /// The . /// The label column. @@ -49,7 +50,7 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c Action advancedSettings = null, Action onFit = null) { - LightGbmStaticsUtils.CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -63,15 +64,9 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c return rec.Score; } - } - - /// - /// Binary Classification trainer estimators. - /// - public static partial class ClassificationTrainers { /// - /// LightGbm extension method. + /// Predict a target using a tree binary classification model trained with the . /// /// The . /// The label column. @@ -98,7 +93,7 @@ public static (Scalar score, Scalar probability, Scalar pred Action advancedSettings = null, Action> onFit = null) { - LightGbmStaticsUtils.CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -114,11 +109,103 @@ public static (Scalar score, Scalar probability, Scalar pred return rec.Output; } - } - internal static class LightGbmStaticsUtils { + /// + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . + /// + /// The . + /// The label column. + /// The features column. + /// The groupId column. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. + public static Scalar LightGbm(this RankingContext.RankingTrainers ctx, + Scalar label, Vector features, Key groupId, Scalar weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null, + Action onFit = null) + { + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + Contracts.CheckValue(groupId, nameof(groupId)); + + var rec = new TrainerEstimatorReconciler.Ranker( + (env, labelName, featuresName, groupIdName, weightsName) => + { + var trainer = new LightGbmRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, + minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, groupId, weights); + + return rec.Score; + } + + /// + /// Predict a target using a tree multiclass classification model trained with the . + /// + /// The multiclass classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) + LightGbm(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + Key label, + Vector features, + Scalar weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null, + Action onFit = null) + { + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassifier( + (env, labelName, featuresName, weightsName) => + { + var trainer = new LightGbmMulticlassTrainer(env, labelName, featuresName, weightsName, numLeaves, + minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Output; + } - internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, + private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, int? numLeaves, int? minDataPerLeaf, double? learningRate, diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index fb4d21dbd9..0d6313a20c 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Training; using Microsoft.ML.Trainers.FastTree.Internal; @@ -26,7 +27,7 @@ internal static class LightGbmShared /// /// Base class for all training with LightGBM. /// - public abstract class LightGbmTrainerBase : TrainerEstimatorBase + public abstract class LightGbmTrainerBase : TrainerEstimatorBaseWithGroupId where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictorProducing { @@ -57,32 +58,43 @@ private sealed class CategoricalMetaData private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true); public override TrainerInfo Info => _info; - private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + private protected LightGbmTrainerBase(IHostEnvironment env, + string name, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int? numLeaves, + int? minDataPerLeaf, + double? learningRate, + int numBoostRound, + Action advancedSettings) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new LightGbmArguments(); + Args.NumLeaves = numLeaves; + Args.MinDataPerLeaf = minDataPerLeaf; + Args.LearningRate = learningRate; + Args.NumBoostRound = numBoostRound; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args); - Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = weightColumn; + Args.WeightColumn = Optional.Explicit(weightColumn); if (groupIdColumn != null) - Args.GroupIdColumn = groupIdColumn; + Args.GroupIdColumn = Optional.Explicit(groupIdColumn); InitParallelTraining(); } private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGbmArguments args, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); @@ -161,32 +173,6 @@ protected virtual void CheckDataValid(IChannel ch, RoleMappedData data) ch.CheckParam(data.Schema.Label != null, nameof(data), "Need a label column"); } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation. - /// - protected void CheckArgsAndAdvancedSettingMismatch(int? numLeaves, - int? minDataPerLeaf, - double? learningRate, - int numBoostRound, - LightGbmArguments snapshot, - LightGbmArguments currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numBoostRound, snapshot.NumBoostRound, currentArgs.NumBoostRound, nameof(numBoostRound)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDataPerLeaf, snapshot.MinDataPerLeaf, currentArgs.MinDataPerLeaf, nameof(minDataPerLeaf)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRate, currentArgs.LearningRate, nameof(learningRate)); - } - } - protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg=false) { double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs index 564347d9c7..f94511ec76 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs @@ -21,8 +21,10 @@ public static class FactorizationMachineExtensions /// The label, or dependent variable. /// The features, or independent variables. /// The optional example weights. - /// A delegate to set more settings. - /// + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string label, string[] features, string weights = null, diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs index 733c98d28a..2a95df5dd7 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs @@ -28,8 +28,10 @@ public static class FactorizationMachineExtensions /// Initial learning rate. /// Number of training iterations. /// Latent space dimensions. - /// A delegate to set more settings. - /// A delegate that is called every time the + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the ./// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive /// the model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to @@ -57,10 +59,11 @@ public static (Scalar score, Scalar predictedLabel) FieldAwareFacto var trainer = new FieldAwareFactorizationMachineTrainer(env, labelCol, featureCols, advancedSettings: args => { - advancedSettings?.Invoke(args); args.LearningRate = learningRate; args.Iters = numIterations; args.LatentDim = numLatentDimensions; + + advancedSettings?.Invoke(args); }); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 3907131f6b..3c3e0ea13b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1433,7 +1433,10 @@ internal override void Check(IHostEnvironment env) /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LinearClassificationTrainer(IHostEnvironment env, string featureColumn, string labelColumn, @@ -1682,21 +1685,17 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, stri Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); _args = new Arguments(); - advancedSettings?.Invoke(_args); - - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, _args); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(maxIterations, initLearningRate, l2Weight, loss, new Arguments(), _args); + _args.MaxIterations = maxIterations; + _args.InitLearningRate = initLearningRate; + _args.L2Weight = l2Weight; // Apply the advanced args, if the user supplied any. + advancedSettings?.Invoke(_args); + _args.FeatureColumn = featureColumn; _args.LabelColumn = labelColumn; _args.WeightColumn = weightColumn; - _args.MaxIterations = maxIterations; - _args.InitLearningRate = initLearningRate; - _args.L2Weight = l2Weight; + if (loss != null) _args.LossFunction = loss; _args.Check(env); @@ -1719,30 +1718,6 @@ internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Ar _args = args; } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// - internal void CheckArgsAndAdvancedSettingMismatch(int maxIterations, - double initLearningRate, - float l2Weight, - ISupportClassificationLossFactory loss, - Arguments snapshot, - Arguments currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, maxIterations, snapshot.MaxIterations, currentArgs.MaxIterations, nameof(maxIterations)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, initLearningRate, snapshot.InitLearningRate, currentArgs.InitLearningRate, nameof(initLearningRate)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, l2Weight, snapshot.L2Weight, currentArgs.L2Weight, nameof(l2Weight)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, loss, snapshot.LossFunction, currentArgs.LossFunction, nameof(loss)); - } - } - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { return new[] diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 5f7a40c5d2..048dc11d34 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -151,31 +151,49 @@ internal static class Defaults private static readonly TrainerInfo _info = new TrainerInfo(caching: true, supportIncrementalTrain: true); public override TrainerInfo Info => _info; - internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, - string weightColumn, Action advancedSettings, float l1Weight, + internal LbfgsTrainerBase(IHostEnvironment env, + string featureColumn, + SchemaShape.Column labelColumn, + string weightColumn, + Action advancedSettings, + float l1Weight, float l2Weight, float optimizationTolerance, int memorySize, bool enforceNoNegativity) - : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn, - l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) + : this(env, new TArgs + { + FeatureColumn = featureColumn, + LabelColumn = labelColumn.Name, + WeightColumn = weightColumn ?? Optional.Explicit(weightColumn), + L1Weight = l1Weight, + L2Weight = l2Weight, + OptTol = optimizationTolerance, + MemorySize = memorySize, + EnforceNonNegativity = enforceNoNegativity + }, + labelColumn, advancedSettings) { } - internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn, - float? l1Weight = null, - float? l2Weight = null, - float? optimizationTolerance = null, - int? memorySize = null, - bool? enforceNoNegativity = null) + internal LbfgsTrainerBase(IHostEnvironment env, + TArgs args, + SchemaShape.Column labelColumn, + Action advancedSettings = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Args = args; + // Apply the advanced args, if the user supplied any. + advancedSettings?.Invoke(args); + + args.FeatureColumn = FeatureColumn.Name; + args.LabelColumn = LabelColumn.Name; + args.WeightColumn = WeightColumn?.Name; Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, - nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); + nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative"); Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive"); @@ -184,16 +202,15 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative"); Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative"); - Host.CheckParam(!(l2Weight < 0), nameof(l2Weight), "Must be non-negative, if provided."); - Host.CheckParam(!(l1Weight < 0), nameof(l1Weight), "Must be non-negative, if provided"); - Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided."); - Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided."); + Host.CheckParam(!(Args.L2Weight < 0), nameof(Args.L2Weight), "Must be non-negative, if provided."); + Host.CheckParam(!(Args.L1Weight < 0), nameof(Args.L1Weight), "Must be non-negative, if provided"); + Host.CheckParam(!(Args.OptTol <= 0), nameof(Args.OptTol), "Must be positive, if provided."); + Host.CheckParam(!(Args.MemorySize <= 0), nameof(Args.MemorySize), "Must be positive, if provided."); - // Review: Warn about the overriding behavior - L2Weight = l2Weight ?? Args.L2Weight; - L1Weight = l1Weight ?? Args.L1Weight; - OptTol = optimizationTolerance ?? Args.OptTol; - MemorySize = memorySize ?? Args.MemorySize; + L2Weight = Args.L2Weight; + L1Weight = Args.L1Weight; + OptTol = Args.OptTol; + MemorySize =Args.MemorySize; MaxIterations = Args.MaxIterations; SgdInitializationTolerance = Args.SgdInitializationTolerance; Quiet = Args.Quiet; @@ -201,7 +218,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l UseThreads = Args.UseThreads; NumThreads = Args.NumThreads; DenseOptimizer = Args.DenseOptimizer; - EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity; + EnforceNonNegativity = Args.EnforceNonNegativity; if (EnforceNonNegativity && ShowTrainingStats) { @@ -217,14 +234,25 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l } private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, - string weightColumn, Action advancedSettings) + string weightColumn, + float l1Weight, + float l2Weight, + float optimizationTolerance, + int memorySize, + bool enforceNoNegativity) { - var args = new TArgs(); + var args = new TArgs + { + FeatureColumn = featureColumn, + LabelColumn = labelColumn.Name, + WeightColumn = weightColumn ?? Optional.Explicit(weightColumn), + L1Weight = l1Weight, + L2Weight = l2Weight, + OptTol = optimizationTolerance, + MemorySize = memorySize, + EnforceNonNegativity = enforceNoNegativity + }; - // Apply the advanced args, if the user supplied any. - advancedSettings?.Invoke(args); - args.FeatureColumn = featureColumn; - args.LabelColumn = labelColumn.Name; args.WeightColumn = weightColumn; return args; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 5e2003f3e1..70e896feef 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -91,7 +91,7 @@ public MulticlassLogisticRegression(IHostEnvironment env, string featureColumn, int memorySize = Arguments.Defaults.MemorySize, bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), weightColumn, advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weightColumn, advancedSettings, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -104,7 +104,7 @@ public MulticlassLogisticRegression(IHostEnvironment env, string featureColumn, /// Initializes a new instance of /// internal MulticlassLogisticRegression(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeU4ScalarColumn(args.LabelColumn)) { ShowTrainingStats = Args.ShowTrainingStats; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 1713affce5..72dd984328 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -53,7 +53,7 @@ public sealed class Arguments : LearnerInputBaseWithLabel /// The name of the feature column. public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, string labelColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn), - TrainerUtils.MakeU4ScalarLabel(labelColumn)) + TrainerUtils.MakeU4ScalarColumn(labelColumn)) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -64,7 +64,7 @@ public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, s /// internal MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + TrainerUtils.MakeU4ScalarColumn(args.LabelColumn)) { Host.CheckValue(args, nameof(args)); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs index 517d79e7b8..6e3ebe5fa7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs @@ -26,9 +26,12 @@ public static class SdcaRegressionExtensions /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// The custom loss, if unspecified will be . - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static SdcaRegressionTrainer StochasticDualCoordinateAscent(this RegressionContext.RegressionTrainers ctx, - string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, + string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, ISupportSdcaRegressionLoss loss = null, float? l2Const = null, float? l1Threshold = null, @@ -54,7 +57,10 @@ public static class SdcaBinaryClassificationExtensions /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// /// /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 3b774e261c..0ec4ccd878 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -58,7 +58,10 @@ public sealed class Arguments : ArgumentsBase /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public SdcaMultiClassTrainer(IHostEnvironment env, string featureColumn, string labelColumn, @@ -68,7 +71,7 @@ public SdcaMultiClassTrainer(IHostEnvironment env, float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -79,7 +82,7 @@ public SdcaMultiClassTrainer(IHostEnvironment env, internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(env, args, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 84e1241a69..9d3e1205cc 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -63,7 +63,10 @@ public Arguments() /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public SdcaRegressionTrainer(IHostEnvironment env, string featureColumn, string labelColumn, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index d8e15cc460..b803c74d36 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -15,7 +15,7 @@ namespace Microsoft.ML.StaticPipe /// /// Extension methods and utilities for instantiating SDCA trainer estimators inside statically typed pipelines. /// - public static class SdcaRegressionExtensions + public static class SdcaExtensions { /// /// Predict a target using a linear regression model trained with the SDCA trainer. @@ -28,7 +28,10 @@ public static class SdcaRegressionExtensions /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// The custom loss, if unspecified will be . - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -71,10 +74,6 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, return rec.Score; } - } - - public static class SdcaBinaryClassificationExtensions - { /// /// Predict a target using a linear binary classification model trained with the SDCA trainer, and log-loss. @@ -86,7 +85,10 @@ public static class SdcaBinaryClassificationExtensions /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -146,7 +148,10 @@ public static (Scalar score, Scalar probability, Scalar pred /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -199,10 +204,6 @@ public static (Scalar score, Scalar predictedLabel) Sdca( return rec.Output; } - } - - public static class SdcaMulticlassExtensions - { /// /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. @@ -215,7 +216,10 @@ public static class SdcaMulticlassExtensions /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive diff --git a/test/Microsoft.ML.Benchmarks.Tests/BenchmarksTest.cs b/test/Microsoft.ML.Benchmarks.Tests/BenchmarksTest.cs index 79bbd601da..96f4a6bd8f 100644 --- a/test/Microsoft.ML.Benchmarks.Tests/BenchmarksTest.cs +++ b/test/Microsoft.ML.Benchmarks.Tests/BenchmarksTest.cs @@ -39,7 +39,11 @@ public class BenchmarksTest private ITestOutputHelper Output { get; } +#if DEBUG [Fact(Skip = SkipTheDebug)] +#else + [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] +#endif public void BenchmarksProjectIsNotBroken() { var summary = BenchmarkRunner.Run(new TestConfig().With(new OutputLogger(Output))); diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 458f777d2b..e689384392 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -754,6 +754,85 @@ public void FastTreeRanking() Assert.InRange(metrics.Ndcg[2], 36.5, 37); } + [Fact] + public void LightGBMRanking() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new RankingContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), + separator: '\t', hasHeader: true); + + LightGbmRankingPredictor pred = null; + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, r.features, groupId: r.groupId.ToKey())) + .Append(r => (r.label, r.groupId, score: ctx.Trainers.LightGbm(r.label, r.features, r.groupId, onFit: (p) => { pred = p; }))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.groupId, r => r.score); + Assert.NotNull(metrics); + + Assert.True(metrics.Ndcg.Length == metrics.Dcg.Length && metrics.Dcg.Length == 3); + + Assert.InRange(metrics.Dcg[0], 1.4, 1.6); + Assert.InRange(metrics.Dcg[1], 1.4, 1.8); + Assert.InRange(metrics.Dcg[2], 1.4, 1.8); + + Assert.InRange(metrics.Ndcg[0], 36.5, 37); + Assert.InRange(metrics.Ndcg[1], 36.5, 37); + Assert.InRange(metrics.Ndcg[2], 36.5, 37); + } + + [Fact] + public void MultiClassLightGBM() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new MulticlassClassificationContext(env); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); + + OvaPredictor pred = null; + + // With a custom loss function we no longer get calibrated predictions. + var est = reader.MakeNewEstimator() + .Append(r => (label: r.label.ToKey(), r.features)) + .Append(r => (r.label, preds: ctx.Trainers.LightGbm( + r.label, + r.features, onFit: p => pred = p))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + + var data = model.Read(dataSource); + + // Just output some data on the schema for fun. + var schema = data.AsDynamic.Schema; + for (int c = 0; c < schema.ColumnCount; ++c) + Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2); + Assert.True(metrics.LogLoss > 0); + Assert.True(metrics.TopKAccuracy > 0); + } + [Fact] public void MultiClassNaiveBayesTrainer() { diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index 0102f8d4e4..5613765e77 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -236,6 +236,7 @@ protected void DoNotEverUseInvertPass() private static readonly Regex _matchTime = new Regex(@"[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+)?", RegexOptions.Compiled); private static readonly Regex _matchShortTime = new Regex(@"\([0-9]{2}:[0-9]{2}(\.[0-9]+)?\)", RegexOptions.Compiled); private static readonly Regex _matchMemory = new Regex(@"memory usage\(MB\): [0-9]+", RegexOptions.Compiled); + private static readonly Regex _matchReservedMemory = new Regex(@": [0-9]+ bytes", RegexOptions.Compiled); private static readonly Regex _matchElapsed = new Regex(@"Time elapsed\(s\): [0-9.]+", RegexOptions.Compiled); private static readonly Regex _matchTimes = new Regex(@"Instances caching time\(s\): [0-9\.]+", RegexOptions.Compiled); private static readonly Regex _matchUpdatesPerSec = new Regex(@", ([0-9\.]+|Infinity)M WeightUpdates/sec", RegexOptions.Compiled); @@ -284,6 +285,7 @@ protected void Normalize(string path) line = _matchShortTime.Replace(line, "(%Time%)"); line = _matchElapsed.Replace(line, "Time elapsed(s): %Number%"); line = _matchMemory.Replace(line, "memory usage(MB): %Number%"); + line = _matchReservedMemory.Replace(line, ": %Number% bytes"); line = _matchTimes.Replace(line, "Instances caching time(s): %Number%"); line = _matchUpdatesPerSec.Replace(line, ", %Number%M WeightUpdates/sec"); line = _matchParameterT.Replace(line, "=PARAM:/t:%Number%"); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index d8d76527a7..857cd20822 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -906,7 +906,7 @@ public void TensorFlowTransformCifarSavedModel() } } - [Fact] + [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] public void TensorFlowTransformCifarInvalidShape() { var model_location = "cifar_model/frozen_model.pb"; diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index 13236d6658..690bd0dd2f 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -183,7 +183,7 @@ public void TestTensorFlowStatic() } } - [Fact] + [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] public void TestTensorFlowStaticWithSchema() { var modelLocation = "cifar_model/frozen_model.pb"; From 4319f8665b6285daaaf7dd17d742e22300076a9b Mon Sep 17 00:00:00 2001 From: feiyun0112 Date: Sat, 27 Oct 2018 22:22:58 +0800 Subject: [PATCH 3/3] Resolve conflict --- src/Microsoft.ML.FastTree/FastTree.cs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index b73b01faa1..23bed0acef 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -124,11 +124,6 @@ private protected FastTreeTrainerBase(IHostEnvironment env, if (groupIdColumn != null) Args.GroupIdColumn = Optional.Explicit(groupIdColumn); ; - //override with the directly provided values. - Args.NumLeaves = numLeaves; - Args.NumTrees = numTrees; - Args.MinDocumentsInLeafs = minDocumentsInLeafs; - // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.