Skip to content

Commit aea88dc

Browse files
authored
Updating LightGBM Arguments (#2948)
* Breaking down the LightGBM options class into separate option classes for each LightGBM trainer. * Refactored the Boost Option classes * Hides the interface for the Booster Parameter Factory (IBoosterParameterFactory). Fixes #2559 Fixes #2618
1 parent 1a89468 commit aea88dc

29 files changed

+1159
-1111
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGbmWithOptions.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using Microsoft.ML.Trainers.LightGbm;
2-
using static Microsoft.ML.Trainers.LightGbm.Options;
1+
using Microsoft.ML.Trainers.LightGbm;
32

43
namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
54
{
@@ -19,7 +18,7 @@ public static void Example()
1918

2019
// Create the pipeline with LightGbm Estimator using advanced options.
2120
var pipeline = mlContext.BinaryClassification.Trainers.LightGbm(
22-
new Options
21+
new LightGbmBinaryTrainer.Options
2322
{
2423
Booster = new GossBooster.Options
2524
{

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
using Microsoft.ML.Data;
44
using Microsoft.ML.SamplesUtils;
55
using Microsoft.ML.Trainers.LightGbm;
6-
using static Microsoft.ML.Trainers.LightGbm.Options;
76

87
namespace Microsoft.ML.Samples.Dynamic.Trainers.MulticlassClassification
98
{
@@ -33,11 +32,11 @@ public static void Example()
3332
// - Convert the string labels into key types.
3433
// - Apply LightGbm multiclass trainer with advanced options.
3534
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label")
36-
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(new Options
35+
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
3736
{
3837
LabelColumnName = "LabelIndex",
3938
FeatureColumnName = "Features",
40-
Booster = new DartBooster.Options
39+
Booster = new DartBooster.Options()
4140
{
4241
TreeDropFraction = 0.15,
4342
XgboostDartMode = false

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Ranking/LightGbmWithOptions.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.ML.Trainers.LightGbm;
2-
using static Microsoft.ML.Trainers.LightGbm.Options;
32

43
namespace Microsoft.ML.Samples.Dynamic.Trainers.Ranking
54
{
@@ -21,13 +20,13 @@ public static void Example()
2120

2221
// Create the Estimator pipeline. For simplicity, we will train a small tree with 4 leaves and 2 boosting iterations.
2322
var pipeline = mlContext.Ranking.Trainers.LightGbm(
24-
new Options
23+
new LightGbmRankingTrainer.Options
2524
{
2625
NumberOfLeaves = 4,
2726
MinimumExampleCountPerGroup = 10,
2827
LearningRate = 0.1,
2928
NumberOfIterations = 2,
30-
Booster = new TreeBooster.Options
29+
Booster = new GradientBooster.Options
3130
{
3231
FeatureFraction = 0.9
3332
}

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using System.Linq;
33
using Microsoft.ML.Data;
44
using Microsoft.ML.Trainers.LightGbm;
5-
using static Microsoft.ML.Trainers.LightGbm.Options;
65

76
namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression
87
{
@@ -36,13 +35,13 @@ public static void Example()
3635
.Where(name => name != labelName) // Drop the Label
3736
.ToArray();
3837
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
39-
.Append(mlContext.Regression.Trainers.LightGbm(new Options
38+
.Append(mlContext.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
4039
{
4140
LabelColumnName = labelName,
4241
NumberOfLeaves = 4,
4342
MinimumExampleCountPerLeaf = 6,
4443
LearningRate = 0.001,
45-
Booster = new GossBooster.Options
44+
Booster = new GossBooster.Options()
4645
{
4746
TopRate = 0.3,
4847
OtherRate = 0.2

src/Microsoft.ML.LightGbm.StaticPipe/LightGbmStaticExtensions.cs

+17-16
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public static Scalar<float> LightGbm(this RegressionCatalog.RegressionTrainers c
4242
int? numberOfLeaves = null,
4343
int? minimumExampleCountPerLeaf = null,
4444
double? learningRate = null,
45-
int numberOfIterations = Options.Defaults.NumberOfIterations,
45+
int numberOfIterations = Defaults.NumberOfIterations,
4646
Action<LightGbmRegressionModelParameters> onFit = null)
4747
{
4848
CheckUserValues(label, features, weights, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations, onFit);
@@ -76,10 +76,11 @@ public static Scalar<float> LightGbm(this RegressionCatalog.RegressionTrainers c
7676
/// <returns>The Score output column indicating the predicted value.</returns>
7777
public static Scalar<float> LightGbm(this RegressionCatalog.RegressionTrainers catalog,
7878
Scalar<float> label, Vector<float> features, Scalar<float> weights,
79-
Options options,
79+
LightGbmRegressionTrainer.Options options,
8080
Action<LightGbmRegressionModelParameters> onFit = null)
8181
{
82-
CheckUserValues(label, features, weights, options, onFit);
82+
Contracts.CheckValue(options, nameof(options));
83+
CheckUserValues(label, features, weights, onFit);
8384

8485
var rec = new TrainerEstimatorReconciler.Regression(
8586
(env, labelName, featuresName, weightsName) =>
@@ -128,7 +129,7 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
128129
int? numberOfLeaves = null,
129130
int? minimumExampleCountPerLeaf = null,
130131
double? learningRate = null,
131-
int numberOfIterations = Options.Defaults.NumberOfIterations,
132+
int numberOfIterations = Defaults.NumberOfIterations,
132133
Action<CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator>> onFit = null)
133134
{
134135
CheckUserValues(label, features, weights, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations, onFit);
@@ -165,10 +166,11 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
165166
/// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
166167
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
167168
Scalar<bool> label, Vector<float> features, Scalar<float> weights,
168-
Options options,
169+
LightGbmBinaryTrainer.Options options,
169170
Action<CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator>> onFit = null)
170171
{
171-
CheckUserValues(label, features, weights, options, onFit);
172+
Contracts.CheckValue(options, nameof(options));
173+
CheckUserValues(label, features, weights, onFit);
172174

173175
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
174176
(env, labelName, featuresName, weightsName) =>
@@ -215,7 +217,7 @@ public static Scalar<float> LightGbm<TVal>(this RankingCatalog.RankingTrainers c
215217
int? numberOfLeaves = null,
216218
int? minimumExampleCountPerLeaf = null,
217219
double? learningRate = null,
218-
int numberOfIterations = Options.Defaults.NumberOfIterations,
220+
int numberOfIterations = Defaults.NumberOfIterations,
219221
Action<LightGbmRankingModelParameters> onFit = null)
220222
{
221223
CheckUserValues(label, features, weights, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations, onFit);
@@ -253,10 +255,11 @@ public static Scalar<float> LightGbm<TVal>(this RankingCatalog.RankingTrainers c
253255
/// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
254256
public static Scalar<float> LightGbm<TVal>(this RankingCatalog.RankingTrainers catalog,
255257
Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights,
256-
Options options,
258+
LightGbmRankingTrainer.Options options,
257259
Action<LightGbmRankingModelParameters> onFit = null)
258260
{
259-
CheckUserValues(label, features, weights, options, onFit);
261+
Contracts.CheckValue(options, nameof(options));
262+
CheckUserValues(label, features, weights, onFit);
260263
Contracts.CheckValue(groupId, nameof(groupId));
261264

262265
var rec = new TrainerEstimatorReconciler.Ranker<TVal>(
@@ -309,7 +312,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
309312
int? numberOfLeaves = null,
310313
int? minimumExampleCountPerLeaf = null,
311314
double? learningRate = null,
312-
int numberOfIterations = Options.Defaults.NumberOfIterations,
315+
int numberOfIterations = Defaults.NumberOfIterations,
313316
Action<OneVersusAllModelParameters> onFit = null)
314317
{
315318
CheckUserValues(label, features, weights, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations, onFit);
@@ -347,10 +350,11 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
347350
Key<uint, TVal> label,
348351
Vector<float> features,
349352
Scalar<float> weights,
350-
Options options,
353+
LightGbmMulticlassTrainer.Options options,
351354
Action<OneVersusAllModelParameters> onFit = null)
352355
{
353-
CheckUserValues(label, features, weights, options, onFit);
356+
Contracts.CheckValue(options, nameof(options));
357+
CheckUserValues(label, features, weights, onFit);
354358

355359
var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler<TVal>(
356360
(env, labelName, featuresName, weightsName) =>
@@ -386,14 +390,11 @@ private static void CheckUserValues(PipelineColumn label, Vector<float> features
386390
Contracts.CheckValueOrNull(onFit);
387391
}
388392

389-
private static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights,
390-
Options options,
391-
Delegate onFit)
393+
private static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights, Delegate onFit)
392394
{
393395
Contracts.CheckValue(label, nameof(label));
394396
Contracts.CheckValue(features, nameof(features));
395397
Contracts.CheckValueOrNull(weights);
396-
Contracts.CheckValue(options, nameof(options));
397398
Contracts.CheckValueOrNull(onFit);
398399
}
399400
}

0 commit comments

Comments
 (0)