Skip to content

Commit b572614

Browse files
committed
Adding a sample for LightGbm Ranking
1 parent 685b064 commit b572614

File tree

7 files changed

+150
-3
lines changed

7 files changed

+150
-3
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGBMBinaryClassificationWithOptions.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.ML.LightGBM;
2-
using Microsoft.ML.Transforms.Categorical;
32
using static Microsoft.ML.LightGBM.Options;
43

54
namespace Microsoft.ML.Samples.Dynamic
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Microsoft.ML.Samples.Dynamic
6+
{
7+
public class LightGbmRanking
8+
{
9+
// This example requires installation of additional nuget package <a href="https://www.nuget.org/packages/Microsoft.ML.LightGBM/">Microsoft.ML.LightGBM</a>.
10+
public static void Example()
11+
{
12+
// Creating the ML.Net IHostEnvironment object, needed for the pipeline.
13+
var mlContext = new MLContext();
14+
15+
// Download and featurize the train and validation datasets.
16+
(var trainData, var validationData) = SamplesUtils.DatasetUtils.LoadFeaturizedMslrWeb10kTrainAndValidate(mlContext);
17+
18+
// Create the Estimator pipeline. For simplicity, we will train a small tree with 4 leaves and 2 boosting iterations.
19+
var pipeline = mlContext.Ranking.Trainers.LightGbm(
20+
labelColumn: "Label",
21+
featureColumn: "Features",
22+
groupIdColumn: "GroupId",
23+
numLeaves: 4,
24+
minDataPerLeaf: 10,
25+
learningRate: 0.1,
26+
numBoostRound: 2);
27+
28+
// Fit this Pipeline to the Training Data.
29+
var model = pipeline.Fit(trainData);
30+
31+
// Evaluate how the model is doing on the test data.
32+
var dataWithPredictions = model.Transform(validationData);
33+
34+
var metrics = mlContext.Ranking.Evaluate(dataWithPredictions, "Label", "GroupId");
35+
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
36+
37+
// Output:
38+
// DCG @N: 1.38, 3.11, 4.94
39+
// NDCG @N: 7.13, 10.12, 12.62
40+
}
41+
}
42+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using Microsoft.ML.LightGBM;
2+
using static Microsoft.ML.LightGBM.Options;
3+
4+
namespace Microsoft.ML.Samples.Dynamic
5+
{
6+
public class LightGbmRankingWithOptions
7+
{
8+
// This example requires installation of additional nuget package <a href="https://www.nuget.org/packages/Microsoft.ML.LightGBM/">Microsoft.ML.LightGBM</a>.
9+
public static void Example()
10+
{
11+
// Creating the ML.Net IHostEnvironment object, needed for the pipeline.
12+
var mlContext = new MLContext();
13+
14+
// Download and featurize the train and validation datasets.
15+
(var trainData, var validationData) = SamplesUtils.DatasetUtils.LoadFeaturizedMslrWeb10kTrainAndValidate(mlContext);
16+
17+
// Create the Estimator pipeline. For simplicity, we will train a small tree with 4 leaves and 2 boosting iterations.
18+
var pipeline = mlContext.Ranking.Trainers.LightGbm(
19+
new Options
20+
{
21+
LabelColumn = "Label",
22+
FeatureColumn = "Features",
23+
GroupIdColumn = "GroupId",
24+
NumLeaves = 4,
25+
MinDataPerLeaf = 10,
26+
LearningRate = 0.1,
27+
NumBoostRound = 2
28+
});
29+
30+
// Fit this Pipeline to the Training Data.
31+
var model = pipeline.Fit(trainData);
32+
33+
// Evaluate how the model is doing on the test data.
34+
var dataWithPredictions = model.Transform(validationData);
35+
36+
var metrics = mlContext.Ranking.Evaluate(dataWithPredictions, "Label", "GroupId");
37+
SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
38+
39+
// Output:
40+
// DCG @N: 1.38, 3.11, 4.94
41+
// NDCG @N: 7.13, 10.12, 12.62
42+
}
43+
}
44+
}

docs/samples/Microsoft.ML.Samples/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ internal static class Program
66
{
77
static void Main(string[] args)
88
{
9-
TakeRows.Example();
9+
LightGbmRanking.Example();
1010
}
1111
}
1212
}

src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public sealed class RankerMetrics
1818
///Array of discounted cumulative gains where i-th element represent DCG@i.
1919
/// <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain">Discounted Cumulative gain</a>
2020
/// is the sum of the gains, for all the instances i, normalized by the natural logarithm of the instance + 1.
21-
/// Note that unline the Wikipedia article, ML.Net uses the natural logarithm.
21+
/// Note that unlike the Wikipedia article, ML.Net uses the natural logarithm.
2222
/// <image src="https://github.com/dotnet/machinelearning/tree/master/docs/images/DCG.png"></image>
2323
/// </summary>
2424
public double[] Dcg { get; }

src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Linq;
23
using Microsoft.ML.Data;
34

45
namespace Microsoft.ML.SamplesUtils
@@ -35,5 +36,15 @@ public static void PrintMetrics(RegressionMetrics metrics)
3536
Console.WriteLine($"RMS: {metrics.Rms:F2}");
3637
Console.WriteLine($"RSquared: {metrics.RSquared:F2}");
3738
}
39+
40+
/// <summary>
41+
/// Pretty-print RankerMetrics objects.
42+
/// </summary>
43+
/// <param name="metrics">Ranker metrics.</param>
44+
public static void PrintMetrics(RankerMetrics metrics)
45+
{
46+
Console.WriteLine($"DCG@N: {string.Join(", ", metrics.Dcg.Select(d => Math.Round(d, 2)).ToArray())}");
47+
Console.WriteLine($"NDCG@N: {string.Join(", ", metrics.Ndcg.Select(d => Math.Round(d, 2)).ToArray())}");
48+
}
3849
}
3950
}

src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,57 @@ public static IDataView LoadFeaturizedAdultDataset(MLContext mlContext)
146146
return featurizedData;
147147
}
148148

149+
public static string DownloadMslrWeb10kTrain()
150+
{
151+
var fileName = "MSLRWeb10KTrain720kRows.tsv";
152+
if (!File.Exists(fileName))
153+
Download("https://tlcresources.blob.core.windows.net/datasets/MSLR-WEB10K/MSLR-WEB10K_Fold1.TRAIN.500MB_720k-rows.tsv", fileName);
154+
return fileName;
155+
}
156+
157+
public static string DownloadMslrWeb10kValidate()
158+
{
159+
var fileName = "MSLRWeb10KValidate240kRows.tsv";
160+
if (!File.Exists(fileName))
161+
Download("https://tlcresources.blob.core.windows.net/datasets/MSLR-WEB10K/MSLR-WEB10K_Fold1.VALIDATE.160MB_240k-rows.tsv", fileName);
162+
return fileName;
163+
}
164+
165+
public static (IDataView, IDataView) LoadFeaturizedMslrWeb10kTrainAndValidate(MLContext mlContext)
166+
{
167+
// Download the training and validation files.
168+
string trainDataFile = DownloadMslrWeb10kTrain();
169+
string validationDataFile = DownloadMslrWeb10kValidate();
170+
171+
// Create the reader to read the data.
172+
var reader = mlContext.Data.CreateTextLoader(
173+
columns: new[]
174+
{
175+
new TextLoader.Column("Label", DataKind.R4, 0),
176+
new TextLoader.Column("GroupId", DataKind.TX, 1),
177+
new TextLoader.Column("Features", DataKind.R4, new[] { new TextLoader.Range(2, 138) })
178+
}
179+
);
180+
181+
// Load the raw training and validation datasets.
182+
var trainData = reader.Read(trainDataFile);
183+
var validationData = reader.Read(validationDataFile);
184+
185+
// Create the featurization pipeline. First, hash the GroupId column.
186+
var pipeline = mlContext.Transforms.Conversion.Hash("GroupId")
187+
// Replace missing values in Features column with the default replacement value for its type.
188+
.Append(mlContext.Transforms.ReplaceMissingValues("Features"));
189+
190+
// Fit the pipeline on the training data.
191+
var fittedPipeline = pipeline.Fit(trainData);
192+
193+
// Use the fitted pipeline to transform the training and validation datasets.
194+
var transformedTrainData = fittedPipeline.Transform(trainData);
195+
var transformedValidationData = fittedPipeline.Transform(validationData);
196+
197+
return (transformedTrainData, transformedValidationData);
198+
}
199+
149200
/// <summary>
150201
/// Downloads the breast cancer dataset from the ML.NET repo.
151202
/// </summary>

0 commit comments

Comments
 (0)