Skip to content

Commit 925bdf0

Browse files
committed
The test pipeline for consuming an ONNX model would fail due to the
Score column being named "Score0". The ONNX model will rename the output columns by design, therefore a different class with the ColumnName of "Score0" is needed. This fixes the test pipeline to address this issue. Fixes dotnet#2981
1 parent 1724da8 commit 925bdf0

File tree

1 file changed

+8
-7
lines changed
  • test/Microsoft.ML.Functional.Tests

1 file changed

+8
-7
lines changed

test/Microsoft.ML.Functional.Tests/ONNX.cs

+8-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.IO;
6+
using Microsoft.ML.Data;
67
using Microsoft.ML.Functional.Tests.Datasets;
78
using Microsoft.ML.RunTests;
89
using Microsoft.ML.TestFramework;
@@ -14,6 +15,12 @@
1415

1516
namespace Microsoft.ML.Functional.Tests
1617
{
18+
internal sealed class OnnxScoreColumn
19+
{
20+
[ColumnName("Score0")]
21+
public float[] Score { get; set; }
22+
}
23+
1724
public class ONNX : BaseTestClass
1825
{
1926
public ONNX(ITestOutputHelper output) : base(output)
@@ -51,15 +58,9 @@ public void SaveOnnxModelLoadAndScoreFastTree()
5158
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath);
5259
var onnxModel = onnxEstimator.Fit(data);
5360

54-
// TODO #2980: ONNX outputs don't match the outputs of the model, so we must hand-correct this for now.
55-
// TODO #2981: ONNX models cannot be fit as part of a pipeline, so we must use a workaround like this.
56-
var onnxWorkaroundPipeline = onnxModel.Append(
57-
mlContext.Transforms.CopyColumns("Score", "Score0").Fit(onnxModel.Transform(data)));
58-
5961
// Create prediction engine and test predictions.
6062
var originalPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, ScoreColumn>(model);
61-
// TODO #2982: ONNX produces vector types and not the original output type.
62-
var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, VectorScoreColumn>(onnxWorkaroundPipeline);
63+
var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, OnnxScoreColumn>(onnxModel);
6364

6465
// Take a handful of examples out of the dataset and compute predictions.
6566
var dataEnumerator = mlContext.Data.CreateEnumerable<HousingRegression>(mlContext.Data.TakeRows(data, 5), false);

0 commit comments

Comments
 (0)