Skip to content

Commit 3170ab0

Browse files
najeeb-kazmiZruty0
authored andcommitted
Adding prediction benchmarks using legacy LearningPipeline API (#1126)
See #1013 for the benchmark results
1 parent 3e3bd50 commit 3170ab0

File tree

2 files changed

+114
-3
lines changed

2 files changed

+114
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using BenchmarkDotNet.Attributes;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Api;
8+
using Microsoft.ML.Runtime.Learners;
9+
using Microsoft.ML.Legacy;
10+
using Microsoft.ML.Legacy.Data;
11+
using Microsoft.ML.Legacy.Transforms;
12+
using Microsoft.ML.Legacy.Trainers;
13+
14+
namespace Microsoft.ML.Benchmarks
15+
{
16+
public class LegacyPredictionEngineBench
17+
{
18+
private IrisData _irisExample;
19+
private PredictionModel<IrisData, IrisPrediction> _irisModel;
20+
21+
private SentimentData _sentimentExample;
22+
private PredictionModel<SentimentData, SentimentPrediction> _sentimentModel;
23+
24+
private BreastCancerData _breastCancerExample;
25+
private PredictionModel<BreastCancerData, BreastCancerPrediction> _breastCancerModel;
26+
27+
[GlobalSetup(Target = nameof(MakeIrisPredictions))]
28+
public void SetupIrisPipeline()
29+
{
30+
_irisExample = new IrisData()
31+
{
32+
SepalLength = 3.3f,
33+
SepalWidth = 1.6f,
34+
PetalLength = 0.2f,
35+
PetalWidth = 5.1f,
36+
};
37+
38+
string _irisDataPath = Program.GetInvariantCultureDataPath("iris.txt");
39+
40+
var pipeline = new LearningPipeline();
41+
pipeline.Add(new TextLoader(_irisDataPath).CreateFrom<IrisData>(useHeader: true, separator: '\t'));
42+
pipeline.Add(new ColumnConcatenator("Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" }));
43+
pipeline.Add(new StochasticDualCoordinateAscentClassifier() { NumThreads = 1, ConvergenceTolerance = 1e-2f });
44+
45+
_irisModel = pipeline.Train<IrisData, IrisPrediction>();
46+
}
47+
48+
[GlobalSetup(Target = nameof(MakeSentimentPredictions))]
49+
public void SetupSentimentPipeline()
50+
{
51+
_sentimentExample = new SentimentData()
52+
{
53+
SentimentText = "Not a big fan of this."
54+
};
55+
56+
string _sentimentDataPath = Program.GetInvariantCultureDataPath("wikipedia-detox-250-line-data.tsv");
57+
58+
var pipeline = new LearningPipeline();
59+
pipeline.Add(new TextLoader(_sentimentDataPath).CreateFrom<SentimentData>(useHeader: true, separator: '\t'));
60+
pipeline.Add(new TextFeaturizer("Features", "SentimentText"));
61+
pipeline.Add(new StochasticDualCoordinateAscentBinaryClassifier() { NumThreads = 1, ConvergenceTolerance = 1e-2f });
62+
63+
_sentimentModel = pipeline.Train<SentimentData, SentimentPrediction>();
64+
}
65+
66+
[GlobalSetup(Target = nameof(MakeBreastCancerPredictions))]
67+
public void SetupBreastCancerPipeline()
68+
{
69+
_breastCancerExample = new BreastCancerData()
70+
{
71+
Features = new[] { 5f, 1f, 1f, 1f, 2f, 1f, 3f, 1f, 1f }
72+
};
73+
74+
string _breastCancerDataPath = Program.GetInvariantCultureDataPath("breast-cancer.txt");
75+
76+
var pipeline = new LearningPipeline();
77+
pipeline.Add(new TextLoader(_breastCancerDataPath).CreateFrom<BreastCancerData>(useHeader: false, separator: '\t'));
78+
pipeline.Add(new StochasticDualCoordinateAscentBinaryClassifier() { NumThreads = 1, ConvergenceTolerance = 1e-2f });
79+
80+
_breastCancerModel = pipeline.Train<BreastCancerData, BreastCancerPrediction>();
81+
}
82+
83+
[Benchmark]
84+
public void MakeIrisPredictions()
85+
{
86+
for (int i = 0; i < 10000; i++)
87+
{
88+
_irisModel.Predict(_irisExample);
89+
}
90+
}
91+
92+
[Benchmark]
93+
public void MakeSentimentPredictions()
94+
{
95+
for (int i = 0; i < 10000; i++)
96+
{
97+
_sentimentModel.Predict(_sentimentExample);
98+
}
99+
}
100+
101+
[Benchmark]
102+
public void MakeBreastCancerPredictions()
103+
{
104+
for (int i = 0; i < 10000; i++)
105+
{
106+
_breastCancerModel.Predict(_breastCancerExample);
107+
}
108+
}
109+
}
110+
}

test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,10 @@ public void MakeBreastCancerPredictions()
161161

162162
public class SentimentData
163163
{
164-
[ColumnName("Label")]
164+
[ColumnName("Label"), Column("0")]
165165
public bool Sentiment;
166166

167+
[Column("1")]
167168
public string SentimentText;
168169
}
169170

@@ -177,10 +178,10 @@ public class SentimentPrediction
177178

178179
public class BreastCancerData
179180
{
180-
[ColumnName("Label")]
181+
[ColumnName("Label"), Column("0")]
181182
public bool Label;
182183

183-
[ColumnName("Features"), VectorType(9)]
184+
[ColumnName("Features"), Column("1-9"), VectorType(9)]
184185
public float[] Features;
185186
}
186187

0 commit comments

Comments
 (0)