-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Update tests (using Iris dataset) to use new API #2008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,33 +3,45 @@ | |
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Legacy.Models; | ||
using Microsoft.ML.Legacy.Trainers; | ||
using Microsoft.ML.Legacy.Transforms; | ||
using Microsoft.ML.RunTests; | ||
using Xunit; | ||
using TextLoader = Microsoft.ML.Legacy.Data.TextLoader; | ||
|
||
namespace Microsoft.ML.Scenarios | ||
{ | ||
#pragma warning disable 612, 618 | ||
public partial class ScenariosTests | ||
{ | ||
[Fact] | ||
public void TrainAndPredictIrisModelTest() | ||
{ | ||
string dataPath = GetDataPath("iris.txt"); | ||
|
||
var pipeline = new Legacy.LearningPipeline(seed: 1, conc: 1); | ||
|
||
pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisData>(useHeader: false)); | ||
pipeline.Add(new ColumnConcatenator(outputColumn: "Features", | ||
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); | ||
|
||
pipeline.Add(new StochasticDualCoordinateAscentClassifier()); | ||
|
||
Legacy.PredictionModel<IrisData, IrisPrediction> model = pipeline.Train<IrisData, IrisPrediction>(); | ||
|
||
IrisPrediction prediction = model.Predict(new IrisData() | ||
var mlContext = new MLContext(seed: 1, conc: 1); | ||
|
||
var reader = mlContext.Data.CreateTextReader(columns: new[] | ||
{ | ||
new TextLoader.Column("Label", DataKind.R4, 0), | ||
new TextLoader.Column("SepalLength", DataKind.R4, 1), | ||
new TextLoader.Column("SepalWidth", DataKind.R4, 2), | ||
new TextLoader.Column("PetalLength", DataKind.R4, 3), | ||
new TextLoader.Column("PetalWidth", DataKind.R4, 4) | ||
} | ||
); | ||
|
||
var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") | ||
.Append(mlContext.Transforms.Normalize("Features")) | ||
.AppendCacheCheckpoint(mlContext) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we adding normalizer transform? just curious. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in the new API the paradigm is to be explicit about whether are using Normalizers in the pipeline . (In the old Legacy API normalizers used to be auto-inserted based on the algorithm) Hence decided to be explicit about normalizers. Please see #1604 which talks about removing 'smarts' (i.e. no auto-normalization, no auto-caching and no auto-calibration) in the new API In reply to: 245111952 [](ancestors = 245111952) |
||
.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1)); | ||
|
||
// Read training and test data sets | ||
string dataPath = GetDataPath(TestDatasets.iris.trainFilename); | ||
string testDataPath = dataPath; | ||
var trainData = reader.Read(dataPath); | ||
var testData = reader.Read(testDataPath); | ||
|
||
// Train the pipeline | ||
var trainedModel = pipe.Fit(trainData); | ||
|
||
// Make predictions | ||
var predictFunction = trainedModel.CreatePredictionEngine<IrisData, IrisPrediction>(mlContext); | ||
IrisPrediction prediction = predictFunction.Predict(new IrisData() | ||
{ | ||
SepalLength = 5.1f, | ||
SepalWidth = 3.3f, | ||
|
@@ -41,7 +53,7 @@ public void TrainAndPredictIrisModelTest() | |
Assert.Equal(0, prediction.PredictedLabels[1], 2); | ||
Assert.Equal(0, prediction.PredictedLabels[2], 2); | ||
|
||
prediction = model.Predict(new IrisData() | ||
prediction = predictFunction.Predict(new IrisData() | ||
{ | ||
SepalLength = 6.4f, | ||
SepalWidth = 3.1f, | ||
|
@@ -53,7 +65,7 @@ public void TrainAndPredictIrisModelTest() | |
Assert.Equal(0, prediction.PredictedLabels[1], 2); | ||
Assert.Equal(1, prediction.PredictedLabels[2], 2); | ||
|
||
prediction = model.Predict(new IrisData() | ||
prediction = predictFunction.Predict(new IrisData() | ||
{ | ||
SepalLength = 4.4f, | ||
SepalWidth = 3.1f, | ||
|
@@ -65,53 +77,19 @@ public void TrainAndPredictIrisModelTest() | |
Assert.Equal(.8, prediction.PredictedLabels[1], 1); | ||
Assert.Equal(0, prediction.PredictedLabels[2], 2); | ||
|
||
// Note: Testing against the same data set as a simple way to test evaluation. | ||
// This isn't appropriate in real-world scenarios. | ||
string testDataPath = GetDataPath("iris.txt"); | ||
var testData = new TextLoader(testDataPath).CreateFrom<IrisData>(useHeader: false); | ||
|
||
var evaluator = new ClassificationEvaluator(); | ||
evaluator.OutputTopKAcc = 3; | ||
ClassificationMetrics metrics = evaluator.Evaluate(model, testData); | ||
// Evaluate the trained pipeline | ||
var predicted = trainedModel.Transform(testData); | ||
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topK: 3); | ||
|
||
Assert.Equal(.98, metrics.AccuracyMacro); | ||
Assert.Equal(.98, metrics.AccuracyMicro, 2); | ||
Assert.Equal(.06, metrics.LogLoss, 2); | ||
Assert.InRange(metrics.LogLossReduction, 94, 96); | ||
Assert.Equal(1, metrics.TopKAccuracy); | ||
|
||
Assert.Equal(3, metrics.PerClassLogLoss.Length); | ||
Assert.Equal(0, metrics.PerClassLogLoss[0], 1); | ||
Assert.Equal(.1, metrics.PerClassLogLoss[1], 1); | ||
Assert.Equal(.1, metrics.PerClassLogLoss[2], 1); | ||
|
||
ConfusionMatrix matrix = metrics.ConfusionMatrix; | ||
Assert.Equal(3, matrix.Order); | ||
Assert.Equal(3, matrix.ClassNames.Count); | ||
Assert.Equal("0", matrix.ClassNames[0]); | ||
Assert.Equal("1", matrix.ClassNames[1]); | ||
Assert.Equal("2", matrix.ClassNames[2]); | ||
|
||
Assert.Equal(50, matrix[0, 0]); | ||
Assert.Equal(50, matrix["0", "0"]); | ||
Assert.Equal(0, matrix[0, 1]); | ||
Assert.Equal(0, matrix["0", "1"]); | ||
Assert.Equal(0, matrix[0, 2]); | ||
Assert.Equal(0, matrix["0", "2"]); | ||
|
||
Assert.Equal(0, matrix[1, 0]); | ||
Assert.Equal(0, matrix["1", "0"]); | ||
Assert.Equal(48, matrix[1, 1]); | ||
Assert.Equal(48, matrix["1", "1"]); | ||
Assert.Equal(2, matrix[1, 2]); | ||
Assert.Equal(2, matrix["1", "2"]); | ||
|
||
Assert.Equal(0, matrix[2, 0]); | ||
Assert.Equal(0, matrix["2", "0"]); | ||
Assert.Equal(1, matrix[2, 1]); | ||
Assert.Equal(1, matrix["2", "1"]); | ||
Assert.Equal(49, matrix[2, 2]); | ||
Assert.Equal(49, matrix["2", "2"]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we not have confusion matrix in metrics? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not see it inside If we do, is there an example I can follow ? I do not see it in the CookBook either In reply to: 245114114 [](ancestors = 245114114) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
|
||
public class IrisData | ||
|
@@ -138,6 +116,5 @@ public class IrisPrediction | |
public float[] PredictedLabels; | ||
} | ||
} | ||
#pragma warning restore 612, 618 | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!! #Resolved