From 5b9ed158147ef7d5b6e4297625bd5d3b08c418e8 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev <ivmatan@microsoft.com> Date: Thu, 14 Mar 2019 14:58:54 -0700 Subject: [PATCH 1/3] first step --- .../Scorers/PredictionTransformer.cs | 2 +- src/Microsoft.ML.Data/TrainCatalog.cs | 18 +++++ .../Prediction.cs | 69 ++++++++++--------- 3 files changed, 57 insertions(+), 32 deletions(-) diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index be6757bc36..f7deddf9d1 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -53,7 +53,7 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer [BestFriend] private protected ISchemaBindableMapper BindableMapper; [BestFriend] - private protected DataViewSchema TrainSchema; + internal DataViewSchema TrainSchema; /// <summary> /// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 6e8103f828..054200aacb 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML.Calibrators; @@ -274,6 +275,23 @@ public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValid Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } + public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold) where TModel : class + { + if (chain.LastTransformer.Threshold == threshold) + return chain; + List<ITransformer> transformers = new List<ITransformer>(); + var predictionTransformer = chain.LastTransformer; + foreach (var transform in chain) + { + if (transform != predictionTransformer) + transformers.Add(transform); + } + + var a = new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model, predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn, threshold, predictionTransformer.ThresholdColumn); + transformers.Add(a); + return new TransformerChain<BinaryPredictionTransformer<TModel>>(transformers.ToArray()); + } + /// <summary> /// The list of trainers for performing binary classification. /// </summary> diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs index 4605f953bd..74109dfa1b 100644 --- a/test/Microsoft.ML.Functional.Tests/Prediction.cs +++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs @@ -2,14 +2,25 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Functional.Tests.Datasets; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; using Xunit; +using Xunit.Abstractions; namespace Microsoft.ML.Functional.Tests { - public class PredictionScenarios + public class PredictionScenarios : BaseTestClass { + public PredictionScenarios(ITestOutputHelper output) : base(output) + { + } + + class Answer + { + public float Score { get; set; } + public bool PredictedLabel { get; set; } + } /// <summary> /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier, /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold @@ -19,36 +30,32 @@ public class PredictionScenarios [Fact] public void ReconfigurablePrediction() { - var mlContext = new MLContext(seed: 789); - - // Get the dataset, create a train and test - var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), - hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator) - .Load(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); - var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2); - - // Create a pipeline to train on the housing data - var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { - "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", - "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"}) - .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) - .Append(mlContext.Regression.Trainers.Ols()); - - var model = pipeline.Fit(split.TrainSet); - - var scoredTest = model.Transform(split.TestSet); - var metrics = mlContext.Regression.Evaluate(scoredTest); - - Common.AssertMetrics(metrics); - - // Todo #2465: Allow the setting of threshold and thresholdColumn for scoring. - // This is no longer possible in the API - //var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability); - //var newScoredTest = newModel.Transform(pipeline.Transform(testData)); - //var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest); - // And the Threshold and ThresholdColumn properties are not settable. - //var predictor = model.LastTransformer; - //predictor.Threshold = 0.01; // Not possible + var mlContext = new MLContext(seed: 1); + + var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename), + hasHeader: TestDatasets.Sentiment.fileHasHeader, + separatorChar: TestDatasets.Sentiment.fileSeparator); + + // Create a training pipeline. + var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText") + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.BinaryClassification.Trainers.LogisticRegression( + new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 })); + + // Train the model. + var model = pipeline.Fit(data); + var engine = model.CreatePredictionEngine<TweetSentiment, Answer>(mlContext); + var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); + // Score is 0.64 so predicted label is true. + Assert.True(pr.PredictedLabel); + Assert.True(pr.Score > 0); + var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, 0.7f); + var newEngine = newModel.CreatePredictionEngine<TweetSentiment, Answer>(mlContext); + pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); + // Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false. + + Assert.False(pr.PredictedLabel); + Assert.False(pr.Score > 0.7); } } } From 5dff4ee6a0a551ca1e5a94ecea6c433c94d0f66d Mon Sep 17 00:00:00 2001 From: Ivan Matantsev <ivmatan@microsoft.com> Date: Thu, 14 Mar 2019 16:33:21 -0700 Subject: [PATCH 2/3] and single case --- src/Microsoft.ML.Data/TrainCatalog.cs | 23 +++++++++++++--- .../Prediction.cs | 26 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 054200aacb..54370a88c9 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -275,7 +275,15 @@ public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValid Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } - public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold) where TModel : class + /// <summary> + /// Change threshold for binary model. + /// </summary> + /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam> + /// <param name="chain">Chain of transformers.</param> + /// <param name="threshold">New threshold.</param> + /// <returns></returns> + public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold) + where TModel : class { if (chain.LastTransformer.Threshold == threshold) return chain; @@ -287,11 +295,20 @@ public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshol transformers.Add(transform); } - var a = new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model, predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn, threshold, predictionTransformer.ThresholdColumn); - transformers.Add(a); + transformers.Add(new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model, + predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn, + threshold, predictionTransformer.ThresholdColumn)); return new TransformerChain<BinaryPredictionTransformer<TModel>>(transformers.ToArray()); } + public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold) + where TModel : class + { + if (model.Threshold == threshold) + return model; + return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumn, threshold, model.ThresholdColumn); + } + /// <summary> /// The list of trainers for performing binary classification. /// </summary> diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs index 74109dfa1b..82cb99a092 100644 --- a/test/Microsoft.ML.Functional.Tests/Prediction.cs +++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using Microsoft.ML.Functional.Tests.Datasets; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; @@ -57,5 +58,30 @@ public void ReconfigurablePrediction() Assert.False(pr.PredictedLabel); Assert.False(pr.Score > 0.7); } + + [Fact] + public void ReconfigurablePredictionNoPipeline() + { + var mlContext = new MLContext(seed: 1); + + var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset()); + var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression( + new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }); + var model = pipeline.Fit(data); + var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f); + var rnd = new Random(1); + var randomDataPoint = TypeTestData.GetRandomInstance(rnd); + var engine = model.CreatePredictionEngine<TypeTestData, Answer>(mlContext); + var pr = engine.Predict(randomDataPoint); + // Score is -1.38 so predicted label is false. + Assert.False(pr.PredictedLabel); + Assert.True(pr.Score <= 0); + var newEngine = newModel.CreatePredictionEngine<TypeTestData, Answer>(mlContext); + pr = newEngine.Predict(randomDataPoint); + // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true. + Assert.True(pr.PredictedLabel); + Assert.True(pr.Score <= 0); + } + } } From 142fa9206fdd007ddc1f2f583142d29addf95cc5 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev <ivmatan@microsoft.com> Date: Mon, 25 Mar 2019 10:37:37 -0700 Subject: [PATCH 3/3] address some comments --- src/Microsoft.ML.Data/TrainCatalog.cs | 28 +------------------ .../Prediction.cs | 27 ++++++++++++------ 2 files changed, 20 insertions(+), 35 deletions(-) diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 6c857e4a86..944f3dd9e5 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -256,38 +256,12 @@ public IReadOnlyList<CrossValidationResult<CalibratedBinaryClassificationMetrics Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray(); } - /// <summary> - /// Change threshold for binary model. - /// </summary> - /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam> - /// <param name="chain">Chain of transformers.</param> - /// <param name="threshold">New threshold.</param> - /// <returns></returns> - public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold) - where TModel : class - { - if (chain.LastTransformer.Threshold == threshold) - return chain; - List<ITransformer> transformers = new List<ITransformer>(); - var predictionTransformer = chain.LastTransformer; - foreach (var transform in chain) - { - if (transform != predictionTransformer) - transformers.Add(transform); - } - - transformers.Add(new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model, - predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn, - threshold, predictionTransformer.ThresholdColumn)); - return new TransformerChain<BinaryPredictionTransformer<TModel>>(transformers.ToArray()); - } - public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold) where TModel : class { if (model.Threshold == threshold) return model; - return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumn, threshold, model.ThresholdColumn); + return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn); } /// <summary> diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs index 82cb99a092..627e06e775 100644 --- a/test/Microsoft.ML.Functional.Tests/Prediction.cs +++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs @@ -3,9 +3,13 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using Microsoft.ML.Calibrators; +using Microsoft.ML.Data; using Microsoft.ML.Functional.Tests.Datasets; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; +using Microsoft.ML.Trainers; using Xunit; using Xunit.Abstractions; @@ -17,7 +21,7 @@ public PredictionScenarios(ITestOutputHelper output) : base(output) { } - class Answer + class Prediction { public float Score { get; set; } public bool PredictedLabel { get; set; } @@ -41,17 +45,24 @@ public void ReconfigurablePrediction() var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText") .AppendCacheCheckpoint(mlContext) .Append(mlContext.BinaryClassification.Trainers.LogisticRegression( - new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 })); + new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 })); // Train the model. var model = pipeline.Fit(data); - var engine = model.CreatePredictionEngine<TweetSentiment, Answer>(mlContext); + var engine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(model); var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); // Score is 0.64 so predicted label is true. Assert.True(pr.PredictedLabel); Assert.True(pr.Score > 0); - var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, 0.7f); - var newEngine = newModel.CreatePredictionEngine<TweetSentiment, Answer>(mlContext); + var transformers = new List<ITransformer>(); + foreach (var transform in model) + { + if (transform != model.LastTransformer) + transformers.Add(transform); + } + transformers.Add(mlContext.BinaryClassification.ChangeModelThreshold(model.LastTransformer, 0.7f)); + var newModel = new TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>(transformers.ToArray()); + var newEngine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(newModel); pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); // Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false. @@ -66,17 +77,17 @@ public void ReconfigurablePredictionNoPipeline() var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset()); var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression( - new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }); + new Trainers.LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }); var model = pipeline.Fit(data); var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f); var rnd = new Random(1); var randomDataPoint = TypeTestData.GetRandomInstance(rnd); - var engine = model.CreatePredictionEngine<TypeTestData, Answer>(mlContext); + var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model); var pr = engine.Predict(randomDataPoint); // Score is -1.38 so predicted label is false. Assert.False(pr.PredictedLabel); Assert.True(pr.Score <= 0); - var newEngine = newModel.CreatePredictionEngine<TypeTestData, Answer>(mlContext); + var newEngine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(newModel); pr = newEngine.Predict(randomDataPoint); // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true. Assert.True(pr.PredictedLabel);