From eec4d95dda5e9bc5423d2bfbfb5612ae47de7a5d Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 11 Dec 2019 15:52:49 -0800 Subject: [PATCH 01/16] Experiments --- .../MulticlassClassification/LightGbm.cs | 36 +++++++++++++++++-- .../LightGbmMulticlassTrainer.cs | 2 +- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs index 7ab6d5f80e..a7995132e5 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs @@ -68,13 +68,13 @@ public static void Example() .Evaluate(transformedTestData); PrintMetrics(metrics); - + // Expected output: // Micro Accuracy: 0.99 // Macro Accuracy: 0.99 // Log Loss: 0.05 // Log Loss Reduction: 0.95 - + // Confusion table // ||======================== // PREDICTED || 0 | 1 | 2 | Recall @@ -84,6 +84,15 @@ public static void Example() // 2 || 1 | 0 | 162 | 0.9939 // ||======================== // Precision ||0.9936 |1.0000 |0.9701 | + + var trainingData2 = mlContext.Data + .LoadFromEnumerable(GenerateRandomDataPoints2(500, seed: 123)); + + model = pipeline.Fit(trainingData2); + var transformedTestData2 = model.Transform(trainingData2); + var metrics2 = mlContext.MulticlassClassification + .Evaluate(transformedTestData2); + PrintMetrics(metrics2); } // Generates random uniform doubles in [-0.5, 0.5) @@ -111,6 +120,29 @@ private static IEnumerable GenerateRandomDataPoints(int count, } } + private static IEnumerable GenerateRandomDataPoints2(int count, + int seed = 0) + + { + var random = new Random(seed); + float randomFloat() => (float)(random.NextDouble() - 0.5); + for (int i = 0; i < count; i++) + { + // Generate Labels that are integers 1, 2 or 3 + var label = random.Next(1, 5); + yield return new DataPoint + { + Label = (uint)label, + // Create random features that are correlated with the label. + // The feature values are slightly increased by adding a + // constant multiple of label. + Features = Enumerable.Repeat(label, 20) + .Select(x => randomFloat() + label * 0.2f).ToArray() + + }; + } + } + // Example with label and 20 feature values. A data set is a collection of // such examples. private class DataPoint diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index 606c35657f..4ebc72673e 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -218,7 +218,7 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels) { // Only initialize one time. - if (_numClass < 0) + if (_numClass != 0) { float minLabel = float.MaxValue; float maxLabel = float.MinValue; From 2c4c5fc20274de9d39f876ac72e194555366c6d5 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 18 Dec 2019 10:40:06 -0800 Subject: [PATCH 02/16] Run LightGBM multiclass example in Program.cs --- docs/samples/Microsoft.ML.Samples/Program.cs | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Program.cs b/docs/samples/Microsoft.ML.Samples/Program.cs index 4c46399421..91b300fdf1 100644 --- a/docs/samples/Microsoft.ML.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.Samples/Program.cs @@ -1,6 +1,7 @@ using System; using System.Reflection; using Samples.Dynamic; +using Samples.Dynamic.Trainers.MulticlassClassification; namespace Microsoft.ML.Samples { @@ -10,20 +11,7 @@ public static class Program internal static void RunAll() { - int samples = 0; - foreach (var type in Assembly.GetExecutingAssembly().GetTypes()) - { - var sample = type.GetMethod("Example", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy); - - if (sample != null) - { - Console.WriteLine(type.Name); - sample.Invoke(null, null); - samples++; - } - } - - Console.WriteLine("Number of samples that ran without any exception: " + samples); + LightGbm.Example(); } } } From daac7a4205f22d1f037294c3ee50b79ff990650a Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Thu, 19 Dec 2019 11:09:30 -0800 Subject: [PATCH 03/16] Return _numClass if statement to its original form --- src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index 4ebc72673e..606c35657f 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -218,7 +218,7 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels) { // Only initialize one time. - if (_numClass != 0) + if (_numClass < 0) { float minLabel = float.MaxValue; float maxLabel = float.MinValue; From 9bfd8fd2406ccf54bd510cf8c562c7235cfce4c0 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Fri, 20 Dec 2019 13:34:23 -0800 Subject: [PATCH 04/16] Added test to TreeEstimators --- .../TrainerEstimators/TreeEstimators.cs | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 3e6d9e2769..e192c38f48 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -13,6 +13,7 @@ using Microsoft.ML.RunTests; using Microsoft.ML.Runtime; using Microsoft.ML.TestFramework.Attributes; +using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.LightGbm; using Microsoft.ML.Transforms; @@ -677,6 +678,67 @@ public void LightGbmMulticlassEstimatorCompareUnbalanced() Done(); } + private class DataPoint + { + public uint Label { get; set; } + + [VectorType(20)] + public float[] Features { get; set; } + } + + private static IEnumerable GenerateRandomDataPoints(int count, + int seed = 0, int numClasses = 3) + + { + var random = new Random(seed); + float randomFloat() => (float)(random.NextDouble() - 0.5); + for (int i = 0; i < count; i++) + { + var label = random.Next(1, numClasses + 1); + yield return new DataPoint + { + Label = (uint)label, + Features = Enumerable.Repeat(label, 20) + .Select(x => randomFloat() + label * 0.2f).ToArray() + }; + } + } + + [LightGBMFact] + public void LightGbmFitMoreThanOnce() + { + var mlContext = new MLContext(seed: 0); + + var pipeline = + mlContext.Transforms.Conversion + .MapValueToKey(nameof(DataPoint.Label)) + .Append(mlContext.MulticlassClassification.Trainers + .LightGbm()); + + var numClasses = 3; + var dataPoints = GenerateRandomDataPoints(100, numClasses:numClasses); + var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + var model = pipeline.Fit(trainingData); + var numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length; + Assert.Equal(numClasses, numOfSubParameters); + + numClasses = 4; + dataPoints = GenerateRandomDataPoints(100, numClasses: numClasses); + trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + model = pipeline.Fit(trainingData); + numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length; + Assert.Equal(numClasses, numOfSubParameters); + + numClasses = 2; + dataPoints = GenerateRandomDataPoints(100, numClasses: numClasses); + trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + model = pipeline.Fit(trainingData); + numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length; + Assert.Equal(numClasses, numOfSubParameters); + + Done(); + } + [LightGBMFact] public void LightGbmInDifferentCulture() { From fb3139a5f5ee7f65590710087577f9f04cffae35 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Fri, 20 Dec 2019 14:07:11 -0800 Subject: [PATCH 05/16] Add InitializeBeforeTraining method --- src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs | 6 ++++++ src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index 606c35657f..c066224251 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -215,6 +215,12 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) } } + private protected override void InitializeBeforeTraining() + { + _numClass = -1; //MYTODO: Include more initializations, of TrainedEnsemble, for example? + _tlcNumClass = 0; + } + private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels) { // Only initialize one time. diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index fb6bbeefde..2086ee83a0 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Text; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; @@ -340,6 +341,8 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, TOption private protected override TModel TrainModelCore(TrainContext context) { + InitializeBeforeTraining(); + Host.CheckValue(context, nameof(context)); Dataset dtrain = null; @@ -371,6 +374,9 @@ private protected override TModel TrainModelCore(TrainContext context) return CreatePredictor(); } + private protected virtual void InitializeBeforeTraining() + { return; } + private void InitParallelTraining() { GbmOptions = LightGbmTrainerOptions.ToDictionary(Host); From 635fbb96bceda43ff7c59473ba47e83b0dab3896 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Mon, 30 Dec 2019 15:45:27 -0800 Subject: [PATCH 06/16] Revert to original Sample --- .../MulticlassClassification/LightGbm.cs | 37 +------------------ docs/samples/Microsoft.ML.Samples/Program.cs | 18 +++++++-- 2 files changed, 17 insertions(+), 38 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs index a7995132e5..6840db2516 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs @@ -53,7 +53,7 @@ public static void Example() // Look at 5 predictions foreach (var p in predictions.Take(5)) - Console.WriteLine($"Label: {p.Label}, " + + Console.WriteLine($"Label: {p.Label}, " + $"Prediction: {p.PredictedLabel}"); // Expected output: @@ -84,21 +84,12 @@ public static void Example() // 2 || 1 | 0 | 162 | 0.9939 // ||======================== // Precision ||0.9936 |1.0000 |0.9701 | - - var trainingData2 = mlContext.Data - .LoadFromEnumerable(GenerateRandomDataPoints2(500, seed: 123)); - - model = pipeline.Fit(trainingData2); - var transformedTestData2 = model.Transform(trainingData2); - var metrics2 = mlContext.MulticlassClassification - .Evaluate(transformedTestData2); - PrintMetrics(metrics2); } // Generates random uniform doubles in [-0.5, 0.5) // range with labels 1, 2 or 3. private static IEnumerable GenerateRandomDataPoints(int count, - int seed=0) + int seed = 0) { var random = new Random(seed); @@ -120,29 +111,6 @@ private static IEnumerable GenerateRandomDataPoints(int count, } } - private static IEnumerable GenerateRandomDataPoints2(int count, - int seed = 0) - - { - var random = new Random(seed); - float randomFloat() => (float)(random.NextDouble() - 0.5); - for (int i = 0; i < count; i++) - { - // Generate Labels that are integers 1, 2 or 3 - var label = random.Next(1, 5); - yield return new DataPoint - { - Label = (uint)label, - // Create random features that are correlated with the label. - // The feature values are slightly increased by adding a - // constant multiple of label. - Features = Enumerable.Repeat(label, 20) - .Select(x => randomFloat() + label * 0.2f).ToArray() - - }; - } - } - // Example with label and 20 feature values. A data set is a collection of // such examples. private class DataPoint @@ -174,4 +142,3 @@ public static void PrintMetrics(MulticlassClassificationMetrics metrics) } } } - diff --git a/docs/samples/Microsoft.ML.Samples/Program.cs b/docs/samples/Microsoft.ML.Samples/Program.cs index 91b300fdf1..24674651c7 100644 --- a/docs/samples/Microsoft.ML.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.Samples/Program.cs @@ -1,7 +1,6 @@ using System; using System.Reflection; using Samples.Dynamic; -using Samples.Dynamic.Trainers.MulticlassClassification; namespace Microsoft.ML.Samples { @@ -11,7 +10,20 @@ public static class Program internal static void RunAll() { - LightGbm.Example(); + int samples = 0; + foreach (var type in Assembly.GetExecutingAssembly().GetTypes()) + { + var sample = type.GetMethod("Example", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy); + + if (sample != null) + { + Console.WriteLine(type.Name); + sample.Invoke(null, null); + samples++; + } + } + + Console.WriteLine("Number of samples that ran without any exception: " + samples); } } -} +} \ No newline at end of file From 8d9005b1578749063d3f57ca038bd435505e9c57 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Mon, 30 Dec 2019 15:51:54 -0800 Subject: [PATCH 07/16] Typos --- .../Dynamic/Trainers/MulticlassClassification/LightGbm.cs | 4 ++-- docs/samples/Microsoft.ML.Samples/Program.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs index 6840db2516..54200705c6 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs @@ -53,7 +53,7 @@ public static void Example() // Look at 5 predictions foreach (var p in predictions.Take(5)) - Console.WriteLine($"Label: {p.Label}, " + + Console.WriteLine($"Label: {p.Label}, " + $"Prediction: {p.PredictedLabel}"); // Expected output: @@ -89,7 +89,7 @@ public static void Example() // Generates random uniform doubles in [-0.5, 0.5) // range with labels 1, 2 or 3. private static IEnumerable GenerateRandomDataPoints(int count, - int seed = 0) + int seed=0) { var random = new Random(seed); diff --git a/docs/samples/Microsoft.ML.Samples/Program.cs b/docs/samples/Microsoft.ML.Samples/Program.cs index 24674651c7..4c46399421 100644 --- a/docs/samples/Microsoft.ML.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.Samples/Program.cs @@ -26,4 +26,4 @@ internal static void RunAll() Console.WriteLine("Number of samples that ran without any exception: " + samples); } } -} \ No newline at end of file +} From 0902927fa412a91783bb6de0f8d3a78a2a441532 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Mon, 30 Dec 2019 15:53:24 -0800 Subject: [PATCH 08/16] Removing unused using --- src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index 2086ee83a0..ec9d103fc4 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; -using System.Runtime.CompilerServices; using System.Text; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; From 1aaa77f73f237483cfc355a3b0e08ec845fff028 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Mon, 30 Dec 2019 16:09:58 -0800 Subject: [PATCH 09/16] Modified MYTODO comments --- .../LightGbmMulticlassTrainer.cs | 16 ++++++++++++++-- src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index cc70360a6c..6dd554640b 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -218,13 +218,18 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) private protected override void InitializeBeforeTraining() { - _numClass = -1; //MYTODO: Include more initializations, of TrainedEnsemble, for example? + _numClass = -1; _tlcNumClass = 0; - } + + //MYTODO: Include more initializations, of TrainedEnsemble, for example? + //For example: + //TrainedEnsemble = null; + } private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels) { // Only initialize one time. + if (_numClass < 0) { float minLabel = float.MaxValue; @@ -261,6 +266,13 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat _tlcNumClass = (int)maxLabel + 1; } } + + // If there are NaN labels, they are converted to be equal to _tlcNumClass (i.e. _numClass - 1). + // This is done because NaN labels are going to be seen as + // an extra different class, and thus, when training the model in the WrappedLightGbmTraining class + // a total of _numClass classes are considered. But, when creating the Predictors, only _tlcNumClass number of + // classes are considered ignoring the extra class of NaN labels. + float defaultLabel = _numClass - 1; for (int i = 0; i < labels.Length; ++i) if (float.IsNaN(labels[i])) diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index ec9d103fc4..f77e2fa2f4 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -374,7 +374,7 @@ private protected override TModel TrainModelCore(TrainContext context) } private protected virtual void InitializeBeforeTraining() - { return; } + { return; } // MYTODO: Is there a better way to avoid having to do this? An abstract method? private void InitParallelTraining() { From ff14b112f0ce28d4a4b46da9ef991652250fdfb7 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Thu, 2 Jan 2020 10:51:38 -0800 Subject: [PATCH 10/16] Changed names of variables and moved comment --- .../LightGbmMulticlassTrainer.cs | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index 6dd554640b..99167d3f80 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -58,8 +58,13 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase PredictionKind.MulticlassClassification; /// @@ -129,7 +134,7 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, Options options) : base(env, LoadNameValue, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName)) { Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0."); - _numClass = -1; + _numberOfClassesIncludingNan = -1; } /// @@ -168,7 +173,7 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, private InternalTreeEnsemble GetBinaryEnsemble(int classID) { var res = new InternalTreeEnsemble(); - for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numClass) + for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numberOfClassesIncludingNan) { // Ignore dummy trees. if (TrainedEnsemble.GetTreeAt(i).NumLeaves > 1) @@ -186,12 +191,12 @@ private protected override OneVersusAllModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete."); - Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor."); - Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes."); + Host.Assert(_numberOfClassesIncludingNan > 1, "Must know the number of classes before creating a predictor."); + Host.Assert(TrainedEnsemble.NumTrees % _numberOfClassesIncludingNan == 0, "Number of trees should be a multiple of number of classes."); var innerArgs = LightGbmInterfaceUtils.JoinParameters(GbmOptions); - IPredictorProducing[] predictors = new IPredictorProducing[_tlcNumClass]; - for (int i = 0; i < _tlcNumClass; ++i) + IPredictorProducing[] predictors = new IPredictorProducing[_numberOfClasses]; + for (int i = 0; i < _numberOfClasses; ++i) { var pred = CreateBinaryPredictor(i, innerArgs); var cali = new PlattCalibrator(Host, -LightGbmTrainerOptions.Sigmoid, 0); @@ -218,8 +223,8 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) private protected override void InitializeBeforeTraining() { - _numClass = -1; - _tlcNumClass = 0; + _numberOfClassesIncludingNan = -1; + _numberOfClasses = 0; //MYTODO: Include more initializations, of TrainedEnsemble, for example? //For example: @@ -230,7 +235,7 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat { // Only initialize one time. - if (_numClass < 0) + if (_numberOfClassesIncludingNan < 0) { float minLabel = float.MaxValue; float maxLabel = float.MinValue; @@ -252,28 +257,22 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat if (data.Schema.Label.Value.Type is KeyDataViewType keyType) { if (hasNaNLabel) - _numClass = keyType.GetCountAsInt32(Host) + 1; + _numberOfClassesIncludingNan = keyType.GetCountAsInt32(Host) + 1; else - _numClass = keyType.GetCountAsInt32(Host); - _tlcNumClass = keyType.GetCountAsInt32(Host); + _numberOfClassesIncludingNan = keyType.GetCountAsInt32(Host); + _numberOfClasses = keyType.GetCountAsInt32(Host); } else { if (hasNaNLabel) - _numClass = (int)maxLabel + 2; + _numberOfClassesIncludingNan = (int)maxLabel + 2; else - _numClass = (int)maxLabel + 1; - _tlcNumClass = (int)maxLabel + 1; + _numberOfClassesIncludingNan = (int)maxLabel + 1; + _numberOfClasses = (int)maxLabel + 1; } } - // If there are NaN labels, they are converted to be equal to _tlcNumClass (i.e. _numClass - 1). - // This is done because NaN labels are going to be seen as - // an extra different class, and thus, when training the model in the WrappedLightGbmTraining class - // a total of _numClass classes are considered. But, when creating the Predictors, only _tlcNumClass number of - // classes are considered ignoring the extra class of NaN labels. - - float defaultLabel = _numClass - 1; + float defaultLabel = _numberOfClassesIncludingNan - 1; for (int i = 0; i < labels.Length; ++i) if (float.IsNaN(labels[i])) labels[i] = defaultLabel; @@ -283,7 +282,7 @@ private protected override void GetDefaultParameters(IChannel ch, int numRow, bo { base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true); int numberOfLeaves = (int)GbmOptions["num_leaves"]; - int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, _numClass); + int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, _numberOfClassesIncludingNan); GbmOptions["min_data_per_leaf"] = minimumExampleCountPerLeaf; if (!hiddenMsg) { @@ -300,8 +299,8 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel { Host.AssertValue(ch); ch.Assert(PredictionKind == PredictionKind.MulticlassClassification); - ch.Assert(_numClass > 1); - GbmOptions["num_class"] = _numClass; + ch.Assert(_numberOfClassesIncludingNan > 1); + GbmOptions["num_class"] = _numberOfClassesIncludingNan; bool useSoftmax = false; if (LightGbmTrainerOptions.UseSoftmax.HasValue) From 0c825db5737c328debc877800de8ac663059d2ad Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Thu, 2 Jan 2020 10:53:35 -0800 Subject: [PATCH 11/16] Minor mistake in comment --- src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index 99167d3f80..b3bed4ea34 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -59,7 +59,7 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase Date: Thu, 2 Jan 2020 10:55:00 -0800 Subject: [PATCH 12/16] Erased the content of InitializeBeforeTraining virtual method --- src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index f77e2fa2f4..18b6a9a408 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -373,8 +373,7 @@ private protected override TModel TrainModelCore(TrainContext context) return CreatePredictor(); } - private protected virtual void InitializeBeforeTraining() - { return; } // MYTODO: Is there a better way to avoid having to do this? An abstract method? + private protected virtual void InitializeBeforeTraining(){} private void InitParallelTraining() { From d265046bd64f4bdd8c2be205c64797446329ca34 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Thu, 2 Jan 2020 11:01:14 -0800 Subject: [PATCH 13/16] Made ParallelTraining readonly --- src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index 18b6a9a408..1a409e626a 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -288,7 +288,7 @@ private sealed class CategoricalMetaData /// private protected Dictionary GbmOptions; - private protected IParallel ParallelTraining; + private protected readonly IParallel ParallelTraining; // Store _featureCount and _trainedEnsemble to construct predictor. private protected int FeatureCount; From dc79c6e649f8ce668f50dcbaab38df4172c9fca6 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Thu, 2 Jan 2020 11:24:40 -0800 Subject: [PATCH 14/16] Minor fix in other test, with a using block --- .../TrainerEstimators/TreeEstimators.cs | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index e192c38f48..c119e33c99 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -492,21 +492,23 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr using (var pch = (mlContext as IProgressChannelProvider).StartProgressChannel("Training LightGBM...")) { var host = (mlContext as IHostEnvironment).Register("Training LightGBM..."); - var gbmNative = WrappedLightGbmTraining.Train(ch, pch, gbmParams, gbmDataSet, numIteration: numberOfTrainingIterations); - int nativeLength = 0; - unsafe + using (var gbmNative = WrappedLightGbmTraining.Train(ch, pch, gbmParams, gbmDataSet, numIteration: numberOfTrainingIterations)) { - fixed (float* data = dataMatrix) - fixed (double* result0 = lgbmProbabilities) - fixed (double* result1 = lgbmRawScores) + int nativeLength = 0; + unsafe { - WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32, - _rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Normal, numberOfTrainingIterations, "", ref nativeLength, result0); - WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32, - _rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Raw, numberOfTrainingIterations, "", ref nativeLength, result1); + fixed (float* data = dataMatrix) + fixed (double* result0 = lgbmProbabilities) + fixed (double* result1 = lgbmRawScores) + { + WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32, + _rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Normal, numberOfTrainingIterations, "", ref nativeLength, result0); + WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32, + _rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Raw, numberOfTrainingIterations, "", ref nativeLength, result1); + } + modelString = gbmNative.GetModelString(); } - modelString = gbmNative.GetModelString(); } } } From beb0ed2b81ea12f0c66ca93d4554b9a653927223 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Thu, 2 Jan 2020 11:25:31 -0800 Subject: [PATCH 15/16] Made GbmOptions readonly, and moved the initialization of ParallelTraining and GbmOptions --- src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index 1a409e626a..6a1c7558e7 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -286,7 +286,7 @@ private sealed class CategoricalMetaData /// the code is culture agnostic. When retrieving key value from this dictionary as string /// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]). /// - private protected Dictionary GbmOptions; + private protected readonly Dictionary GbmOptions; private protected readonly IParallel ParallelTraining; @@ -335,6 +335,8 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, TOption Contracts.CheckUserArg(options.L2CategoricalRegularization >= 0.0, nameof(options.L2CategoricalRegularization), "must be >= 0."); LightGbmTrainerOptions = options; + ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); + GbmOptions = LightGbmTrainerOptions.ToDictionary(Host); InitParallelTraining(); } @@ -377,9 +379,6 @@ private protected virtual void InitializeBeforeTraining(){} private void InitParallelTraining() { - GbmOptions = LightGbmTrainerOptions.ToDictionary(Host); - ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); - if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1) { GbmOptions["tree_learner"] = ParallelTraining.ParallelType(); From af2c53ea88b1e36dfd5f035cf047247eb41ba140 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Mon, 6 Jan 2020 13:59:57 -0800 Subject: [PATCH 16/16] Removing MYTODO comment --- src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index b3bed4ea34..9ba7490780 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -225,10 +225,6 @@ private protected override void InitializeBeforeTraining() { _numberOfClassesIncludingNan = -1; _numberOfClasses = 0; - - //MYTODO: Include more initializations, of TrainedEnsemble, for example? - //For example: - //TrainedEnsemble = null; } private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)