Skip to content

Update tests (using Iris dataset) to use new API #2008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
<ProjectReference Include="..\..\src\Microsoft.ML.StaticPipe\Microsoft.ML.StaticPipe.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.TensorFlow.StaticPipe\Microsoft.ML.TensorFlow.StaticPipe.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Legacy\Microsoft.ML.Legacy.csproj" />
Copy link
Member

@codemzs codemzs Jan 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](start = 4, length = 87)

Nice!! #Resolved

<ProjectReference Include="..\..\src\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj" />
<ProjectReference Include="..\Microsoft.ML.Predictor.Tests\Microsoft.ML.Predictor.Tests.csproj" />
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
Expand Down
93 changes: 35 additions & 58 deletions test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,45 @@
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;
using Microsoft.ML.Legacy.Models;
using Microsoft.ML.Legacy.Trainers;
using Microsoft.ML.Legacy.Transforms;
using Microsoft.ML.RunTests;
using Xunit;
using TextLoader = Microsoft.ML.Legacy.Data.TextLoader;

namespace Microsoft.ML.Scenarios
{
#pragma warning disable 612, 618
public partial class ScenariosTests
{
[Fact]
public void TrainAndPredictIrisModelTest()
{
string dataPath = GetDataPath("iris.txt");

var pipeline = new Legacy.LearningPipeline(seed: 1, conc: 1);

pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisData>(useHeader: false));
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));

pipeline.Add(new StochasticDualCoordinateAscentClassifier());

Legacy.PredictionModel<IrisData, IrisPrediction> model = pipeline.Train<IrisData, IrisPrediction>();

IrisPrediction prediction = model.Predict(new IrisData()
var mlContext = new MLContext(seed: 1, conc: 1);

var reader = mlContext.Data.CreateTextReader(columns: new[]
{
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("SepalLength", DataKind.R4, 1),
new TextLoader.Column("SepalWidth", DataKind.R4, 2),
new TextLoader.Column("PetalLength", DataKind.R4, 3),
new TextLoader.Column("PetalWidth", DataKind.R4, 4)
}
);

var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
.Append(mlContext.Transforms.Normalize("Features"))
.AppendCacheCheckpoint(mlContext)
Copy link
Member

@codemzs codemzs Jan 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we adding normalizer transform? just curious. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the new API the paradigm is to be explicit about whether are using Normalizers in the pipeline .

(In the old Legacy API normalizers used to be auto-inserted based on the algorithm)

Hence decided to be explicit about normalizers. Please see #1604 which talks about removing 'smarts' (i.e. no auto-normalization, no auto-caching and no auto-calibration) in the new API


In reply to: 245111952 [](ancestors = 245111952)

.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1));

// Read training and test data sets
string dataPath = GetDataPath(TestDatasets.iris.trainFilename);
string testDataPath = dataPath;
var trainData = reader.Read(dataPath);
var testData = reader.Read(testDataPath);

// Train the pipeline
var trainedModel = pipe.Fit(trainData);

// Make predictions
var predictFunction = trainedModel.CreatePredictionEngine<IrisData, IrisPrediction>(mlContext);
IrisPrediction prediction = predictFunction.Predict(new IrisData()
{
SepalLength = 5.1f,
SepalWidth = 3.3f,
Expand All @@ -41,7 +53,7 @@ public void TrainAndPredictIrisModelTest()
Assert.Equal(0, prediction.PredictedLabels[1], 2);
Assert.Equal(0, prediction.PredictedLabels[2], 2);

prediction = model.Predict(new IrisData()
prediction = predictFunction.Predict(new IrisData()
{
SepalLength = 6.4f,
SepalWidth = 3.1f,
Expand All @@ -53,7 +65,7 @@ public void TrainAndPredictIrisModelTest()
Assert.Equal(0, prediction.PredictedLabels[1], 2);
Assert.Equal(1, prediction.PredictedLabels[2], 2);

prediction = model.Predict(new IrisData()
prediction = predictFunction.Predict(new IrisData()
{
SepalLength = 4.4f,
SepalWidth = 3.1f,
Expand All @@ -65,53 +77,19 @@ public void TrainAndPredictIrisModelTest()
Assert.Equal(.8, prediction.PredictedLabels[1], 1);
Assert.Equal(0, prediction.PredictedLabels[2], 2);

// Note: Testing against the same data set as a simple way to test evaluation.
// This isn't appropriate in real-world scenarios.
string testDataPath = GetDataPath("iris.txt");
var testData = new TextLoader(testDataPath).CreateFrom<IrisData>(useHeader: false);

var evaluator = new ClassificationEvaluator();
evaluator.OutputTopKAcc = 3;
ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
// Evaluate the trained pipeline
var predicted = trainedModel.Transform(testData);
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topK: 3);

Assert.Equal(.98, metrics.AccuracyMacro);
Assert.Equal(.98, metrics.AccuracyMicro, 2);
Assert.Equal(.06, metrics.LogLoss, 2);
Assert.InRange(metrics.LogLossReduction, 94, 96);
Assert.Equal(1, metrics.TopKAccuracy);

Assert.Equal(3, metrics.PerClassLogLoss.Length);
Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);

ConfusionMatrix matrix = metrics.ConfusionMatrix;
Assert.Equal(3, matrix.Order);
Assert.Equal(3, matrix.ClassNames.Count);
Assert.Equal("0", matrix.ClassNames[0]);
Assert.Equal("1", matrix.ClassNames[1]);
Assert.Equal("2", matrix.ClassNames[2]);

Assert.Equal(50, matrix[0, 0]);
Assert.Equal(50, matrix["0", "0"]);
Assert.Equal(0, matrix[0, 1]);
Assert.Equal(0, matrix["0", "1"]);
Assert.Equal(0, matrix[0, 2]);
Assert.Equal(0, matrix["0", "2"]);

Assert.Equal(0, matrix[1, 0]);
Assert.Equal(0, matrix["1", "0"]);
Assert.Equal(48, matrix[1, 1]);
Assert.Equal(48, matrix["1", "1"]);
Assert.Equal(2, matrix[1, 2]);
Assert.Equal(2, matrix["1", "2"]);

Assert.Equal(0, matrix[2, 0]);
Assert.Equal(0, matrix["2", "0"]);
Assert.Equal(1, matrix[2, 1]);
Assert.Equal(1, matrix["2", "1"]);
Assert.Equal(49, matrix[2, 2]);
Assert.Equal(49, matrix["2", "2"]);
Copy link
Member

@codemzs codemzs Jan 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have confusion matrix in metrics? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see it inside metrics. Who would be the best person to confirm whether we support ConfusionMatrix in the new API ?

If we do, is there an example I can follow ? I do not see it in the CookBook either


In reply to: 245114114 [](ancestors = 245114114)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created issue #2009 to bring back support for ConfusionMatrix in the new API


In reply to: 245117612 [](ancestors = 245117612,245114114)

}

public class IrisData
Expand All @@ -138,6 +116,5 @@ public class IrisPrediction
public float[] PredictedLabels;
}
}
#pragma warning restore 612, 618
}

Original file line number Diff line number Diff line change
Expand Up @@ -3,87 +3,89 @@
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;
using Microsoft.ML.Legacy.Models;
using Microsoft.ML.Legacy.Trainers;
using Microsoft.ML.Legacy.Transforms;
using Xunit;
using TextLoader = Microsoft.ML.Legacy.Data.TextLoader;

namespace Microsoft.ML.Scenarios
{
#pragma warning disable 612, 618
public partial class ScenariosTests
{
[Fact]
public void TrainAndPredictIrisModelWithStringLabelTest()
{
var mlContext = new MLContext(seed: 1, conc: 1);

var reader = mlContext.Data.CreateTextReader(columns: new[]
{
new TextLoader.Column("SepalLength", DataKind.R4, 0),
new TextLoader.Column("SepalWidth", DataKind.R4, 1),
new TextLoader.Column("PetalLength", DataKind.R4, 2),
new TextLoader.Column("PetalWidth", DataKind.R4, 3),
new TextLoader.Column("IrisPlantType", DataKind.TX, 4),
},
separatorChar: ','
);

// Read training and test data sets
string dataPath = GetDataPath("iris.data");

var pipeline = new Legacy.LearningPipeline();

pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisDataWithStringLabel>(useHeader: false, separator: ','));

pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field.

pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));

pipeline.Add(new StochasticDualCoordinateAscentClassifier());

var model = pipeline.Train<IrisDataWithStringLabel, IrisPrediction>();
string[] scoreLabels;
model.TryGetScoreLabelNames(out scoreLabels);

Assert.NotNull(scoreLabels);
Assert.Equal(3, scoreLabels.Length);
Assert.Equal("Iris-setosa", scoreLabels[0]);
Assert.Equal("Iris-versicolor", scoreLabels[1]);
Assert.Equal("Iris-virginica", scoreLabels[2]);

IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel()
string testDataPath = dataPath;
var trainData = reader.Read(dataPath);
var testData = reader.Read(testDataPath);

// Create Estimator
var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
.Append(mlContext.Transforms.Normalize("Features"))
.Append(mlContext.Transforms.Conversion.MapValueToKey("IrisPlantType", "Label"), TransformerScope.TrainTest)
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent("Label", "Features", advancedSettings: s => s.NumThreads = 1))
.Append(mlContext.Transforms.Conversion.MapKeyToValue(("PredictedLabel", "Plant")));

// Train the pipeline
var trainedModel = pipe.Fit(trainData);

// Make predictions
var predictFunction = trainedModel.CreatePredictionEngine<IrisDataWithStringLabel, IrisPredictionWithStringLabel>(mlContext);
IrisPredictionWithStringLabel prediction = predictFunction.Predict(new IrisDataWithStringLabel()
{
SepalLength = 5.1f,
SepalWidth = 3.3f,
PetalLength = 1.6f,
PetalWidth = 0.2f,
});

Assert.Equal(1, prediction.PredictedLabels[0], 2);
Assert.Equal(0, prediction.PredictedLabels[1], 2);
Assert.Equal(0, prediction.PredictedLabels[2], 2);
Assert.Equal(1, prediction.PredictedScores[0], 2);
Assert.Equal(0, prediction.PredictedScores[1], 2);
Assert.Equal(0, prediction.PredictedScores[2], 2);
Assert.True(prediction.PredictedPlant == "Iris-setosa");

prediction = model.Predict(new IrisDataWithStringLabel()
prediction = predictFunction.Predict(new IrisDataWithStringLabel()
{
SepalLength = 6.4f,
SepalWidth = 3.1f,
PetalLength = 5.5f,
PetalWidth = 2.2f,
});

Assert.Equal(0, prediction.PredictedLabels[0], 2);
Assert.Equal(0, prediction.PredictedLabels[1], 2);
Assert.Equal(1, prediction.PredictedLabels[2], 2);
Assert.Equal(0, prediction.PredictedScores[0], 2);
Assert.Equal(0, prediction.PredictedScores[1], 2);
Assert.Equal(1, prediction.PredictedScores[2], 2);
Assert.True(prediction.PredictedPlant == "Iris-virginica");

prediction = model.Predict(new IrisDataWithStringLabel()
prediction = predictFunction.Predict(new IrisDataWithStringLabel()
{
SepalLength = 4.4f,
SepalWidth = 3.1f,
PetalLength = 2.5f,
PetalWidth = 1.2f,
});

Assert.Equal(.2, prediction.PredictedLabels[0], 1);
Assert.Equal(.8, prediction.PredictedLabels[1], 1);
Assert.Equal(0, prediction.PredictedLabels[2], 2);

// Note: Testing against the same data set as a simple way to test evaluation.
// This isn't appropriate in real-world scenarios.
string testDataPath = GetDataPath("iris.data");
var testData = new TextLoader(testDataPath).CreateFrom<IrisDataWithStringLabel>(useHeader: false, separator: ',');
Assert.Equal(.2, prediction.PredictedScores[0], 1);
Assert.Equal(.8, prediction.PredictedScores[1], 1);
Assert.Equal(0, prediction.PredictedScores[2], 2);
Assert.True(prediction.PredictedPlant == "Iris-versicolor");

var evaluator = new ClassificationEvaluator();
evaluator.OutputTopKAcc = 3;
ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
// Evaluate the trained pipeline
var predicted = trainedModel.Transform(testData);
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topK: 3);

Assert.Equal(.98, metrics.AccuracyMacro);
Assert.Equal(.98, metrics.AccuracyMicro, 2);
Expand All @@ -95,37 +97,9 @@ public void TrainAndPredictIrisModelWithStringLabelTest()
Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);

ConfusionMatrix matrix = metrics.ConfusionMatrix;
Assert.Equal(3, matrix.Order);
Assert.Equal(3, matrix.ClassNames.Count);
Assert.Equal("Iris-setosa", matrix.ClassNames[0]);
Assert.Equal("Iris-versicolor", matrix.ClassNames[1]);
Assert.Equal("Iris-virginica", matrix.ClassNames[2]);

Assert.Equal(50, matrix[0, 0]);
Assert.Equal(50, matrix["Iris-setosa", "Iris-setosa"]);
Assert.Equal(0, matrix[0, 1]);
Assert.Equal(0, matrix["Iris-setosa", "Iris-versicolor"]);
Assert.Equal(0, matrix[0, 2]);
Assert.Equal(0, matrix["Iris-setosa", "Iris-virginica"]);

Assert.Equal(0, matrix[1, 0]);
Assert.Equal(0, matrix["Iris-versicolor", "Iris-setosa"]);
Assert.Equal(48, matrix[1, 1]);
Assert.Equal(48, matrix["Iris-versicolor", "Iris-versicolor"]);
Assert.Equal(2, matrix[1, 2]);
Assert.Equal(2, matrix["Iris-versicolor", "Iris-virginica"]);

Assert.Equal(0, matrix[2, 0]);
Assert.Equal(0, matrix["Iris-virginica", "Iris-setosa"]);
Assert.Equal(1, matrix[2, 1]);
Assert.Equal(1, matrix["Iris-virginica", "Iris-versicolor"]);
Assert.Equal(49, matrix[2, 2]);
Assert.Equal(49, matrix["Iris-virginica", "Iris-virginica"]);
}

public class IrisDataWithStringLabel
private class IrisDataWithStringLabel
{
[LoadColumn(0)]
public float SepalLength;
Expand All @@ -139,9 +113,17 @@ public class IrisDataWithStringLabel
[LoadColumn(3)]
public float PetalWidth;

[LoadColumn(4), ColumnName("Label")]
public string IrisPlantType;
[LoadColumn(4)]
public string IrisPlantType { get; set; }
}

private class IrisPredictionWithStringLabel
{
[ColumnName("Score")]
public float[] PredictedScores { get; set; }

[ColumnName("Plant")]
public string PredictedPlant { get; set; }
}
}
#pragma warning restore 612, 618
}