From fe9f740a151eb405ef47f80addf9a11d04f239be Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 14 Jun 2018 12:33:32 -0700 Subject: [PATCH 1/4] add pipelineitem for Ova --- src/Microsoft.ML/Models/CrossValidator.cs | 6 +- src/Microsoft.ML/Models/OneVersusAll.cs | 74 +++++++++++++++++++ src/Microsoft.ML/Models/TrainTestEvaluator.cs | 6 +- .../Scenarios/IrisPlantClassificationTests.cs | 33 +++++++++ 4 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 src/Microsoft.ML/Models/OneVersusAll.cs diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index ab84f8a715..5f6e374e11 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -1,4 +1,8 @@ -using Microsoft.ML.Runtime; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; diff --git a/src/Microsoft.ML/Models/OneVersusAll.cs b/src/Microsoft.ML/Models/OneVersusAll.cs new file mode 100644 index 0000000000..2f0a265dad --- /dev/null +++ b/src/Microsoft.ML/Models/OneVersusAll.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; +using static Microsoft.ML.Runtime.EntryPoints.CommonInputs; + +namespace Microsoft.ML.Models +{ + public sealed partial class OneVersusAll + { + /// + /// Create OneVersusAll multiclass trainer. + /// + /// Underlying binary trainer + /// "Use probabilities (vs. raw outputs) to identify top-score category + public static ILearningPipelineItem With(ITrainerInputWithLabel trainer, bool useProbabilities = true) + { + return new OvaPipelineItem(trainer, useProbabilities); + } + + private class OvaPipelineItem : ILearningPipelineItem + { + private Var _data; + private ITrainerInputWithLabel _trainer; + private bool _useProbabilities; + + public OvaPipelineItem(ITrainerInputWithLabel trainer, bool useProbabilities) + { + _trainer = trainer; + _useProbabilities = useProbabilities; + } + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + using (var env = new TlcEnvironment()) + { + var subgraph = env.CreateExperiment(); + subgraph.Add(_trainer); + var ova = new OneVersusAll(); + if (previousStep != null) + { + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } + + _data = dataStep.Data; + ova.TrainingData = dataStep.Data; + ova.UseProbabilities = _useProbabilities; + ova.Nodes = subgraph; + } + Output output = experiment.Add(ova); + return new OvaPipelineStep(output); + } + } + + public Var GetInputData() => _data; + } + + private class OvaPipelineStep : ILearningPipelinePredictorStep + { + public OvaPipelineStep(Output output) + { + Model = output.PredictorModel; + } + + public Var Model { get; } + } + } +} diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs index ae00a34de6..b980774dbb 100644 --- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs +++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs @@ -1,4 +1,8 @@ -using Microsoft.ML.Runtime; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index 6612dfea69..ffd65c23fe 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -136,6 +136,39 @@ public class IrisPrediction [ColumnName("Score")] public float[] PredictedLabels; } + + [Fact] + public void TrainOva() + { + string dataPath = GetDataPath("iris.txt"); + + var pipeline = new LearningPipeline(seed: 1, conc: 1); + pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false)); + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + + pipeline.Add(OneVersusAll.With(new StochasticDualCoordinateAscentBinaryClassifier())); + + PredictionModel model = pipeline.Train(); + + var testData = new TextLoader(dataPath).CreateFrom(useHeader: false); + var evaluator = new ClassificationEvaluator(); + evaluator.OutputTopKAcc = 3; + ClassificationMetrics metrics = evaluator.Evaluate(model, testData); + CheckMetrics(metrics); + + var trainTest = new TrainTestEvaluator() { Kind = MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer }.TrainTestEvaluate(pipeline, testData); + CheckMetrics(trainTest.ClassificationMetrics); + } + + private void CheckMetrics(ClassificationMetrics metrics) + { + Assert.Equal(.96, metrics.AccuracyMacro); + Assert.Equal(.96, metrics.AccuracyMicro, 2); + Assert.Equal(.19, metrics.LogLoss, 2); + Assert.InRange(metrics.LogLossReduction, 80, 84); + Assert.Equal(1, metrics.TopKAccuracy); + } } } From 6cf3f0613bdc73ba37ff17d737133535612c18cb Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 14 Jun 2018 13:55:06 -0700 Subject: [PATCH 2/4] make tests work --- src/Microsoft.ML/Models/TrainTestEvaluator.cs | 16 ++++++++-------- .../Scenarios/IrisPlantClassificationTests.cs | 4 +--- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs index b980774dbb..e3de7a4e50 100644 --- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs +++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs @@ -114,7 +114,7 @@ public TrainTestEvaluatorOutput TrainTestEvaluate TrainTestEvaluate TrainTestEvaluate predictor; using (var memoryStream = new MemoryStream()) { diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index ffd65c23fe..4c19a81c78 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -149,11 +149,10 @@ public void TrainOva() pipeline.Add(OneVersusAll.With(new StochasticDualCoordinateAscentBinaryClassifier())); - PredictionModel model = pipeline.Train(); + var model = pipeline.Train(); var testData = new TextLoader(dataPath).CreateFrom(useHeader: false); var evaluator = new ClassificationEvaluator(); - evaluator.OutputTopKAcc = 3; ClassificationMetrics metrics = evaluator.Evaluate(model, testData); CheckMetrics(metrics); @@ -167,7 +166,6 @@ private void CheckMetrics(ClassificationMetrics metrics) Assert.Equal(.96, metrics.AccuracyMicro, 2); Assert.Equal(.19, metrics.LogLoss, 2); Assert.InRange(metrics.LogLossReduction, 80, 84); - Assert.Equal(1, metrics.TopKAccuracy); } } } From 40dcc26f28f391d19487ec5a2abc6adcdcfd764b Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 20 Jun 2018 13:09:17 -0700 Subject: [PATCH 3/4] address zeeshan comment --- .../Scenarios/IrisPlantClassificationTests.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index 4c19a81c78..406fc3d4ba 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -18,7 +18,7 @@ public void TrainAndPredictIrisModelTest() { string dataPath = GetDataPath("iris.txt"); - var pipeline = new LearningPipeline(seed:1, conc:1); + var pipeline = new LearningPipeline(seed: 1, conc: 1); pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false)); pipeline.Add(new ColumnConcatenator(outputColumn: "Features", @@ -33,7 +33,7 @@ public void TrainAndPredictIrisModelTest() SepalLength = 3.3f, SepalWidth = 1.6f, PetalLength = 0.2f, - PetalWidth= 5.1f, + PetalWidth = 5.1f, }); Assert.Equal(1, prediction.PredictedLabels[0], 2); @@ -148,21 +148,21 @@ public void TrainOva() "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); pipeline.Add(OneVersusAll.With(new StochasticDualCoordinateAscentBinaryClassifier())); - + var model = pipeline.Train(); var testData = new TextLoader(dataPath).CreateFrom(useHeader: false); var evaluator = new ClassificationEvaluator(); ClassificationMetrics metrics = evaluator.Evaluate(model, testData); CheckMetrics(metrics); - + var trainTest = new TrainTestEvaluator() { Kind = MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer }.TrainTestEvaluate(pipeline, testData); CheckMetrics(trainTest.ClassificationMetrics); } private void CheckMetrics(ClassificationMetrics metrics) { - Assert.Equal(.96, metrics.AccuracyMacro); + Assert.Equal(.96, metrics.AccuracyMacro, 2); Assert.Equal(.96, metrics.AccuracyMicro, 2); Assert.Equal(.19, metrics.LogLoss, 2); Assert.InRange(metrics.LogLossReduction, 80, 84); From 7a3ad19ec710e3cfaa15aa87375065f2b8089d51 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 20 Jun 2018 13:31:19 -0700 Subject: [PATCH 4/4] tweek presicion --- .../Scenarios/IrisPlantClassificationTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index 406fc3d4ba..696fb9e92d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -164,7 +164,7 @@ private void CheckMetrics(ClassificationMetrics metrics) { Assert.Equal(.96, metrics.AccuracyMacro, 2); Assert.Equal(.96, metrics.AccuracyMicro, 2); - Assert.Equal(.19, metrics.LogLoss, 2); + Assert.Equal(.19, metrics.LogLoss, 1); Assert.InRange(metrics.LogLossReduction, 80, 84); } }