From 5b9ed158147ef7d5b6e4297625bd5d3b08c418e8 Mon Sep 17 00:00:00 2001
From: Ivan Matantsev <ivmatan@microsoft.com>
Date: Thu, 14 Mar 2019 14:58:54 -0700
Subject: [PATCH 1/3] first step

---
 .../Scorers/PredictionTransformer.cs          |  2 +-
 src/Microsoft.ML.Data/TrainCatalog.cs         | 18 +++++
 .../Prediction.cs                             | 69 ++++++++++---------
 3 files changed, 57 insertions(+), 32 deletions(-)

diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
index be6757bc36..f7deddf9d1 100644
--- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
+++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
@@ -53,7 +53,7 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
         [BestFriend]
         private protected ISchemaBindableMapper BindableMapper;
         [BestFriend]
-        private protected DataViewSchema TrainSchema;
+        internal DataViewSchema TrainSchema;
 
         /// <summary>
         /// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs
index 6e8103f828..054200aacb 100644
--- a/src/Microsoft.ML.Data/TrainCatalog.cs
+++ b/src/Microsoft.ML.Data/TrainCatalog.cs
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 using System;
+using System.Collections.Generic;
 using System.Linq;
 using Microsoft.Data.DataView;
 using Microsoft.ML.Calibrators;
@@ -274,6 +275,23 @@ public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValid
                 Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
         }
 
+        public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold) where TModel : class
+        {
+            if (chain.LastTransformer.Threshold == threshold)
+                return chain;
+            List<ITransformer> transformers = new List<ITransformer>();
+            var predictionTransformer = chain.LastTransformer;
+            foreach (var transform in chain)
+            {
+                if (transform != predictionTransformer)
+                    transformers.Add(transform);
+            }
+
+            var a = new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model, predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn, threshold, predictionTransformer.ThresholdColumn);
+            transformers.Add(a);
+            return new TransformerChain<BinaryPredictionTransformer<TModel>>(transformers.ToArray());
+        }
+
         /// <summary>
         /// The list of trainers for performing binary classification.
         /// </summary>
diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs
index 4605f953bd..74109dfa1b 100644
--- a/test/Microsoft.ML.Functional.Tests/Prediction.cs
+++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs
@@ -2,14 +2,25 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using Microsoft.ML.Functional.Tests.Datasets;
 using Microsoft.ML.RunTests;
 using Microsoft.ML.TestFramework;
 using Xunit;
+using Xunit.Abstractions;
 
 namespace Microsoft.ML.Functional.Tests
 {
-    public class PredictionScenarios
+    public class PredictionScenarios : BaseTestClass
     {
+        public PredictionScenarios(ITestOutputHelper output) : base(output)
+        {
+        }
+
+        class Answer
+        {
+            public float Score { get; set; }
+            public bool PredictedLabel { get; set; }
+        }
         /// <summary>
         /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
         /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
@@ -19,36 +30,32 @@ public class PredictionScenarios
         [Fact]
         public void ReconfigurablePrediction()
         {
-            var mlContext = new MLContext(seed: 789);
-
-            // Get the dataset, create a train and test
-            var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(),
-                hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator)
-                .Load(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
-            var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2);
-
-            // Create a pipeline to train on the housing data
-            var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
-                    "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
-                    "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
-                .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
-                .Append(mlContext.Regression.Trainers.Ols());
-
-            var model = pipeline.Fit(split.TrainSet);
-
-            var scoredTest = model.Transform(split.TestSet);
-            var metrics = mlContext.Regression.Evaluate(scoredTest);
-
-            Common.AssertMetrics(metrics);
-
-            // Todo #2465: Allow the setting of threshold and thresholdColumn for scoring.
-            // This is no longer possible in the API
-            //var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
-            //var newScoredTest = newModel.Transform(pipeline.Transform(testData));
-            //var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest);
-            // And the Threshold and ThresholdColumn properties are not settable.
-            //var predictor = model.LastTransformer;
-            //predictor.Threshold = 0.01; // Not possible
+            var mlContext = new MLContext(seed: 1);
+
+            var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename),
+                hasHeader: TestDatasets.Sentiment.fileHasHeader,
+                separatorChar: TestDatasets.Sentiment.fileSeparator);
+
+            // Create a training pipeline.
+            var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
+                .AppendCacheCheckpoint(mlContext)
+                .Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
+                    new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
+
+            // Train the model.
+            var model = pipeline.Fit(data);
+            var engine = model.CreatePredictionEngine<TweetSentiment, Answer>(mlContext);
+            var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
+            // Score is 0.64 so predicted label is true.
+            Assert.True(pr.PredictedLabel);
+            Assert.True(pr.Score > 0);
+            var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, 0.7f);
+            var newEngine = newModel.CreatePredictionEngine<TweetSentiment, Answer>(mlContext);
+            pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
+            // Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false.
+
+            Assert.False(pr.PredictedLabel);
+            Assert.False(pr.Score > 0.7);
         }
     }
 }

From 5dff4ee6a0a551ca1e5a94ecea6c433c94d0f66d Mon Sep 17 00:00:00 2001
From: Ivan Matantsev <ivmatan@microsoft.com>
Date: Thu, 14 Mar 2019 16:33:21 -0700
Subject: [PATCH 2/3] and single case

---
 src/Microsoft.ML.Data/TrainCatalog.cs         | 23 +++++++++++++---
 .../Prediction.cs                             | 26 +++++++++++++++++++
 2 files changed, 46 insertions(+), 3 deletions(-)

diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs
index 054200aacb..54370a88c9 100644
--- a/src/Microsoft.ML.Data/TrainCatalog.cs
+++ b/src/Microsoft.ML.Data/TrainCatalog.cs
@@ -275,7 +275,15 @@ public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValid
                 Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
         }
 
-        public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold) where TModel : class
+        /// <summary>
+        /// Change threshold for binary model.
+        /// </summary>
+        /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
+        /// <param name="chain">Chain of transformers.</param>
+        /// <param name="threshold">New threshold.</param>
+        /// <returns></returns>
+        public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold)
+            where TModel : class
         {
             if (chain.LastTransformer.Threshold == threshold)
                 return chain;
@@ -287,11 +295,20 @@ public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshol
                     transformers.Add(transform);
             }
 
-            var a = new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model, predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn, threshold, predictionTransformer.ThresholdColumn);
-            transformers.Add(a);
+            transformers.Add(new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model,
+                predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn,
+                threshold, predictionTransformer.ThresholdColumn));
             return new TransformerChain<BinaryPredictionTransformer<TModel>>(transformers.ToArray());
         }
 
+        public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold)
+             where TModel : class
+        {
+            if (model.Threshold == threshold)
+                return model;
+            return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumn, threshold, model.ThresholdColumn);
+        }
+
         /// <summary>
         /// The list of trainers for performing binary classification.
         /// </summary>
diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs
index 74109dfa1b..82cb99a092 100644
--- a/test/Microsoft.ML.Functional.Tests/Prediction.cs
+++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System;
 using Microsoft.ML.Functional.Tests.Datasets;
 using Microsoft.ML.RunTests;
 using Microsoft.ML.TestFramework;
@@ -57,5 +58,30 @@ public void ReconfigurablePrediction()
             Assert.False(pr.PredictedLabel);
             Assert.False(pr.Score > 0.7);
         }
+
+        [Fact]
+        public void ReconfigurablePredictionNoPipeline()
+        {
+            var mlContext = new MLContext(seed: 1);
+
+            var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
+            var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression(
+                     new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 });
+            var model = pipeline.Fit(data);
+            var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f);
+            var rnd = new Random(1);
+            var randomDataPoint = TypeTestData.GetRandomInstance(rnd);
+            var engine = model.CreatePredictionEngine<TypeTestData, Answer>(mlContext);
+            var pr = engine.Predict(randomDataPoint);
+            // Score is -1.38 so predicted label is false.
+            Assert.False(pr.PredictedLabel);
+            Assert.True(pr.Score <= 0);
+            var newEngine = newModel.CreatePredictionEngine<TypeTestData, Answer>(mlContext);
+            pr = newEngine.Predict(randomDataPoint);
+            // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
+            Assert.True(pr.PredictedLabel);
+            Assert.True(pr.Score <= 0);
+        }
+
     }
 }

From 142fa9206fdd007ddc1f2f583142d29addf95cc5 Mon Sep 17 00:00:00 2001
From: Ivan Matantsev <ivmatan@microsoft.com>
Date: Mon, 25 Mar 2019 10:37:37 -0700
Subject: [PATCH 3/3] address some comments

---
 src/Microsoft.ML.Data/TrainCatalog.cs         | 28 +------------------
 .../Prediction.cs                             | 27 ++++++++++++------
 2 files changed, 20 insertions(+), 35 deletions(-)

diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs
index 6c857e4a86..944f3dd9e5 100644
--- a/src/Microsoft.ML.Data/TrainCatalog.cs
+++ b/src/Microsoft.ML.Data/TrainCatalog.cs
@@ -256,38 +256,12 @@ public IReadOnlyList<CrossValidationResult<CalibratedBinaryClassificationMetrics
                 Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
         }
 
-        /// <summary>
-        /// Change threshold for binary model.
-        /// </summary>
-        /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
-        /// <param name="chain">Chain of transformers.</param>
-        /// <param name="threshold">New threshold.</param>
-        /// <returns></returns>
-        public TransformerChain<BinaryPredictionTransformer<TModel>> ChangeModelThreshold<TModel>(TransformerChain<BinaryPredictionTransformer<TModel>> chain, float threshold)
-            where TModel : class
-        {
-            if (chain.LastTransformer.Threshold == threshold)
-                return chain;
-            List<ITransformer> transformers = new List<ITransformer>();
-            var predictionTransformer = chain.LastTransformer;
-            foreach (var transform in chain)
-            {
-                if (transform != predictionTransformer)
-                    transformers.Add(transform);
-            }
-
-            transformers.Add(new BinaryPredictionTransformer<TModel>(Environment, predictionTransformer.Model,
-                predictionTransformer.TrainSchema, predictionTransformer.FeatureColumn,
-                threshold, predictionTransformer.ThresholdColumn));
-            return new TransformerChain<BinaryPredictionTransformer<TModel>>(transformers.ToArray());
-        }
-
         public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold)
              where TModel : class
         {
             if (model.Threshold == threshold)
                 return model;
-            return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumn, threshold, model.ThresholdColumn);
+            return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn);
         }
 
         /// <summary>
diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs
index 82cb99a092..627e06e775 100644
--- a/test/Microsoft.ML.Functional.Tests/Prediction.cs
+++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs
@@ -3,9 +3,13 @@
 // See the LICENSE file in the project root for more information.
 
 using System;
+using System.Collections.Generic;
+using Microsoft.ML.Calibrators;
+using Microsoft.ML.Data;
 using Microsoft.ML.Functional.Tests.Datasets;
 using Microsoft.ML.RunTests;
 using Microsoft.ML.TestFramework;
+using Microsoft.ML.Trainers;
 using Xunit;
 using Xunit.Abstractions;
 
@@ -17,7 +21,7 @@ public PredictionScenarios(ITestOutputHelper output) : base(output)
         {
         }
 
-        class Answer
+        class Prediction
         {
             public float Score { get; set; }
             public bool PredictedLabel { get; set; }
@@ -41,17 +45,24 @@ public void ReconfigurablePrediction()
             var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
                 .AppendCacheCheckpoint(mlContext)
                 .Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
-                    new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
+                    new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }));
 
             // Train the model.
             var model = pipeline.Fit(data);
-            var engine = model.CreatePredictionEngine<TweetSentiment, Answer>(mlContext);
+            var engine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(model);
             var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
             // Score is 0.64 so predicted label is true.
             Assert.True(pr.PredictedLabel);
             Assert.True(pr.Score > 0);
-            var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, 0.7f);
-            var newEngine = newModel.CreatePredictionEngine<TweetSentiment, Answer>(mlContext);
+            var transformers = new List<ITransformer>();
+            foreach (var transform in model)
+            {
+                if (transform != model.LastTransformer)
+                    transformers.Add(transform);
+            }
+            transformers.Add(mlContext.BinaryClassification.ChangeModelThreshold(model.LastTransformer, 0.7f));
+            var newModel = new TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>(transformers.ToArray());
+            var newEngine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(newModel);
             pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
             // Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false.
 
@@ -66,17 +77,17 @@ public void ReconfigurablePredictionNoPipeline()
 
             var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
             var pipeline = mlContext.BinaryClassification.Trainers.LogisticRegression(
-                     new Trainers.LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 });
+                     new Trainers.LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 });
             var model = pipeline.Fit(data);
             var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f);
             var rnd = new Random(1);
             var randomDataPoint = TypeTestData.GetRandomInstance(rnd);
-            var engine = model.CreatePredictionEngine<TypeTestData, Answer>(mlContext);
+            var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model);
             var pr = engine.Predict(randomDataPoint);
             // Score is -1.38 so predicted label is false.
             Assert.False(pr.PredictedLabel);
             Assert.True(pr.Score <= 0);
-            var newEngine = newModel.CreatePredictionEngine<TypeTestData, Answer>(mlContext);
+            var newEngine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(newModel);
             pr = newEngine.Predict(randomDataPoint);
             // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
             Assert.True(pr.PredictedLabel);