Skip to content

Commit f93ab25

Browse files
Adds the ability to load a pre-trained LightGBM file and import it into ML.Net. (dotnet#6569)
* need to finish multiclass * multiclass * reverting test for now * reverting test for now * added test and fixed objective parsing * minor testing changes
1 parent 1c41ed4 commit f93ab25

9 files changed

+7707
-50
lines changed

src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
7+
using System.IO;
68
using Microsoft.ML;
79
using Microsoft.ML.Calibrators;
810
using Microsoft.ML.CommandLine;
@@ -228,6 +230,26 @@ internal LightGbmBinaryTrainer(IHostEnvironment env,
228230
{
229231
}
230232

233+
/// <summary>
234+
/// Initializes a new instance of <see cref="LightGbmBinaryTrainer"/>
235+
/// </summary>
236+
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
237+
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
238+
/// <param name="featureColumnName">The name of the feature column.</param>
239+
internal LightGbmBinaryTrainer(IHostEnvironment env,
240+
Stream lightGbmModel,
241+
string featureColumnName = DefaultColumnNames.Features)
242+
: base(env,
243+
LoadNameValue,
244+
new Options()
245+
{
246+
FeatureColumnName = featureColumnName,
247+
LightGbmModel = lightGbmModel
248+
},
249+
new SchemaShape.Column())
250+
{
251+
}
252+
231253
private protected override CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator> CreatePredictor()
232254
{
233255
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
@@ -241,11 +263,16 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
241263
{
242264
Host.AssertValue(ch);
243265
base.CheckDataValid(ch, data);
244-
var labelType = data.Schema.Label.Value.Type;
245-
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
266+
267+
// If using a pre-trained model file we don't need a label column
268+
if (LightGbmTrainerOptions.LightGbmModel == null)
246269
{
247-
throw ch.ExceptParam(nameof(data),
248-
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be unsigned int, boolean or float.");
270+
var labelType = data.Schema.Label.Value.Type;
271+
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
272+
{
273+
throw ch.ExceptParam(nameof(data),
274+
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be unsigned int, boolean or float.");
275+
}
249276
}
250277
}
251278

src/Microsoft.ML.LightGbm/LightGbmCatalog.cs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.IO;
56
using Microsoft.ML.Data;
67
using Microsoft.ML.Runtime;
78
using Microsoft.ML.Trainers.LightGbm;
@@ -67,6 +68,22 @@ public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.Regressi
6768
return new LightGbmRegressionTrainer(env, options);
6869
}
6970

71+
/// <summary>
72+
/// Create <see cref="LightGbmRegressionTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree regression.
73+
/// </summary>
74+
/// <param name="catalog">The <see cref="RegressionCatalog"/>.</param>
75+
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
76+
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
77+
public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
78+
Stream lightGbmModel,
79+
string featureColumnName = DefaultColumnNames.Features
80+
)
81+
{
82+
Contracts.CheckValue(catalog, nameof(catalog));
83+
var env = CatalogUtils.GetEnvironment(catalog);
84+
return new LightGbmRegressionTrainer(env, lightGbmModel, featureColumnName);
85+
}
86+
7087
/// <summary>
7188
/// Create <see cref="LightGbmBinaryTrainer"/>, which predicts a target using a gradient boosting decision tree binary classification.
7289
/// </summary>
@@ -119,6 +136,22 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi
119136
return new LightGbmBinaryTrainer(env, options);
120137
}
121138

139+
/// <summary>
140+
/// Create <see cref="LightGbmBinaryTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree binary classification.
141+
/// </summary>
142+
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
143+
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
144+
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
145+
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
146+
Stream lightGbmModel,
147+
string featureColumnName = DefaultColumnNames.Features
148+
)
149+
{
150+
Contracts.CheckValue(catalog, nameof(catalog));
151+
var env = CatalogUtils.GetEnvironment(catalog);
152+
return new LightGbmBinaryTrainer(env, lightGbmModel, featureColumnName);
153+
}
154+
122155
/// <summary>
123156
/// Create <see cref="LightGbmRankingTrainer"/>, which predicts a target using a gradient boosting decision tree ranking model.
124157
/// </summary>
@@ -174,6 +207,22 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer
174207
return new LightGbmRankingTrainer(env, options);
175208
}
176209

210+
/// <summary>
211+
/// Create <see cref="LightGbmRankingTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree ranking model.
212+
/// </summary>
213+
/// <param name="catalog">The <see cref="RankingCatalog"/>.</param>
214+
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
215+
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
216+
public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog,
217+
Stream lightGbmModel,
218+
string featureColumnName = DefaultColumnNames.Features
219+
)
220+
{
221+
Contracts.CheckValue(catalog, nameof(catalog));
222+
var env = CatalogUtils.GetEnvironment(catalog);
223+
return new LightGbmRankingTrainer(env, lightGbmModel, featureColumnName);
224+
}
225+
177226
/// <summary>
178227
/// Create <see cref="LightGbmMulticlassTrainer"/>, which predicts a target using a gradient boosting decision tree multiclass classification model.
179228
/// </summary>
@@ -225,5 +274,21 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa
225274
var env = CatalogUtils.GetEnvironment(catalog);
226275
return new LightGbmMulticlassTrainer(env, options);
227276
}
277+
278+
/// <summary>
279+
/// Create <see cref="LightGbmMulticlassTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree multiclass classification model.
280+
/// </summary>
281+
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
282+
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
283+
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
284+
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
285+
Stream lightGbmModel,
286+
string featureColumnName = DefaultColumnNames.Features
287+
)
288+
{
289+
Contracts.CheckValue(catalog, nameof(catalog));
290+
var env = CatalogUtils.GetEnvironment(catalog);
291+
return new LightGbmMulticlassTrainer(env, lightGbmModel, featureColumnName);
292+
}
228293
}
229294
}

src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.IO;
78
using System.Linq;
89
using Microsoft.ML;
910
using Microsoft.ML.Calibrators;
@@ -170,6 +171,26 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env,
170171
{
171172
}
172173

174+
/// <summary>
175+
/// Initializes a new instance of <see cref="LightGbmRankingTrainer"/>
176+
/// </summary>
177+
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
178+
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
179+
/// <param name="featureColumnName">The name of the feature column.</param>
180+
internal LightGbmMulticlassTrainer(IHostEnvironment env,
181+
Stream lightGbmModel,
182+
string featureColumnName = DefaultColumnNames.Features)
183+
: base(env,
184+
LoadNameValue,
185+
new Options()
186+
{
187+
FeatureColumnName = featureColumnName,
188+
LightGbmModel = lightGbmModel
189+
},
190+
new SchemaShape.Column())
191+
{
192+
}
193+
173194
private InternalTreeEnsemble GetBinaryEnsemble(int classID)
174195
{
175196
var res = new InternalTreeEnsemble();
@@ -213,11 +234,15 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
213234
{
214235
Host.AssertValue(ch);
215236
base.CheckDataValid(ch, data);
216-
var labelType = data.Schema.Label.Value.Type;
217-
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
237+
// If using a pre-trained model file we don't need a label or group column
238+
if (LightGbmTrainerOptions.LightGbmModel == null)
218239
{
219-
throw ch.ExceptParam(nameof(data),
220-
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be of unsigned int, boolean or float.");
240+
var labelType = data.Schema.Label.Value.Type;
241+
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
242+
{
243+
throw ch.ExceptParam(nameof(data),
244+
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be of unsigned int, boolean or float.");
245+
}
221246
}
222247
}
223248

@@ -227,6 +252,21 @@ private protected override void InitializeBeforeTraining()
227252
_numberOfClasses = 0;
228253
}
229254

255+
private protected override void AdditionalLoadPreTrainedModel(string modelText)
256+
{
257+
string[] lines = modelText.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
258+
// Jump to the "objective" value in the file. It's at the beginning.
259+
int i = 0;
260+
while (!lines[i].StartsWith("objective"))
261+
i++;
262+
263+
// Format in the file is objective=multiclass num_class:4
264+
var split = lines[i].Split(' ');
265+
_numberOfClassesIncludingNan = int.Parse(split[1].Split(':')[1]);
266+
_numberOfClasses = _numberOfClassesIncludingNan;
267+
}
268+
269+
230270
private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)
231271
{
232272
// Only initialize one time.
@@ -317,11 +357,14 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel
317357

318358
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
319359
{
320-
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
321-
Contracts.Assert(success);
360+
SchemaShape.Column labelCol = default;
361+
if (LightGbmTrainerOptions.LightGbmModel == null)
362+
{
363+
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out labelCol);
364+
Contracts.Assert(success);
365+
}
322366

323-
var metadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
324-
.Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
367+
var metadata = LightGbmTrainerOptions.LightGbmModel == null ? new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues).Concat(AnnotationUtils.GetTrainerOutputAnnotation())) : new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation());
325368
return new[]
326369
{
327370
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),

src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.IO;
78
using Microsoft.ML;
89
using Microsoft.ML.CommandLine;
910
using Microsoft.ML.Data;
@@ -215,27 +216,52 @@ internal LightGbmRankingTrainer(IHostEnvironment env,
215216
Host.CheckNonEmpty(rowGroupIdColumnName, nameof(rowGroupIdColumnName));
216217
}
217218

219+
/// <summary>
220+
/// Initializes a new instance of <see cref="LightGbmRankingTrainer"/>
221+
/// </summary>
222+
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
223+
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
224+
/// <param name="featureColumnName">The name of the feature column.</param>
225+
internal LightGbmRankingTrainer(IHostEnvironment env,
226+
Stream lightGbmModel,
227+
string featureColumnName = DefaultColumnNames.Features)
228+
: base(env,
229+
LoadNameValue,
230+
new Options()
231+
{
232+
FeatureColumnName = featureColumnName,
233+
LightGbmModel = lightGbmModel
234+
},
235+
new SchemaShape.Column())
236+
{
237+
}
238+
218239
private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
219240
{
220241
Host.AssertValue(ch);
221242
base.CheckDataValid(ch, data);
222-
// Check label types.
223-
var labelCol = data.Schema.Label.Value;
224-
var labelType = labelCol.Type;
225-
if (!(labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
226-
{
227-
throw ch.ExceptParam(nameof(data),
228-
$"Label column '{labelCol.Name}' is of type '{labelType.RawType}', but must be Key or Single.");
229-
}
230-
// Check group types.
231-
if (!data.Schema.Group.HasValue)
232-
throw ch.ExceptValue(nameof(data.Schema.Group), "Group column is missing.");
233-
var groupCol = data.Schema.Group.Value;
234-
var groupType = groupCol.Type;
235-
if (!(groupType == NumberDataViewType.UInt32 || groupType is KeyDataViewType))
243+
244+
// If using a pre-trained model file we don't need a label or group column
245+
if (LightGbmTrainerOptions.LightGbmModel == null)
236246
{
237-
throw ch.ExceptParam(nameof(data),
238-
$"Group column '{groupCol.Name}' is of type '{groupType.RawType}', but must be UInt32 or Key.");
247+
// Check label types.
248+
var labelCol = data.Schema.Label.Value;
249+
var labelType = labelCol.Type;
250+
if (!(labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
251+
{
252+
throw ch.ExceptParam(nameof(data),
253+
$"Label column '{labelCol.Name}' is of type '{labelType.RawType}', but must be Key or Single.");
254+
}
255+
// Check group types.
256+
if (!data.Schema.Group.HasValue)
257+
throw ch.ExceptValue(nameof(data.Schema.Group), "Group column is missing.");
258+
var groupCol = data.Schema.Group.Value;
259+
var groupType = groupCol.Type;
260+
if (!(groupType == NumberDataViewType.UInt32 || groupType is KeyDataViewType))
261+
{
262+
throw ch.ExceptParam(nameof(data),
263+
$"Group column '{groupCol.Name}' is of type '{groupType.RawType}', but must be UInt32 or Key.");
264+
}
239265
}
240266
}
241267

src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Generic;
6+
using System.IO;
67
using Microsoft.ML;
78
using Microsoft.ML.CommandLine;
89
using Microsoft.ML.Data;
@@ -192,6 +193,26 @@ internal LightGbmRegressionTrainer(IHostEnvironment env, Options options)
192193
{
193194
}
194195

196+
/// <summary>
197+
/// Initializes a new instance of <see cref="LightGbmRegressionTrainer"/>
198+
/// </summary>
199+
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
200+
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
201+
/// <param name="featureColumnName">The name of the feature column.</param>
202+
internal LightGbmRegressionTrainer(IHostEnvironment env,
203+
Stream lightGbmModel,
204+
string featureColumnName = DefaultColumnNames.Features)
205+
: base(env,
206+
LoadNameValue,
207+
new Options()
208+
{
209+
FeatureColumnName = featureColumnName,
210+
LightGbmModel = lightGbmModel
211+
},
212+
new SchemaShape.Column())
213+
{
214+
}
215+
195216
private protected override LightGbmRegressionModelParameters CreatePredictor()
196217
{
197218
Host.Check(TrainedEnsemble != null,
@@ -204,11 +225,16 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
204225
{
205226
Host.AssertValue(ch);
206227
base.CheckDataValid(ch, data);
207-
var labelType = data.Schema.Label.Value.Type;
208-
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
228+
229+
// If using a pre-trained model file we don't need a label column
230+
if (LightGbmTrainerOptions.LightGbmModel == null)
209231
{
210-
throw ch.ExceptParam(nameof(data),
211-
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be an unsigned int, boolean or float.");
232+
var labelType = data.Schema.Label.Value.Type;
233+
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
234+
{
235+
throw ch.ExceptParam(nameof(data),
236+
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be an unsigned int, boolean or float.");
237+
}
212238
}
213239
}
214240

0 commit comments

Comments
 (0)