|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
5 | 5 | using System.IO;
|
| 6 | +using Microsoft.ML.Data; |
6 | 7 | using Microsoft.ML.Functional.Tests.Datasets;
|
7 | 8 | using Microsoft.ML.RunTests;
|
8 | 9 | using Microsoft.ML.TestFramework;
|
|
14 | 15 |
|
15 | 16 | namespace Microsoft.ML.Functional.Tests
|
16 | 17 | {
|
| 18 | + internal sealed class OnnxScoreColumn |
| 19 | + { |
| 20 | + [ColumnName("Score0")] |
| 21 | + public float[] Score { get; set; } |
| 22 | + } |
| 23 | + |
17 | 24 | public class ONNX : BaseTestClass
|
18 | 25 | {
|
19 | 26 | public ONNX(ITestOutputHelper output) : base(output)
|
@@ -51,15 +58,9 @@ public void SaveOnnxModelLoadAndScoreFastTree()
|
51 | 58 | var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath);
|
52 | 59 | var onnxModel = onnxEstimator.Fit(data);
|
53 | 60 |
|
54 |
| - // TODO #2980: ONNX outputs don't match the outputs of the model, so we must hand-correct this for now. |
55 |
| - // TODO #2981: ONNX models cannot be fit as part of a pipeline, so we must use a workaround like this. |
56 |
| - var onnxWorkaroundPipeline = onnxModel.Append( |
57 |
| - mlContext.Transforms.CopyColumns("Score", "Score0").Fit(onnxModel.Transform(data))); |
58 |
| - |
59 | 61 | // Create prediction engine and test predictions.
|
60 | 62 | var originalPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, ScoreColumn>(model);
|
61 |
| - // TODO #2982: ONNX produces vector types and not the original output type. |
62 |
| - var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, VectorScoreColumn>(onnxWorkaroundPipeline); |
| 63 | + var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, OnnxScoreColumn>(onnxModel); |
63 | 64 |
|
64 | 65 | // Take a handful of examples out of the dataset and compute predictions.
|
65 | 66 | var dataEnumerator = mlContext.Data.CreateEnumerable<HousingRegression>(mlContext.Data.TakeRows(data, 5), false);
|
|
0 commit comments