diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 9ef81f9047..ba015ebe6f 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -184,6 +184,7 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); + Host.Assert(Bindings.InfoCount >= 2); if (!ctx.ContainsColumn(DefaultColumnNames.Features)) return; @@ -197,15 +198,28 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx) for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo) outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo)); - //Check if "Probability" column was generated by the base class, only then - //label can be predicted. - if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2])) + /* If the probability column was generated, then the classification threshold is set to 0.5. Otherwise, + the predicted label is based on the sign of the score. + */ + string opType = "Binarizer"; + OnnxNode node; + var binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true); + + if (Bindings.InfoCount >= 3) { - string opType = "Binarizer"; - var node = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) }, - new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType)); + Host.Assert(ctx.ContainsColumn(outColumnNames[2])); + node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[2]), binarizerOutput, ctx.GetNodeName(opType)); node.AddAttribute("threshold", 0.5); } + else + { + node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType)); + node.AddAttribute("threshold", 0.0); + } + opType = "Cast"; + node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType(); + node.AddAttribute("to", t); } private protected override IDataTransform ApplyToDataCore(IHostEnvironment env, IDataView newSource) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index a84d18792c..77a3b15fbf 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3049,6 +3049,7 @@ private enum AggregateFunction private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); + Host.Check(Utils.Size(outputNames) >= 1); //Nodes. var nodesTreeids = new List<long>(); @@ -3111,7 +3112,8 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, } string opType = "TreeEnsembleRegressor"; - var node = ctx.CreateNode(opType, new[] { featureColumn }, outputNames, ctx.GetNodeName(opType)); + string scoreVarName = (Utils.Size(outputNames) == 2) ? outputNames[1] : outputNames[0]; // Get Score from PredictedLabel and/or Score columns + var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType)); node.AddAttribute("post_transform", PostTransform.None.GetDescription()); node.AddAttribute("n_targets", 1); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs index 1ab5d8b6be..d2dea6f60e 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs @@ -136,9 +136,10 @@ internal LinearModelParameters(IHostEnvironment env, string name, in VBuffer<flo private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); - Host.Check(Utils.Size(outputs) == 1); + Host.Check(Utils.Size(outputs) >= 1); string opType = "LinearRegressor"; - var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); + string scoreVarName = (Utils.Size(outputs) == 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns + var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} node.AddAttribute("post_transform", "NONE"); node.AddAttribute("targets", 1); diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index 3a56e96dd7..c7ddd56e30 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -416,7 +416,7 @@ "Probability" ], "output": [ - "PredictedLabel" + "BinarizerOutput" ], "name": "Binarizer", "opType": "Binarizer", @@ -429,6 +429,23 @@ ], "domain": "ai.onnx.ml" }, + { + "input": [ + "BinarizerOutput" + ], + "output": [ + "PredictedLabel" + ], + "name": "Cast1", + "opType": "Cast", + "attribute": [ + { + "name": "to", + "i": "9", + "type": "INT" + } + ] + }, { "input": [ "PredictedLabel" diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt index 5045719a2d..166b713fe7 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt @@ -742,7 +742,7 @@ "Probability" ], "output": [ - "PredictedLabel" + "BinarizerOutput" ], "name": "Binarizer", "opType": "Binarizer", @@ -755,6 +755,23 @@ ], "domain": "ai.onnx.ml" }, + { + "input": [ + "BinarizerOutput" + ], + "output": [ + "PredictedLabel" + ], + "name": "Cast1", + "opType": "Cast", + "attribute": [ + { + "name": "to", + "i": "9", + "type": "INT" + } + ] + }, { "input": [ "PredictedLabel" diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt index 36ae7a2fd3..d007cd95c0 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt @@ -369,7 +369,7 @@ "Probability" ], "output": [ - "PredictedLabel" + "BinarizerOutput" ], "name": "Binarizer", "opType": "Binarizer", @@ -382,6 +382,23 @@ ], "domain": "ai.onnx.ml" }, + { + "input": [ + "BinarizerOutput" + ], + "output": [ + "PredictedLabel" + ], + "name": "Cast1", + "opType": "Cast", + "attribute": [ + { + "name": "to", + "i": "9", + "type": "INT" + } + ] + }, { "input": [ "Label" diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 6c794904b7..0a61508491 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -134,6 +134,15 @@ private class BreastCancerMulticlassExample public float[] Features; } + private class BreastCancerBinaryClassification + { + [LoadColumn(0)] + public bool Label; + + [LoadColumn(2, 9), VectorType(8)] + public float[] Features; + } + [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline. Tracked by https://github.com/dotnet/machinelearning/issues/2087")] public void KmeansOnnxConversionTest() { @@ -202,14 +211,15 @@ public void RegressionTrainersOnnxConversionTest() List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>() { mlContext.Regression.Trainers.Sdca("Target","FeatureVector"), - mlContext.Regression.Trainers.Ols("Target","FeatureVector"), + mlContext.Regression.Trainers.Ols("Target","FeatureVector"), mlContext.Regression.Trainers.OnlineGradientDescent("Target","FeatureVector"), mlContext.Regression.Trainers.FastForest("Target", "FeatureVector"), mlContext.Regression.Trainers.FastTree("Target", "FeatureVector"), mlContext.Regression.Trainers.FastTreeTweedie("Target", "FeatureVector"), mlContext.Regression.Trainers.LbfgsPoissonRegression("Target", "FeatureVector"), }; - if (Environment.Is64BitProcess) { + if (Environment.Is64BitProcess) + { estimators.Add(mlContext.Regression.Trainers.LightGbm("Target", "FeatureVector")); } foreach (var estimator in estimators) @@ -232,7 +242,7 @@ public void RegressionTrainersOnnxConversionTest() CompareSelectedR4ScalarColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult, 3); // compare score results } // Compare the Onnx graph to a baseline if OnnxRuntime is not supported - else + else { var onnxFileName = $"{estimator.ToString()}.txt"; var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Regression", "Adult"); @@ -244,6 +254,58 @@ public void RegressionTrainersOnnxConversionTest() Done(); } + [Fact] + public void BinaryClassificationTrainersOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + string dataPath = GetDataPath("breast-cancer.txt"); + // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). + var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true); + List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>() + { + mlContext.BinaryClassification.Trainers.AveragedPerceptron(), + mlContext.BinaryClassification.Trainers.FastForest(), + mlContext.BinaryClassification.Trainers.FastTree(), + mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), + mlContext.BinaryClassification.Trainers.LinearSvm(), + mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(), + mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(), + mlContext.BinaryClassification.Trainers.SgdCalibrated(), + mlContext.BinaryClassification.Trainers.SgdNonCalibrated(), + mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(), + }; + if (Environment.Is64BitProcess) + { + estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm()); + } + + var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features"). + Append(mlContext.Transforms.NormalizeMinMax("Features")); + foreach (var estimator in estimators) + { + var pipeline = initialPipeline.Append(estimator); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + // Compare model scores produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + var onnxFileName = $"{estimator.ToString()}.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); + CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); + } + } + Done(); + } + private class DataPoint { [VectorType(3)] @@ -1225,7 +1287,8 @@ private void CreateDummyExamplesToMakeComplierHappy() var dummyExample = new BreastCancerFeatureVector() { Features = null }; var dummyExample1 = new BreastCancerCatFeatureExample() { Label = false, F1 = 0, F2 = "Amy" }; var dummyExample2 = new BreastCancerMulticlassExample() { Label = "Amy", Features = null }; - var dummyExample3 = new SmallSentimentExample() { Tokens = null }; + var dummyExample3 = new BreastCancerBinaryClassification() { Label = false, Features = null }; + var dummyExample4 = new SmallSentimentExample() { Tokens = null }; } private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right)