Skip to content

Commit d8d4c2e

Browse files
committed
Added tests for new API where components(Loaders/Transforms/Learners) are directly instantiated.
1 parent 62095cb commit d8d4c2e

File tree

2 files changed

+328
-0
lines changed

2 files changed

+328
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Models;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Api;
8+
using Microsoft.ML.Runtime.Data;
9+
using Microsoft.ML.Runtime.Learners;
10+
using Microsoft.ML.Runtime.Model;
11+
using System;
12+
using System.IO;
13+
using Xunit;
14+
15+
namespace Microsoft.ML.Scenarios
16+
{
17+
public partial class ScenariosTests
18+
{
19+
[Fact]
20+
public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
21+
{
22+
string dataPath = GetDataPath("iris.txt");
23+
string testDataPath = dataPath;
24+
25+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
26+
{
27+
// Pipeline
28+
var loader = new TextLoader(env,
29+
new TextLoader.Arguments()
30+
{
31+
HasHeader = false,
32+
Column = new[] {
33+
new TextLoader.Column()
34+
{
35+
Name = "Label",
36+
Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} },
37+
Type = DataKind.R4
38+
},
39+
new TextLoader.Column()
40+
{
41+
Name = "SepalLength",
42+
Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} },
43+
Type = DataKind.R4
44+
},
45+
new TextLoader.Column()
46+
{
47+
Name = "SepalWidth",
48+
Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} },
49+
Type = DataKind.R4
50+
},
51+
new TextLoader.Column()
52+
{
53+
Name = "PetalLength",
54+
Source = new [] { new TextLoader.Range() { Min = 3, Max = 3} },
55+
Type = DataKind.R4
56+
},
57+
new TextLoader.Column()
58+
{
59+
Name = "PetalWidth",
60+
Source = new [] { new TextLoader.Range() { Min = 4, Max = 4} },
61+
Type = DataKind.R4
62+
}
63+
}
64+
}, new MultiFileSource(dataPath));
65+
66+
IDataTransform trans = new ConcatTransform(env, loader, "Features",
67+
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
68+
69+
// Normalizer is not automatically added though the trainer has 'NormalizeFeatures' On/Auto
70+
trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features");
71+
72+
// Train
73+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments());
74+
75+
// Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
76+
var cached = new CacheDataView(env, trans, prefetch: null);
77+
var trainRoles = TrainUtils.CreateExamples(cached, label: "Label", feature: "Features");
78+
trainer.Train(trainRoles);
79+
80+
// Get scorer and evaluate the predictions from test data
81+
var pred = trainer.CreatePredictor();
82+
IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
83+
var metrics = Evaluate(env, testDataScorer);
84+
CompareMatrics(metrics);
85+
86+
// Create prediction engine and test predictions
87+
var model = env.CreatePredictionEngine<IrisData, IrisPrediction>(testDataScorer);
88+
ComparePredictions(model);
89+
90+
// Get feature importance i.e. weight vector
91+
var summary = ((MulticlassLogisticRegressionPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
92+
Assert.Equal(7.757867, Convert.ToDouble(summary[0].Value), 5);
93+
}
94+
}
95+
96+
private void ComparePredictions(PredictionEngine<IrisData, IrisPrediction> model)
97+
{
98+
IrisPrediction prediction = model.Predict(new IrisData()
99+
{
100+
SepalLength = 3.3f,
101+
SepalWidth = 1.6f,
102+
PetalLength = 0.2f,
103+
PetalWidth = 5.1f,
104+
});
105+
106+
Assert.Equal(1, prediction.PredictedLabels[0], 2);
107+
Assert.Equal(0, prediction.PredictedLabels[1], 2);
108+
Assert.Equal(0, prediction.PredictedLabels[2], 2);
109+
110+
prediction = model.Predict(new IrisData()
111+
{
112+
SepalLength = 3.1f,
113+
SepalWidth = 5.5f,
114+
PetalLength = 2.2f,
115+
PetalWidth = 6.4f,
116+
});
117+
118+
Assert.Equal(0, prediction.PredictedLabels[0], 2);
119+
Assert.Equal(0, prediction.PredictedLabels[1], 2);
120+
Assert.Equal(1, prediction.PredictedLabels[2], 2);
121+
122+
prediction = model.Predict(new IrisData()
123+
{
124+
SepalLength = 3.1f,
125+
SepalWidth = 2.5f,
126+
PetalLength = 1.2f,
127+
PetalWidth = 4.4f,
128+
});
129+
130+
Assert.Equal(.2, prediction.PredictedLabels[0], 1);
131+
Assert.Equal(.8, prediction.PredictedLabels[1], 1);
132+
Assert.Equal(0, prediction.PredictedLabels[2], 2);
133+
}
134+
135+
private void CompareMatrics(ClassificationMetrics metrics)
136+
{
137+
Assert.Equal(.98, metrics.AccuracyMacro);
138+
Assert.Equal(.98, metrics.AccuracyMicro, 2);
139+
Assert.Equal(.06, metrics.LogLoss, 2);
140+
Assert.InRange(metrics.LogLossReduction, 94, 96);
141+
Assert.Equal(1, metrics.TopKAccuracy);
142+
143+
Assert.Equal(3, metrics.PerClassLogLoss.Length);
144+
Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
145+
Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
146+
Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);
147+
148+
ConfusionMatrix matrix = metrics.ConfusionMatrix;
149+
Assert.Equal(3, matrix.Order);
150+
Assert.Equal(3, matrix.ClassNames.Count);
151+
Assert.Equal("0", matrix.ClassNames[0]);
152+
Assert.Equal("1", matrix.ClassNames[1]);
153+
Assert.Equal("2", matrix.ClassNames[2]);
154+
155+
Assert.Equal(50, matrix[0, 0]);
156+
Assert.Equal(50, matrix["0", "0"]);
157+
Assert.Equal(0, matrix[0, 1]);
158+
Assert.Equal(0, matrix["0", "1"]);
159+
Assert.Equal(0, matrix[0, 2]);
160+
Assert.Equal(0, matrix["0", "2"]);
161+
162+
Assert.Equal(0, matrix[1, 0]);
163+
Assert.Equal(0, matrix["1", "0"]);
164+
Assert.Equal(48, matrix[1, 1]);
165+
Assert.Equal(48, matrix["1", "1"]);
166+
Assert.Equal(2, matrix[1, 2]);
167+
Assert.Equal(2, matrix["1", "2"]);
168+
169+
Assert.Equal(0, matrix[2, 0]);
170+
Assert.Equal(0, matrix["2", "0"]);
171+
Assert.Equal(1, matrix[2, 1]);
172+
Assert.Equal(1, matrix["2", "1"]);
173+
Assert.Equal(49, matrix[2, 2]);
174+
Assert.Equal(49, matrix["2", "2"]);
175+
}
176+
177+
private ClassificationMetrics Evaluate(IHostEnvironment env, IDataView scoredData)
178+
{
179+
var dataEval = TrainUtils.CreateExamplesOpt(scoredData, label: "Label", feature: "Features");
180+
181+
// Evaluate.
182+
// It does not work. It throws error "Failed to find 'Score' column" when Evaluate is called
183+
//var evaluator = new MultiClassClassifierEvaluator(env, new MultiClassClassifierEvaluator.Arguments() { OutputTopKAcc = 3 });
184+
185+
var evaluator = new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments() { OutputTopKAcc = 3 });
186+
var metricsDic = evaluator.Evaluate(dataEval);
187+
188+
return ClassificationMetrics.FromMetrics(env, metricsDic["OverallMetrics"], metricsDic["ConfusionMatrix"])[0];
189+
}
190+
191+
private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transforms, IPredictor pred, string testDataPath = null)
192+
{
193+
using (var ch = env.Start("Saving model"))
194+
using (var memoryStream = new MemoryStream())
195+
{
196+
var trainRoles = TrainUtils.CreateExamples(transforms, label: "Label", feature: "Features");
197+
198+
// Model cannot be saved with CacheDataView
199+
TrainUtils.SaveModel(env, ch, memoryStream, pred, trainRoles);
200+
memoryStream.Position = 0;
201+
using (var rep = RepositoryReader.Open(memoryStream, ch))
202+
{
203+
IDataLoader testPipe = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(testDataPath), true);
204+
RoleMappedData testRoles = TrainUtils.CreateExamples(testPipe, label: "Label", feature: "Features");
205+
return ScoreUtils.GetScorer(pred, testRoles, env, testRoles.Schema);
206+
}
207+
}
208+
}
209+
}
210+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Models;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Api;
8+
using Microsoft.ML.Runtime.Data;
9+
using Microsoft.ML.Runtime.FastTree;
10+
using Microsoft.ML.Runtime.Internal.Calibration;
11+
using Microsoft.ML.Runtime.Model;
12+
using Microsoft.ML.Trainers;
13+
using Microsoft.ML.Transforms;
14+
using System.Collections.Generic;
15+
using System.IO;
16+
using System.Linq;
17+
using Xunit;
18+
19+
namespace Microsoft.ML.Scenarios
20+
{
21+
public partial class ScenariosTests
22+
{
23+
[Fact]
24+
public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
25+
{
26+
var dataPath = GetDataPath(SentimentDataPath);
27+
var testDataPath = GetDataPath(SentimentTestPath);
28+
29+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
30+
{
31+
// Pipeline
32+
var loader = new TextLoader(env,
33+
new TextLoader.Arguments()
34+
{
35+
Separator = "tab",
36+
HasHeader = true,
37+
Column = new[]
38+
{
39+
new TextLoader.Column()
40+
{
41+
Name = "Label",
42+
Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
43+
Type = DataKind.Num
44+
},
45+
46+
new TextLoader.Column()
47+
{
48+
Name = "SentimentText",
49+
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
50+
Type = DataKind.Text
51+
}
52+
}
53+
}, new MultiFileSource(dataPath));
54+
55+
var trans = TextTransform.Create(env, new TextTransform.Arguments()
56+
{
57+
Column = new TextTransform.Column
58+
{
59+
Name = "Features",
60+
Source = new[] { "SentimentText" }
61+
},
62+
KeepDiacritics = false,
63+
KeepPunctuations = false,
64+
TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
65+
OutputTokens = true,
66+
StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
67+
VectorNormalizer = TextTransform.TextNormKind.L2,
68+
CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false },
69+
WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 2, AllLengths = true },
70+
},
71+
loader);
72+
73+
// Train
74+
var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()
75+
{
76+
NumLeaves = 5,
77+
NumTrees = 5,
78+
MinDocumentsInLeafs = 2
79+
});
80+
81+
var trainRoles = TrainUtils.CreateExamples(trans, label: "Label", feature: "Features");
82+
trainer.Train(trainRoles);
83+
84+
// Get scorer and evaluate the predictions from test data
85+
var pred = trainer.CreatePredictor();
86+
IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
87+
var metrics = EvaluateBinary(env, testDataScorer);
88+
ValidateBinaryMetrics(metrics);
89+
90+
// Create prediction engine and test predictions
91+
var model = env.CreateBatchPredictionEngine<SentimentData, SentimentPrediction>(testDataScorer);
92+
var sentiments = GetTestData();
93+
var predictions = model.Predict(sentiments, false);
94+
Assert.Equal(2, predictions.Count());
95+
Assert.True(predictions.ElementAt(0).Sentiment.IsFalse);
96+
Assert.True(predictions.ElementAt(1).Sentiment.IsTrue);
97+
98+
// Get feature importance based on feature gain during training
99+
var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
100+
Assert.Equal(1.0, (double)summary[0].Value, 1);
101+
}
102+
}
103+
104+
private BinaryClassificationMetrics EvaluateBinary(IHostEnvironment env, IDataView scoredData)
105+
{
106+
var dataEval = TrainUtils.CreateExamplesOpt(scoredData, label: "Label", feature: "Features");
107+
108+
// Evaluate.
109+
// It does not work. It throws error "Failed to find 'Score' column" when Evaluate is called
110+
//var evaluator = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments());
111+
112+
var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments());
113+
var metricsDic = evaluator.Evaluate(dataEval);
114+
115+
return BinaryClassificationMetrics.FromMetrics(env, metricsDic["OverallMetrics"], metricsDic["ConfusionMatrix"])[0];
116+
}
117+
}
118+
}

0 commit comments

Comments
 (0)