-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Configurable Threshold for binary models #2969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,6 +256,14 @@ public IReadOnlyList<CrossValidationResult<CalibratedBinaryClassificationMetrics | |
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray(); | ||
} | ||
|
||
public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should put XML comments on all public members. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
where TModel : class | ||
{ | ||
if (model.Threshold == threshold) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
do you want to warn here? #WontFix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should provide the same warning that C# does when you have a variable like In reply to: 265862991 [](ancestors = 265862991) |
||
return model; | ||
return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Technically, Issue #2465 was that we should be able to set the (Note that in practice, we have a |
||
} | ||
|
||
/// <summary> | ||
/// The list of trainers for performing binary classification. | ||
/// </summary> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,30 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using Microsoft.ML.Calibrators; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Functional.Tests.Datasets; | ||
using Microsoft.ML.RunTests; | ||
using Microsoft.ML.TestFramework; | ||
using Microsoft.ML.Trainers; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
||
namespace Microsoft.ML.Functional.Tests | ||
{ | ||
public class PredictionScenarios | ||
public class PredictionScenarios : BaseTestClass | ||
{ | ||
public PredictionScenarios(ITestOutputHelper output) : base(output) | ||
{ | ||
} | ||
|
||
class Prediction | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Please add to |
||
{ | ||
public float Score { get; set; } | ||
public bool PredictedLabel { get; set; } | ||
} | ||
/// <summary> | ||
/// Reconfigurable predictions: The following should be possible: A user trains a binary classifier, | ||
/// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold | ||
|
@@ -19,36 +35,64 @@ public class PredictionScenarios | |
[Fact] | ||
public void ReconfigurablePrediction() | ||
{ | ||
var mlContext = new MLContext(seed: 789); | ||
|
||
// Get the dataset, create a train and test | ||
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), | ||
hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator) | ||
.Load(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); | ||
var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2); | ||
|
||
// Create a pipeline to train on the housing data | ||
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { | ||
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", | ||
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"}) | ||
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) | ||
.Append(mlContext.Regression.Trainers.Ols()); | ||
|
||
var model = pipeline.Fit(split.TrainSet); | ||
|
||
var scoredTest = model.Transform(split.TestSet); | ||
var metrics = mlContext.Regression.Evaluate(scoredTest); | ||
|
||
Common.AssertMetrics(metrics); | ||
|
||
// Todo #2465: Allow the setting of threshold and thresholdColumn for scoring. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thank you! #Resolved |
||
// This is no longer possible in the API | ||
//var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability); | ||
//var newScoredTest = newModel.Transform(pipeline.Transform(testData)); | ||
//var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest); | ||
// And the Threshold and ThresholdColumn properties are not settable. | ||
//var predictor = model.LastTransformer; | ||
//predictor.Threshold = 0.01; // Not possible | ||
var mlContext = new MLContext(seed: 1); | ||
|
||
var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we try not to load file everywhere? It will be faster to just use in-memory data. #WontFix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have standard test datasets saved to files that we use in tests. #ByDesign |
||
hasHeader: TestDatasets.Sentiment.fileHasHeader, | ||
separatorChar: TestDatasets.Sentiment.fileSeparator); | ||
|
||
// Create a training pipeline. | ||
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText") | ||
.AppendCacheCheckpoint(mlContext) | ||
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression( | ||
new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 })); | ||
|
||
// Train the model. | ||
var model = pipeline.Fit(data); | ||
var engine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(model); | ||
var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); | ||
// Score is 0.64 so predicted label is true. | ||
Assert.True(pr.PredictedLabel); | ||
Assert.True(pr.Score > 0); | ||
var transformers = new List<ITransformer>(); | ||
foreach (var transform in model) | ||
{ | ||
if (transform != model.LastTransformer) | ||
transformers.Add(transform); | ||
} | ||
transformers.Add(mlContext.BinaryClassification.ChangeModelThreshold(model.LastTransformer, 0.7f)); | ||
var newModel = new TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>(transformers.ToArray()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Would it be better to do an in-place change rather than making a whole new chain? The new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No. We rely upon The reason why they must be immutable is one of the "practical corollaries" to the IDataView design principles. We rely on several places on the Now then, nothing prevents someone from forming another transform |
||
var newEngine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(newModel); | ||
pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" }); | ||
// Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false. | ||
|
||
Assert.False(pr.PredictedLabel); | ||
Assert.False(pr.Score > 0.7); | ||
} | ||
|
||
[Fact] | ||
public void ReconfigurablePredictionNoPipeline() | ||
{ | ||
var mlContext = new MLContext(seed: 1); | ||
|
||
var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset()); | ||
var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression( | ||
new Trainers.LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }); | ||
var model = pipeline.Fit(data); | ||
var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f); | ||
var rnd = new Random(1); | ||
var randomDataPoint = TypeTestData.GetRandomInstance(rnd); | ||
var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model); | ||
var pr = engine.Predict(randomDataPoint); | ||
// Score is -1.38 so predicted label is false. | ||
Assert.False(pr.PredictedLabel); | ||
Assert.True(pr.Score <= 0); | ||
var newEngine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(newModel); | ||
pr = newEngine.Predict(randomDataPoint); | ||
// Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true. | ||
Assert.True(pr.PredictedLabel); | ||
Assert.True(pr.Score <= 0); | ||
} | ||
|
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs documentation