Skip to content

Commit c449625

Browse files
authored
ApplyOnnxModel sample (#3349)
* Rename to ApplyOnnxModel * simplify * fix comments * fix comments
1 parent 797e87d commit c449625

File tree

1 file changed

+15
-40
lines changed

1 file changed

+15
-40
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs renamed to docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ApplyOnnxModel.cs

+15-40
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,29 @@
22
using System.Linq;
33
using Microsoft.ML;
44
using Microsoft.ML.Data;
5-
using Microsoft.ML.OnnxRuntime;
65

76
namespace Samples.Dynamic
87
{
9-
public static class OnnxTransformExample
8+
public static class ApplyOnnxModel
109
{
11-
/// <summary>
12-
/// Example use of OnnxEstimator in an ML.NET pipeline
13-
/// </summary>
1410
public static void Example()
1511
{
1612
// Download the squeeznet image model from ONNX model zoo, version 1.2
1713
// https://github.com/onnx/models/tree/master/squeezenet or use
1814
// Microsoft.ML.Onnx.TestModels nuget.
1915
var modelPath = @"squeezenet\00000001\model.onnx";
2016

21-
// Inspect the model's inputs and outputs
22-
var session = new InferenceSession(modelPath);
23-
var inputInfo = session.InputMetadata.First();
24-
var outputInfo = session.OutputMetadata.First();
25-
Console.WriteLine($"Input Name is {String.Join(",", inputInfo.Key)}");
26-
Console.WriteLine($"Input Dimensions are {String.Join(",", inputInfo.Value.Dimensions)}");
27-
Console.WriteLine($"Output Name is {String.Join(",", outputInfo.Key)}");
28-
Console.WriteLine($"Output Dimensions are {String.Join(",", outputInfo.Value.Dimensions)}");
29-
// Results..
30-
// Input Name is data_0
31-
// Input Dimensions are 1,3,224,224
32-
// Output Name is softmaxout_1
33-
// Output Dimensions are 1,1000,1,1
34-
3517
// Create ML pipeline to score the data using OnnxScoringEstimator
3618
var mlContext = new MLContext();
37-
var data = GetTensorData();
38-
var idv = mlContext.Data.LoadFromEnumerable(data);
39-
var pipeline = mlContext.Transforms.ApplyOnnxModel(new[] { outputInfo.Key }, new[] { inputInfo.Key }, modelPath);
40-
41-
// Run the pipeline and get the transformed values
42-
var transformedValues = pipeline.Fit(idv).Transform(idv);
4319

20+
// Generate sample test data.
21+
var samples = GetTensorData();
22+
// Convert training data to IDataView, the general data type used in ML.NET.
23+
var data = mlContext.Data.LoadFromEnumerable(samples);
24+
// Create the pipeline to score using provided onnx model.
25+
var pipeline = mlContext.Transforms.ApplyOnnxModel(modelPath);
26+
// Fit the pipeline and get the transformed values
27+
var transformedValues = pipeline.Fit(data).Transform(data);
4428
// Retrieve model scores into Prediction class
4529
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedValues, reuseRowObject: false);
4630

@@ -66,25 +50,18 @@ public static void Example()
6650
// ----------
6751
}
6852

69-
/// <summary>
70-
/// inputSize is the overall dimensions of the model input tensor.
71-
/// </summary>
53+
// inputSize is the overall dimensions of the model input tensor.
7254
private const int inputSize = 224 * 224 * 3;
7355

74-
/// <summary>
75-
/// A class to hold sample tensor data. Member name should match
76-
/// the inputs that the model expects (in this case, data_0)
77-
/// </summary>
56+
// A class to hold sample tensor data. Member name should match
57+
// the inputs that the model expects (in this case, data_0)
7858
public class TensorData
7959
{
8060
[VectorType(inputSize)]
8161
public float[] data_0 { get; set; }
8262
}
8363

84-
/// <summary>
85-
/// Method to generate sample test data. Returns 2 sample rows.
86-
/// </summary>
87-
/// <returns></returns>
64+
// Method to generate sample test data. Returns 2 sample rows.
8865
public static TensorData[] GetTensorData()
8966
{
9067
// This can be any numerical data. Assume image pixel values.
@@ -93,10 +70,8 @@ public static TensorData[] GetTensorData()
9370
return new TensorData[] { new TensorData() { data_0 = image1 }, new TensorData() { data_0 = image2 } };
9471
}
9572

96-
/// <summary>
97-
/// Class to contain the output values from the transformation.
98-
/// This model generates a vector of 1000 floats.
99-
/// </summary>
73+
// Class to contain the output values from the transformation.
74+
// This model generates a vector of 1000 floats.
10075
class Prediction
10176
{
10277
[VectorType(1000)]

0 commit comments

Comments
 (0)