diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs index 7ab6d5f80e..54200705c6 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbm.cs @@ -68,13 +68,13 @@ public static void Example() .Evaluate(transformedTestData); PrintMetrics(metrics); - + // Expected output: // Micro Accuracy: 0.99 // Macro Accuracy: 0.99 // Log Loss: 0.05 // Log Loss Reduction: 0.95 - + // Confusion table // ||======================== // PREDICTED || 0 | 1 | 2 | Recall @@ -142,4 +142,3 @@ public static void PrintMetrics(MulticlassClassificationMetrics metrics) } } } - diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index dff5bbd5fe..9ba7490780 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -58,8 +58,13 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase PredictionKind.MulticlassClassification; /// @@ -129,7 +134,7 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, Options options) : base(env, LoadNameValue, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName)) { Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0."); - _numClass = -1; + _numberOfClassesIncludingNan = -1; } /// @@ -168,7 +173,7 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, private InternalTreeEnsemble GetBinaryEnsemble(int classID) { var res = new InternalTreeEnsemble(); - for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numClass) + for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numberOfClassesIncludingNan) { // Ignore dummy trees. if (TrainedEnsemble.GetTreeAt(i).NumLeaves > 1) @@ -186,12 +191,12 @@ private protected override OneVersusAllModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete."); - Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor."); - Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes."); + Host.Assert(_numberOfClassesIncludingNan > 1, "Must know the number of classes before creating a predictor."); + Host.Assert(TrainedEnsemble.NumTrees % _numberOfClassesIncludingNan == 0, "Number of trees should be a multiple of number of classes."); var innerArgs = LightGbmInterfaceUtils.JoinParameters(GbmOptions); - IPredictorProducing[] predictors = new IPredictorProducing[_tlcNumClass]; - for (int i = 0; i < _tlcNumClass; ++i) + IPredictorProducing[] predictors = new IPredictorProducing[_numberOfClasses]; + for (int i = 0; i < _numberOfClasses; ++i) { var pred = CreateBinaryPredictor(i, innerArgs); var cali = new PlattCalibrator(Host, -LightGbmTrainerOptions.Sigmoid, 0); @@ -216,10 +221,17 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) } } + private protected override void InitializeBeforeTraining() + { + _numberOfClassesIncludingNan = -1; + _numberOfClasses = 0; + } + private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels) { // Only initialize one time. - if (_numClass < 0) + + if (_numberOfClassesIncludingNan < 0) { float minLabel = float.MaxValue; float maxLabel = float.MinValue; @@ -241,21 +253,22 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat if (data.Schema.Label.Value.Type is KeyDataViewType keyType) { if (hasNaNLabel) - _numClass = keyType.GetCountAsInt32(Host) + 1; + _numberOfClassesIncludingNan = keyType.GetCountAsInt32(Host) + 1; else - _numClass = keyType.GetCountAsInt32(Host); - _tlcNumClass = keyType.GetCountAsInt32(Host); + _numberOfClassesIncludingNan = keyType.GetCountAsInt32(Host); + _numberOfClasses = keyType.GetCountAsInt32(Host); } else { if (hasNaNLabel) - _numClass = (int)maxLabel + 2; + _numberOfClassesIncludingNan = (int)maxLabel + 2; else - _numClass = (int)maxLabel + 1; - _tlcNumClass = (int)maxLabel + 1; + _numberOfClassesIncludingNan = (int)maxLabel + 1; + _numberOfClasses = (int)maxLabel + 1; } } - float defaultLabel = _numClass - 1; + + float defaultLabel = _numberOfClassesIncludingNan - 1; for (int i = 0; i < labels.Length; ++i) if (float.IsNaN(labels[i])) labels[i] = defaultLabel; @@ -265,7 +278,7 @@ private protected override void GetDefaultParameters(IChannel ch, int numRow, bo { base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true); int numberOfLeaves = (int)GbmOptions["num_leaves"]; - int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, _numClass); + int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, _numberOfClassesIncludingNan); GbmOptions["min_data_per_leaf"] = minimumExampleCountPerLeaf; if (!hiddenMsg) { @@ -282,8 +295,8 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel { Host.AssertValue(ch); ch.Assert(PredictionKind == PredictionKind.MulticlassClassification); - ch.Assert(_numClass > 1); - GbmOptions["num_class"] = _numClass; + ch.Assert(_numberOfClassesIncludingNan > 1); + GbmOptions["num_class"] = _numberOfClassesIncludingNan; bool useSoftmax = false; if (LightGbmTrainerOptions.UseSoftmax.HasValue) diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index fb6bbeefde..6a1c7558e7 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -286,9 +286,9 @@ private sealed class CategoricalMetaData /// the code is culture agnostic. When retrieving key value from this dictionary as string /// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]). /// - private protected Dictionary GbmOptions; + private protected readonly Dictionary GbmOptions; - private protected IParallel ParallelTraining; + private protected readonly IParallel ParallelTraining; // Store _featureCount and _trainedEnsemble to construct predictor. private protected int FeatureCount; @@ -335,11 +335,15 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, TOption Contracts.CheckUserArg(options.L2CategoricalRegularization >= 0.0, nameof(options.L2CategoricalRegularization), "must be >= 0."); LightGbmTrainerOptions = options; + ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); + GbmOptions = LightGbmTrainerOptions.ToDictionary(Host); InitParallelTraining(); } private protected override TModel TrainModelCore(TrainContext context) { + InitializeBeforeTraining(); + Host.CheckValue(context, nameof(context)); Dataset dtrain = null; @@ -371,11 +375,10 @@ private protected override TModel TrainModelCore(TrainContext context) return CreatePredictor(); } + private protected virtual void InitializeBeforeTraining(){} + private void InitParallelTraining() { - GbmOptions = LightGbmTrainerOptions.ToDictionary(Host); - ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); - if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1) { GbmOptions["tree_learner"] = ParallelTraining.ParallelType(); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 3e6d9e2769..c119e33c99 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -13,6 +13,7 @@ using Microsoft.ML.RunTests; using Microsoft.ML.Runtime; using Microsoft.ML.TestFramework.Attributes; +using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.LightGbm; using Microsoft.ML.Transforms; @@ -491,21 +492,23 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr using (var pch = (mlContext as IProgressChannelProvider).StartProgressChannel("Training LightGBM...")) { var host = (mlContext as IHostEnvironment).Register("Training LightGBM..."); - var gbmNative = WrappedLightGbmTraining.Train(ch, pch, gbmParams, gbmDataSet, numIteration: numberOfTrainingIterations); - int nativeLength = 0; - unsafe + using (var gbmNative = WrappedLightGbmTraining.Train(ch, pch, gbmParams, gbmDataSet, numIteration: numberOfTrainingIterations)) { - fixed (float* data = dataMatrix) - fixed (double* result0 = lgbmProbabilities) - fixed (double* result1 = lgbmRawScores) + int nativeLength = 0; + unsafe { - WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32, - _rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Normal, numberOfTrainingIterations, "", ref nativeLength, result0); - WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32, - _rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Raw, numberOfTrainingIterations, "", ref nativeLength, result1); + fixed (float* data = dataMatrix) + fixed (double* result0 = lgbmProbabilities) + fixed (double* result1 = lgbmRawScores) + { + WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32, + _rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Normal, numberOfTrainingIterations, "", ref nativeLength, result0); + WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32, + _rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Raw, numberOfTrainingIterations, "", ref nativeLength, result1); + } + modelString = gbmNative.GetModelString(); } - modelString = gbmNative.GetModelString(); } } } @@ -677,6 +680,67 @@ public void LightGbmMulticlassEstimatorCompareUnbalanced() Done(); } + private class DataPoint + { + public uint Label { get; set; } + + [VectorType(20)] + public float[] Features { get; set; } + } + + private static IEnumerable GenerateRandomDataPoints(int count, + int seed = 0, int numClasses = 3) + + { + var random = new Random(seed); + float randomFloat() => (float)(random.NextDouble() - 0.5); + for (int i = 0; i < count; i++) + { + var label = random.Next(1, numClasses + 1); + yield return new DataPoint + { + Label = (uint)label, + Features = Enumerable.Repeat(label, 20) + .Select(x => randomFloat() + label * 0.2f).ToArray() + }; + } + } + + [LightGBMFact] + public void LightGbmFitMoreThanOnce() + { + var mlContext = new MLContext(seed: 0); + + var pipeline = + mlContext.Transforms.Conversion + .MapValueToKey(nameof(DataPoint.Label)) + .Append(mlContext.MulticlassClassification.Trainers + .LightGbm()); + + var numClasses = 3; + var dataPoints = GenerateRandomDataPoints(100, numClasses:numClasses); + var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + var model = pipeline.Fit(trainingData); + var numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length; + Assert.Equal(numClasses, numOfSubParameters); + + numClasses = 4; + dataPoints = GenerateRandomDataPoints(100, numClasses: numClasses); + trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + model = pipeline.Fit(trainingData); + numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length; + Assert.Equal(numClasses, numOfSubParameters); + + numClasses = 2; + dataPoints = GenerateRandomDataPoints(100, numClasses: numClasses); + trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + model = pipeline.Fit(trainingData); + numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length; + Assert.Equal(numClasses, numOfSubParameters); + + Done(); + } + [LightGBMFact] public void LightGbmInDifferentCulture() {