Skip to content

Configurable Threshold for binary models #2969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
[BestFriend]
private protected ISchemaBindableMapper BindableMapper;
[BestFriend]
private protected DataViewSchema TrainSchema;
internal DataViewSchema TrainSchema;

/// <summary>
/// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
Expand Down
8 changes: 8 additions & 0 deletions src/Microsoft.ML.Data/TrainCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ public IReadOnlyList<CrossValidationResult<CalibratedBinaryClassificationMetrics
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
}

public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs documentation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should put XML comments on all public members.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where TModel : class
{
if (model.Threshold == threshold)
Copy link
Member

@sfilipi sfilipi Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (model.Threshold == threshold) [](start = 12, length = 33)

do you want to warn here? #WontFix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should provide the same warning that C# does when you have a variable like int a = 5 and then assign 5 to it later.


In reply to: 265862991 [](ancestors = 265862991)

return model;
return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn);
Copy link
Contributor

@rogancarr rogancarr Mar 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.ThresholdColumn [](start = 140, length = 21)

Technically, Issue #2465 was that we should be able to set the Threshold and ThresholdColumn properties of the BinaryPredictionTransformer. That said, I think that the actual use cases that we care about are changing the Threshold; the ThresholdColumn seems more important when we are creating a new BinaryPredictionTransformer. I actually don't think we need to modify that property. I'll update the issue that just setting a Threshold would be nice.

(Note that in practice, we have a Score column and a Probability column; modulo floating point error, these are a 1:1 mapping, so we will get the same results no matter which one we threshold on. Setting a custom threshold column seems more like a backdoor for crazy things: e.g. creating a custom score based on a few models and / or heuristics and then modifying a BinaryPredictionTransformer to score that column instead.) #Resolved

}

/// <summary>
/// The list of trainers for performing binary classification.
/// </summary>
Expand Down
106 changes: 75 additions & 31 deletions test/Microsoft.ML.Functional.Tests/Prediction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Functional.Tests.Datasets;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Trainers;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.Functional.Tests
{
public class PredictionScenarios
public class PredictionScenarios : BaseTestClass
{
public PredictionScenarios(ITestOutputHelper output) : base(output)
{
}

class Prediction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class Prediction [](start = 8, length = 16)

Please add to Datasets/CommonColumns.cs, and call it PredictionColumns.

{
public float Score { get; set; }
public bool PredictedLabel { get; set; }
}
/// <summary>
/// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
/// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
Expand All @@ -19,36 +35,64 @@ public class PredictionScenarios
[Fact]
public void ReconfigurablePrediction()
{
var mlContext = new MLContext(seed: 789);

// Get the dataset, create a train and test
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(),
hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator)
.Load(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2);

// Create a pipeline to train on the housing data
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
.Append(mlContext.Regression.Trainers.Ols());

var model = pipeline.Fit(split.TrainSet);

var scoredTest = model.Transform(split.TestSet);
var metrics = mlContext.Regression.Evaluate(scoredTest);

Common.AssertMetrics(metrics);

// Todo #2465: Allow the setting of threshold and thresholdColumn for scoring.
Copy link
Contributor

@rogancarr rogancarr Mar 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo #2465: [](start = 15, length = 11)

Thank you! #Resolved

// This is no longer possible in the API
//var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
//var newScoredTest = newModel.Transform(pipeline.Transform(testData));
//var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest);
// And the Threshold and ThresholdColumn properties are not settable.
//var predictor = model.LastTransformer;
//predictor.Threshold = 0.01; // Not possible
var mlContext = new MLContext(seed: 1);

var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename),
Copy link
Member

@wschin wschin Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try not to load file everywhere? It will be faster to just use in-memory data. #WontFix

Copy link
Contributor

@rogancarr rogancarr Mar 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have standard test datasets saved to files that we use in tests. #ByDesign

hasHeader: TestDatasets.Sentiment.fileHasHeader,
separatorChar: TestDatasets.Sentiment.fileSeparator);

// Create a training pipeline.
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }));

// Train the model.
var model = pipeline.Fit(data);
var engine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(model);
var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
// Score is 0.64 so predicted label is true.
Assert.True(pr.PredictedLabel);
Assert.True(pr.Score > 0);
var transformers = new List<ITransformer>();
foreach (var transform in model)
{
if (transform != model.LastTransformer)
transformers.Add(transform);
}
transformers.Add(mlContext.BinaryClassification.ChangeModelThreshold(model.LastTransformer, 0.7f));
var newModel = new TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>(transformers.ToArray());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>> [](start = 27, length = 126)

Would it be better to do an in-place change rather than making a whole new chain? The new TransformerChain that comes back after still has references to the previous objects anyways. The only new thing here is BinaryPredictionTransformer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to do an in-place change rather than making a whole new chain?

No. We rely upon ITransformers not being mutable objects in many, many places.

The reason why they must be immutable is one of the "practical corollaries" to the IDataView design principles. We rely on several places on the ITransformer implementors being consistent. (Not least inside the ITransformer implementations themselves!) We often form chains of ITransformers (pipelines, in other words), where each transform is the result of fitting to the result of the last. From a practical, user facing perspective, if these were suddenly to become mutable objects in their behavior w.r.t. transformation, the basic assumption that underlies why that is reliable at all would be compromised. So if you have a chain of transformers A, B, C, if you can suddenly mutate B's behavior, the assumptions under which C was fit no longer hold, in ways that are impossible to detect. (Transformers when fit often depend on the distributional behavior of their inputs.)

Now then, nothing prevents someone from forming another transform B' derived from B, taking the original A and C, and forming the chain A, B', and C. But I think this is understood to be a generally more risky operation. We have, and I think users have, I think a strong assumption that if ITransformer (including chains) works once, it shall continue to work of they don't do anything to it, something that would break if we were to allow them to be mutable.

var newEngine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(newModel);
pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
// Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false.

Assert.False(pr.PredictedLabel);
Assert.False(pr.Score > 0.7);
}

[Fact]
public void ReconfigurablePredictionNoPipeline()
{
var mlContext = new MLContext(seed: 1);

var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression(
new Trainers.LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 });
var model = pipeline.Fit(data);
var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f);
var rnd = new Random(1);
var randomDataPoint = TypeTestData.GetRandomInstance(rnd);
var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model);
var pr = engine.Predict(randomDataPoint);
// Score is -1.38 so predicted label is false.
Assert.False(pr.PredictedLabel);
Assert.True(pr.Score <= 0);
var newEngine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(newModel);
pr = newEngine.Predict(randomDataPoint);
// Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
Assert.True(pr.PredictedLabel);
Assert.True(pr.Score <= 0);
}

}
}