Skip to content

Commit 36c75d9

Browse files
najeeb-kazmijustinormont
authored andcommitted
Adding benchmark test for PredictionEngine (#1014)
Adds a benchmark test to measure performance of doing many single predictions with PredictionEngine. Closes #1013
1 parent f6d850f commit 36c75d9

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

test/Microsoft.ML.Benchmarks/Microsoft.ML.Benchmarks.csproj

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
<None Include="..\data\wikipedia-detox-250-line-data.tsv" Link="Input\wikipedia-detox-250-line-data.tsv">
3737
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
3838
</None>
39+
<None Include="..\data\breast-cancer.txt" Link="Input\breast-cancer.txt">
40+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
41+
</None>
3942

4043
<BenchmarkFile Update="@(BenchmarkFile)">
4144
<Link>external\%(Identity)</Link>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using BenchmarkDotNet.Attributes;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Data;
8+
using Microsoft.ML.Runtime.Api;
9+
using Microsoft.ML.Runtime.Learners;
10+
11+
namespace Microsoft.ML.Benchmarks
12+
{
13+
[Config(typeof(PredictConfig))]
14+
public class PredictionEngineBench
15+
{
16+
private IrisData _irisExample;
17+
private PredictionFunction<IrisData, IrisPrediction> _irisModel;
18+
19+
private SentimentData _sentimentExample;
20+
private PredictionFunction<SentimentData, SentimentPrediction> _sentimentModel;
21+
22+
private BreastCancerData _breastCancerExample;
23+
private PredictionFunction<BreastCancerData, BreastCancerPrediction> _breastCancerModel;
24+
25+
[GlobalSetup(Target = nameof(MakeIrisPredictions))]
26+
public void SetupIrisPipeline()
27+
{
28+
_irisExample = new IrisData()
29+
{
30+
SepalLength = 3.3f,
31+
SepalWidth = 1.6f,
32+
PetalLength = 0.2f,
33+
PetalWidth = 5.1f,
34+
};
35+
36+
string _irisDataPath = Program.GetInvariantCultureDataPath("iris.txt");
37+
38+
using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance))
39+
{
40+
var reader = new TextLoader(env,
41+
new TextLoader.Arguments()
42+
{
43+
Separator = "\t",
44+
HasHeader = true,
45+
Column = new[]
46+
{
47+
new TextLoader.Column("Label", DataKind.R4, 0),
48+
new TextLoader.Column("SepalLength", DataKind.R4, 1),
49+
new TextLoader.Column("SepalWidth", DataKind.R4, 2),
50+
new TextLoader.Column("PetalLength", DataKind.R4, 3),
51+
new TextLoader.Column("PetalWidth", DataKind.R4, 4),
52+
}
53+
});
54+
55+
IDataView data = reader.Read(new MultiFileSource(_irisDataPath));
56+
57+
var pipeline = new ConcatEstimator(env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" })
58+
.Append(new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { NumThreads = 1, ConvergenceTolerance = 1e-2f }, "Features", "Label"));
59+
60+
var model = pipeline.Fit(data);
61+
62+
_irisModel = model.MakePredictionFunction<IrisData, IrisPrediction>(env);
63+
}
64+
}
65+
66+
[GlobalSetup(Target = nameof(MakeSentimentPredictions))]
67+
public void SetupSentimentPipeline()
68+
{
69+
_sentimentExample = new SentimentData()
70+
{
71+
SentimentText = "Not a big fan of this."
72+
};
73+
74+
string _sentimentDataPath = Program.GetInvariantCultureDataPath("wikipedia-detox-250-line-data.tsv");
75+
76+
using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance))
77+
{
78+
var reader = new TextLoader(env,
79+
new TextLoader.Arguments()
80+
{
81+
Separator = "\t",
82+
HasHeader = true,
83+
Column = new[]
84+
{
85+
new TextLoader.Column("Label", DataKind.BL, 0),
86+
new TextLoader.Column("SentimentText", DataKind.Text, 1)
87+
}
88+
});
89+
90+
IDataView data = reader.Read(new MultiFileSource(_sentimentDataPath));
91+
92+
var pipeline = new TextTransform(env, "SentimentText", "Features")
93+
.Append(new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1, ConvergenceTolerance = 1e-2f }, "Features", "Label"));
94+
95+
var model = pipeline.Fit(data);
96+
97+
_sentimentModel = model.MakePredictionFunction<SentimentData, SentimentPrediction>(env);
98+
}
99+
}
100+
101+
[GlobalSetup(Target = nameof(MakeBreastCancerPredictions))]
102+
public void SetupBreastCancerPipeline()
103+
{
104+
_breastCancerExample = new BreastCancerData()
105+
{
106+
Features = new[] { 5f, 1f, 1f, 1f, 2f, 1f, 3f, 1f, 1f }
107+
};
108+
109+
string _breastCancerDataPath = Program.GetInvariantCultureDataPath("breast-cancer.txt");
110+
111+
using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance))
112+
{
113+
var reader = new TextLoader(env,
114+
new TextLoader.Arguments()
115+
{
116+
Separator = "\t",
117+
HasHeader = false,
118+
Column = new[]
119+
{
120+
new TextLoader.Column("Label", DataKind.BL, 0),
121+
new TextLoader.Column("Features", DataKind.R4, new[] { new TextLoader.Range(1, 9) })
122+
}
123+
});
124+
125+
IDataView data = reader.Read(new MultiFileSource(_breastCancerDataPath));
126+
127+
var pipeline = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1, ConvergenceTolerance = 1e-2f }, "Features", "Label");
128+
129+
var model = pipeline.Fit(data);
130+
131+
_breastCancerModel = model.MakePredictionFunction<BreastCancerData, BreastCancerPrediction>(env);
132+
}
133+
}
134+
135+
[Benchmark]
136+
public void MakeIrisPredictions()
137+
{
138+
for (int i = 0; i < 10000; i++)
139+
{
140+
_irisModel.Predict(_irisExample);
141+
}
142+
}
143+
144+
[Benchmark]
145+
public void MakeSentimentPredictions()
146+
{
147+
for (int i = 0; i < 10000; i++)
148+
{
149+
_sentimentModel.Predict(_sentimentExample);
150+
}
151+
}
152+
153+
[Benchmark]
154+
public void MakeBreastCancerPredictions()
155+
{
156+
for (int i = 0; i < 10000; i++)
157+
{
158+
_breastCancerModel.Predict(_breastCancerExample);
159+
}
160+
}
161+
}
162+
163+
public class SentimentData
164+
{
165+
[ColumnName("Label")]
166+
public bool Sentiment;
167+
168+
public string SentimentText;
169+
}
170+
171+
public class SentimentPrediction
172+
{
173+
[ColumnName("PredictedLabel")]
174+
public bool Sentiment;
175+
176+
public float Score;
177+
}
178+
179+
public class BreastCancerData
180+
{
181+
[ColumnName("Label")]
182+
public bool Label;
183+
184+
[ColumnName("Features"), VectorType(9)]
185+
public float[] Features;
186+
}
187+
188+
public class BreastCancerPrediction
189+
{
190+
[ColumnName("Score")]
191+
public float Score;
192+
}
193+
}

0 commit comments

Comments
 (0)