Skip to content

Commit e5eef19

Browse files
committed
Internalization of TensorFlowUtils.cs and refactored TensorFlowCatalog.
1 parent b2587fa commit e5eef19

File tree

10 files changed

+84
-33
lines changed

10 files changed

+84
-33
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/ImageClassification.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static void Example()
2020
var idv = mlContext.Data.ReadFromEnumerable(data);
2121

2222
// Create a ML pipeline.
23-
var pipeline = mlContext.Transforms.ScoreTensorFlowModel(
23+
var pipeline = mlContext.Transforms.TensorFlow.ScoreTensorFlowModel(
2424
modelLocation,
2525
new[] { nameof(OutputScores.output) },
2626
new[] { nameof(TensorData.input) });

docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public static void Example()
4545
// Load the TensorFlow model once.
4646
// - Use it for quering the schema for input and output in the model
4747
// - Use it for prediction in the pipeline.
48-
var modelInfo = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation);
48+
var modelInfo = mlContext.Transforms.TensorFlow.LoadTensorFlowModel(modelLocation);
4949
var schema = modelInfo.GetModelSchema();
5050
var featuresType = (VectorType)schema["Features"].Type;
5151
Console.WriteLine("Name: {0}, Type: {1}, Shape: (-1, {2})", "Features", featuresType.ItemType.RawType, featuresType.Dimensions[0]);
@@ -72,7 +72,7 @@ public static void Example()
7272
var engine = mlContext.Transforms.Text.TokenizeWords("TokenizedWords", "Sentiment_Text")
7373
.Append(mlContext.Transforms.Conversion.ValueMap(lookupMap, "Words", "Ids", new[] { ("VariableLenghtFeatures", "TokenizedWords") }))
7474
.Append(mlContext.Transforms.CustomMapping(ResizeFeaturesAction, "Resize"))
75-
.Append(mlContext.Transforms.ScoreTensorFlowModel(modelInfo, new[] { "Prediction/Softmax" }, new[] { "Features" }))
75+
.Append(mlContext.Transforms.TensorFlow.ScoreTensorFlowModel(modelInfo, new[] { "Prediction/Softmax" }, new[] { "Features" }))
7676
.Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax")))
7777
.Fit(dataView)
7878
.CreatePredictionEngine<IMDBSentiment, OutputScores>(mlContext);

src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ public sealed class TransformsCatalog
3737
/// </summary>
3838
public FeatureSelectionTransforms FeatureSelection { get; }
3939

40+
/// <summary>
41+
/// List of operations for using TensorFlow model.
42+
/// </summary>
43+
public TensorFlowTransforms TensorFlow { get; }
44+
4045
internal TransformsCatalog(IHostEnvironment env)
4146
{
4247
Contracts.AssertValue(env);
@@ -47,6 +52,7 @@ internal TransformsCatalog(IHostEnvironment env)
4752
Text = new TextTransforms(this);
4853
Projection = new ProjectionTransforms(this);
4954
FeatureSelection = new FeatureSelectionTransforms(this);
55+
TensorFlow = new TensorFlowTransforms(this);
5056
}
5157

5258
public abstract class SubCatalogBase
@@ -109,5 +115,15 @@ internal FeatureSelectionTransforms(TransformsCatalog owner) : base(owner)
109115
{
110116
}
111117
}
118+
119+
/// <summary>
120+
/// The catalog of TensorFlow operations.
121+
/// </summary>
122+
public sealed class TensorFlowTransforms : SubCatalogBase
123+
{
124+
internal TensorFlowTransforms(TransformsCatalog owner) : base(owner)
125+
{
126+
}
127+
}
112128
}
113129
}

src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public static void Main(string[] args)
1717
return;
1818
}
1919

20-
foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(new MLContext(), args[0]))
20+
foreach (var (name, opType, type, inputs) in new MLContext().Transforms.TensorFlow.GetModelNodes(args[0]))
2121
{
2222
var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}";
2323
Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}");

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ public static class TensorFlowUtils
2121
/// Key to access operator's type (a string) in <see cref="DataViewSchema.Column.Metadata"/>.
2222
/// Its value describes the Tensorflow operator that produces this <see cref="DataViewSchema.Column"/>.
2323
/// </summary>
24-
public const string TensorflowOperatorTypeKind = "TensorflowOperatorType";
24+
internal const string TensorflowOperatorTypeKind = "TensorflowOperatorType";
2525
/// <summary>
2626
/// Key to access upstream operators' names (a string array) in <see cref="DataViewSchema.Column.Metadata"/>.
2727
/// Its value states operators that the associated <see cref="DataViewSchema.Column"/>'s generator depends on.
2828
/// </summary>
29-
public const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";
29+
internal const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";
3030

3131
internal static DataViewSchema GetModelSchema(IExceptionContext ectx, TFGraph graph, string opType = null)
3232
{
@@ -94,7 +94,7 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, TFGraph gr
9494
/// </summary>
9595
/// <param name="env">The environment to use.</param>
9696
/// <param name="modelPath">Model to load.</param>
97-
public static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
97+
internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
9898
{
9999
var model = LoadTensorFlowModel(env, modelPath);
100100
return GetModelSchema(env, model.Session.Graph);
@@ -109,7 +109,7 @@ public static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPa
109109
/// <param name="env">The environment to use.</param>
110110
/// <param name="modelPath">Model to load.</param>
111111
/// <returns></returns>
112-
public static IEnumerable<(string, string, DataViewType, string[])> GetModelNodes(IHostEnvironment env, string modelPath)
112+
internal static IEnumerable<(string, string, DataViewType, string[])> GetModelNodes(IHostEnvironment env, string modelPath)
113113
{
114114
var schema = GetModelSchema(env, modelPath);
115115

@@ -338,7 +338,7 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
338338
/// <param name="env">The environment to use.</param>
339339
/// <param name="modelPath">The model to load.</param>
340340
/// <returns></returns>
341-
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
341+
internal static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
342342
{
343343
var session = GetSession(env, modelPath);
344344
return new TensorFlowModelInfo(env, session, modelPath);

src/Microsoft.ML.TensorFlow/TensorFlowModelInfo.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace Microsoft.ML.Transforms
2020
/// </item>
2121
/// </list>
2222
/// </summary>
23-
public class TensorFlowModelInfo
23+
public sealed class TensorFlowModelInfo
2424
{
2525
internal TFSession Session { get; }
2626
public string ModelPath { get; }

src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Collections.Generic;
56
using Microsoft.Data.DataView;
67
using Microsoft.ML.Data;
78
using Microsoft.ML.Transforms;
9+
using Microsoft.ML.Transforms.TensorFlow;
810

911
namespace Microsoft.ML
1012
{
@@ -25,7 +27,7 @@ public static class TensorflowCatalog
2527
/// ]]>
2628
/// </format>
2729
/// </example>
28-
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog,
30+
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog,
2931
string modelLocation,
3032
string outputColumnName,
3133
string inputColumnName)
@@ -45,7 +47,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
4547
/// ]]>
4648
/// </format>
4749
/// </example>
48-
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog,
50+
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog,
4951
string modelLocation,
5052
string[] outputColumnNames,
5153
string[] inputColumnNames)
@@ -58,7 +60,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
5860
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param>
5961
/// <param name="inputColumnName"> The name of the model input.</param>
6062
/// <param name="outputColumnName">The name of the requested model output.</param>
61-
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog,
63+
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog,
6264
TensorFlowModelInfo tensorFlowModel,
6365
string outputColumnName,
6466
string inputColumnName)
@@ -78,7 +80,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
7880
/// ]]>
7981
/// </format>
8082
/// </example>
81-
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog,
83+
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog,
8284
TensorFlowModelInfo tensorFlowModel,
8385
string[] outputColumnNames,
8486
string[] inputColumnNames)
@@ -90,7 +92,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
9092
/// </summary>
9193
/// <param name="catalog">The transform's catalog.</param>
9294
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
93-
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
95+
public static TensorFlowEstimator TensorFlow(this TransformsCatalog.TensorFlowTransforms catalog,
9496
TensorFlowEstimator.Options options)
9597
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options);
9698

@@ -100,9 +102,42 @@ public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
100102
/// <param name="catalog">The transform's catalog.</param>
101103
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
102104
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param>
103-
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
105+
public static TensorFlowEstimator TensorFlow(this TransformsCatalog.TensorFlowTransforms catalog,
104106
TensorFlowEstimator.Options options,
105107
TensorFlowModelInfo tensorFlowModel)
106108
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel);
109+
110+
/// <summary>
111+
/// This method retrieves the information about the graph nodes of a TensorFlow model as an <see cref="DataViewSchema"/>.
112+
/// For every node in the graph that has an output type that is compatible with the types supported by
113+
/// <see cref="TensorFlowEstimator"/>, the output schema contains a column with the name of that node, and the
114+
/// type of its output (including the item type and the shape, if it is known). Every column also contains metadata
115+
/// of kind <see cref="TensorFlowUtils.TensorflowOperatorTypeKind"/>, indicating the operation type of the node, and if that node has inputs in the graph,
116+
/// it contains metadata of kind <see cref="TensorFlowUtils.TensorflowUpstreamOperatorsKind"/>, indicating the names of the input nodes.
117+
/// </summary>
118+
/// <param name="catalog">The transform's catalog.</param>
119+
/// <param name="modelLocation">Location of the TensorFlow model.</param>
120+
public static DataViewSchema GetModelSchema(this TransformsCatalog.TensorFlowTransforms catalog, string modelLocation)
121+
=> TensorFlowUtils.GetModelSchema(CatalogUtils.GetEnvironment(catalog), modelLocation);
122+
123+
/// <summary>
124+
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
125+
/// iterates over the columns of the <see cref="DataViewSchema"/> returned by <see cref="GetModelSchema(TransformsCatalog.TensorFlowTransforms, string)"/>,
126+
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
127+
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
128+
/// </summary>
129+
/// <param name="catalog">The transform's catalog.</param>
130+
/// <param name="modelLocation">Location of the TensorFlow model.</param>
131+
public static IEnumerable<(string, string, DataViewType, string[])> GetModelNodes(this TransformsCatalog.TensorFlowTransforms catalog, string modelLocation)
132+
=> TensorFlowUtils.GetModelNodes(CatalogUtils.GetEnvironment(catalog), modelLocation);
133+
134+
/// <summary>
135+
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
136+
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlow(TransformsCatalog.TensorFlowTransforms, TensorFlowEstimator.Options, TensorFlowModelInfo)"/>.
137+
/// </summary>
138+
/// <param name="catalog">The transform's catalog.</param>
139+
/// <param name="modelLocation">Location of the TensorFlow model.</param>
140+
public static TensorFlowModelInfo LoadTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog, string modelLocation)
141+
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation);
107142
}
108143
}

test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public void TensorFlowTransforCifarEndToEndTest()
3636
var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImageReal", "ImagePath"))
3737
.Append(new ImageResizingEstimator(mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal"))
3838
.Append(new ImagePixelExtractingEstimator(mlContext, "Input", "ImageCropped", interleave: true))
39-
.Append(mlContext.Transforms.ScoreTensorFlowModel(model_location, "Output", "Input"))
39+
.Append(mlContext.Transforms.TensorFlow.ScoreTensorFlowModel(model_location, "Output", "Input"))
4040
.Append(new ColumnConcatenatingEstimator(mlContext, "Features", "Output"))
4141
.Append(new ValueToKeyMappingEstimator(mlContext, "Label"))
4242
.AppendCacheCheckpoint(mlContext)

0 commit comments

Comments
 (0)