Skip to content

Fixes #3878. About calling Fit more than once on Multiclass LightGBM trainer. #4608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ public static void Example()
.Evaluate(transformedTestData);

PrintMetrics(metrics);

// Expected output:
// Micro Accuracy: 0.99
// Macro Accuracy: 0.99
// Log Loss: 0.05
// Log Loss Reduction: 0.95

// Confusion table
// ||========================
// PREDICTED || 0 | 1 | 2 | Recall
Expand Down Expand Up @@ -142,4 +142,3 @@ public static void PrintMetrics(MulticlassClassificationMetrics metrics)
}
}
}

51 changes: 32 additions & 19 deletions src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,13 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase<LightGbmMult
private const int _minDataToUseSoftmax = 50000;

private const double _maxNumClass = 1e6;
private int _numClass;
private int _tlcNumClass;

// If there are NaN labels, they are converted to be equal to _numberOfClassesIncludingNan - 1.
// This is done because NaN labels are going to be seen as an extra different class, when training the model in the WrappedLightGbmTraining class
// But, when creating the Predictors, only _numberOfClasses is considered, ignoring the "extra class" of NaN labels.
private int _numberOfClassesIncludingNan;
private int _numberOfClasses;

private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;

/// <summary>
Expand Down Expand Up @@ -129,7 +134,7 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, Options options)
: base(env, LoadNameValue, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName))
{
Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0.");
_numClass = -1;
_numberOfClassesIncludingNan = -1;
}

/// <summary>
Expand Down Expand Up @@ -168,7 +173,7 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env,
private InternalTreeEnsemble GetBinaryEnsemble(int classID)
{
var res = new InternalTreeEnsemble();
for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numClass)
for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numberOfClassesIncludingNan)
{
// Ignore dummy trees.
if (TrainedEnsemble.GetTreeAt(i).NumLeaves > 1)
Expand All @@ -186,12 +191,12 @@ private protected override OneVersusAllModelParameters CreatePredictor()
{
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");

Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor.");
Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes.");
Host.Assert(_numberOfClassesIncludingNan > 1, "Must know the number of classes before creating a predictor.");
Host.Assert(TrainedEnsemble.NumTrees % _numberOfClassesIncludingNan == 0, "Number of trees should be a multiple of number of classes.");

var innerArgs = LightGbmInterfaceUtils.JoinParameters(GbmOptions);
IPredictorProducing<float>[] predictors = new IPredictorProducing<float>[_tlcNumClass];
for (int i = 0; i < _tlcNumClass; ++i)
IPredictorProducing<float>[] predictors = new IPredictorProducing<float>[_numberOfClasses];
for (int i = 0; i < _numberOfClasses; ++i)
{
var pred = CreateBinaryPredictor(i, innerArgs);
var cali = new PlattCalibrator(Host, -LightGbmTrainerOptions.Sigmoid, 0);
Expand All @@ -216,10 +221,17 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
}
}

private protected override void InitializeBeforeTraining()
{
_numberOfClassesIncludingNan = -1;
_numberOfClasses = 0;
}

private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)
{
// Only initialize one time.
if (_numClass < 0)

if (_numberOfClassesIncludingNan < 0)
{
float minLabel = float.MaxValue;
float maxLabel = float.MinValue;
Expand All @@ -241,21 +253,22 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat
if (data.Schema.Label.Value.Type is KeyDataViewType keyType)
{
if (hasNaNLabel)
_numClass = keyType.GetCountAsInt32(Host) + 1;
_numberOfClassesIncludingNan = keyType.GetCountAsInt32(Host) + 1;
else
_numClass = keyType.GetCountAsInt32(Host);
_tlcNumClass = keyType.GetCountAsInt32(Host);
_numberOfClassesIncludingNan = keyType.GetCountAsInt32(Host);
_numberOfClasses = keyType.GetCountAsInt32(Host);
}
else
{
if (hasNaNLabel)
_numClass = (int)maxLabel + 2;
_numberOfClassesIncludingNan = (int)maxLabel + 2;
else
_numClass = (int)maxLabel + 1;
_tlcNumClass = (int)maxLabel + 1;
_numberOfClassesIncludingNan = (int)maxLabel + 1;
_numberOfClasses = (int)maxLabel + 1;
}
}
float defaultLabel = _numClass - 1;

float defaultLabel = _numberOfClassesIncludingNan - 1;
for (int i = 0; i < labels.Length; ++i)
if (float.IsNaN(labels[i]))
labels[i] = defaultLabel;
Expand All @@ -265,7 +278,7 @@ private protected override void GetDefaultParameters(IChannel ch, int numRow, bo
{
base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true);
int numberOfLeaves = (int)GbmOptions["num_leaves"];
int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, _numClass);
int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, _numberOfClassesIncludingNan);
GbmOptions["min_data_per_leaf"] = minimumExampleCountPerLeaf;
if (!hiddenMsg)
{
Expand All @@ -282,8 +295,8 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel
{
Host.AssertValue(ch);
ch.Assert(PredictionKind == PredictionKind.MulticlassClassification);
ch.Assert(_numClass > 1);
GbmOptions["num_class"] = _numClass;
ch.Assert(_numberOfClassesIncludingNan > 1);
GbmOptions["num_class"] = _numberOfClassesIncludingNan;
bool useSoftmax = false;

if (LightGbmTrainerOptions.UseSoftmax.HasValue)
Expand Down
13 changes: 8 additions & 5 deletions src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ private sealed class CategoricalMetaData
/// the code is culture agnostic. When retrieving key value from this dictionary as string
/// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]).
/// </summary>
private protected Dictionary<string, object> GbmOptions;
private protected readonly Dictionary<string, object> GbmOptions;

private protected IParallel ParallelTraining;
private protected readonly IParallel ParallelTraining;

// Store _featureCount and _trainedEnsemble to construct predictor.
private protected int FeatureCount;
Expand Down Expand Up @@ -335,11 +335,15 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, TOption
Contracts.CheckUserArg(options.L2CategoricalRegularization >= 0.0, nameof(options.L2CategoricalRegularization), "must be >= 0.");

LightGbmTrainerOptions = options;
ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer();
GbmOptions = LightGbmTrainerOptions.ToDictionary(Host);
InitParallelTraining();
}

private protected override TModel TrainModelCore(TrainContext context)
{
InitializeBeforeTraining();

Host.CheckValue(context, nameof(context));

Dataset dtrain = null;
Expand Down Expand Up @@ -371,11 +375,10 @@ private protected override TModel TrainModelCore(TrainContext context)
return CreatePredictor();
}

private protected virtual void InitializeBeforeTraining(){}

private void InitParallelTraining()
{
GbmOptions = LightGbmTrainerOptions.ToDictionary(Host);
ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer();

if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1)
{
GbmOptions["tree_learner"] = ParallelTraining.ParallelType();
Expand Down
86 changes: 75 additions & 11 deletions test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.RunTests;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Trainers.LightGbm;
using Microsoft.ML.Transforms;
Expand Down Expand Up @@ -491,21 +492,23 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr
using (var pch = (mlContext as IProgressChannelProvider).StartProgressChannel("Training LightGBM..."))
{
var host = (mlContext as IHostEnvironment).Register("Training LightGBM...");
var gbmNative = WrappedLightGbmTraining.Train(ch, pch, gbmParams, gbmDataSet, numIteration: numberOfTrainingIterations);

int nativeLength = 0;
unsafe
using (var gbmNative = WrappedLightGbmTraining.Train(ch, pch, gbmParams, gbmDataSet, numIteration: numberOfTrainingIterations))
{
fixed (float* data = dataMatrix)
fixed (double* result0 = lgbmProbabilities)
fixed (double* result1 = lgbmRawScores)
int nativeLength = 0;
unsafe
{
WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32,
_rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Normal, numberOfTrainingIterations, "", ref nativeLength, result0);
WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32,
_rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Raw, numberOfTrainingIterations, "", ref nativeLength, result1);
fixed (float* data = dataMatrix)
fixed (double* result0 = lgbmProbabilities)
fixed (double* result1 = lgbmRawScores)
{
WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32,
_rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Normal, numberOfTrainingIterations, "", ref nativeLength, result0);
WrappedLightGbmInterface.BoosterPredictForMat(gbmNative.Handle, (IntPtr)data, WrappedLightGbmInterface.CApiDType.Float32,
_rowNumber, _columnNumber, 1, (int)WrappedLightGbmInterface.CApiPredictType.Raw, numberOfTrainingIterations, "", ref nativeLength, result1);
}
modelString = gbmNative.GetModelString();
}
modelString = gbmNative.GetModelString();
}
}
}
Expand Down Expand Up @@ -677,6 +680,67 @@ public void LightGbmMulticlassEstimatorCompareUnbalanced()
Done();
}

private class DataPoint
{
public uint Label { get; set; }

[VectorType(20)]
public float[] Features { get; set; }
}

private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0, int numClasses = 3)

{
var random = new Random(seed);
float randomFloat() => (float)(random.NextDouble() - 0.5);
for (int i = 0; i < count; i++)
{
var label = random.Next(1, numClasses + 1);
yield return new DataPoint
{
Label = (uint)label,
Features = Enumerable.Repeat(label, 20)
.Select(x => randomFloat() + label * 0.2f).ToArray()
};
}
}

[LightGBMFact]
public void LightGbmFitMoreThanOnce()
{
var mlContext = new MLContext(seed: 0);

var pipeline =
mlContext.Transforms.Conversion
.MapValueToKey(nameof(DataPoint.Label))
.Append(mlContext.MulticlassClassification.Trainers
.LightGbm());

var numClasses = 3;
var dataPoints = GenerateRandomDataPoints(100, numClasses:numClasses);
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
var model = pipeline.Fit(trainingData);
var numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length;
Assert.Equal(numClasses, numOfSubParameters);

numClasses = 4;
dataPoints = GenerateRandomDataPoints(100, numClasses: numClasses);
trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
model = pipeline.Fit(trainingData);
numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length;
Assert.Equal(numClasses, numOfSubParameters);

numClasses = 2;
dataPoints = GenerateRandomDataPoints(100, numClasses: numClasses);
trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
model = pipeline.Fit(trainingData);
numOfSubParameters = (model.LastTransformer.Model as OneVersusAllModelParameters).SubModelParameters.Length;
Assert.Equal(numClasses, numOfSubParameters);

Done();
}

[LightGBMFact]
public void LightGbmInDifferentCulture()
{
Expand Down