Skip to content

Enabled option to get multiple outputs from TF graphs #814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Sep 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
21d2b48
Merge pull request #1 from dotnet/master
zeahmed May 23, 2018
405ffcf
Merge remote-tracking branch 'upstream/master'
zeahmed May 24, 2018
966b091
Merge remote-tracking branch 'upstream/master'
zeahmed May 29, 2018
0a8ec57
Merge remote-tracking branch 'upstream/master'
zeahmed May 30, 2018
a5e9a1d
Merge remote-tracking branch 'upstream/master'
zeahmed Jun 4, 2018
37959d0
Merge remote-tracking branch 'upstream/master'
zeahmed Jun 6, 2018
fba10a5
Merge remote-tracking branch 'upstream/master'
zeahmed Jun 29, 2018
bb1e63d
Merge remote-tracking branch 'upstream/master'
zeahmed Jul 17, 2018
740ad4b
Merge remote-tracking branch 'upstream/master'
zeahmed Jul 21, 2018
9f5c45d
Merge remote-tracking branch 'upstream/master'
zeahmed Aug 2, 2018
4bbf76a
Merge remote-tracking branch 'upstream/master'
zeahmed Aug 10, 2018
d6248be
Merge remote-tracking branch 'upstream/master'
zeahmed Aug 31, 2018
55683f7
Enabled option to get multiple outputs from TF graphs
zeahmed Sep 4, 2018
1826769
Merge remote-tracking branch 'upstream/master' into TFTransform_multi…
zeahmed Sep 4, 2018
2ad6470
Updated CSharpAPI and relevants tests to reflect multi-ouputs from TF…
zeahmed Sep 4, 2018
ca7f38c
Merge remote-tracking branch 'upstream/master' into TFTransform_multi…
zeahmed Sep 4, 2018
e2e0e14
Merge remote-tracking branch 'upstream/master' into TFTransform_multi…
zeahmed Sep 4, 2018
5316655
Updated cache position to -1 when creating getters in TensorflowTrans…
zeahmed Sep 4, 2018
b57a5c9
Updated doc.xml to reflect the multi-outputs.
zeahmed Sep 5, 2018
76ea346
Addressed reviewers' comments.
zeahmed Sep 5, 2018
23677c8
Addressed reviewers' comments.
zeahmed Sep 6, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 132 additions & 47 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions src/Microsoft.ML.TensorFlow/doc.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@
Extracts hidden layers' values from a pre-trained Tensorflow model.
</summary>
<remarks>
The TensorflowTransform extracts the specified output from the operation computed on the graph (given the input(s)) using a pre-trained <a href="https://www.tensorflow.org">Tensorflow</a> model.
The transform takes as input the Tensorflow model together with the names of the inputs to the model and name of the operation for which output values will be extracted from the model.
The TensorflowTransform extracts the specified outputs from the operations computed on the graph (given the input(s)) using a pre-trained <a href="https://www.tensorflow.org">Tensorflow</a> model.
The transform takes as input the Tensorflow model together with the names of the inputs to the model and names of the operations for which output values will be extracted from the model.

This transform requires the <a href="https://dotnet.myget.org/feed/dotnet-core/package/nuget/Microsoft.ML.TensorFlow/0.5.0-preview-26830-5">Microsoft.ML.TensorFlow</a> nuget to be installed.

The TensorflowTransform has following assumptions regarding the input, output and processing of data.
<list type="number">
<item>
The transform currently accepts the <a href="https://www.tensorflow.org/mobile/prepare_models">frozen TensorFlow model</a> file as input.
</item>
<item>The transform supports scoring only one example at a time.</item>
<item>The name of input column(s) should match the name of input(s) in Tensorflow model.</item>
<item>The name of the output column should match one of the operations in the Tensorflow graph.</item>
<item>The name of each output column should match one of the operations in the Tensorflow graph.</item>
<item>Currently, float and double are the only acceptable data types for input/output.</item>
<item>
Upon success, the transform will introduce a new column in <see cref="IDataView"/> based on the name of the output column specified.
Upon success, the transform will introduce a new column in <see cref="IDataView"/> corresponding to each output column specified.
</item>
</list>

The inputs and outputs of a TensorFlow model can be obtained using the <a href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/README.md#inspecting-graphs"><code>summarize_graph</code> tool</a>.

</remarks>
</member>
<example name="TensorflowTransform">
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15939,7 +15939,7 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.
/// <summary>
/// The name of the output
/// </summary>
public string OutputColumn { get; set; }
public string[] OutputColumns { get; set; }

/// <summary>
/// Input dataset
Expand Down
7 changes: 5 additions & 2 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -21919,8 +21919,11 @@
"IsNullable": false
},
{
"Name": "OutputColumn",
"Type": "String",
"Name": "OutputColumns",
"Type": {
"Kind": "Array",
"ItemType": "String"
},
"Desc": "The name of the output",
"Aliases": [
"output"
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ public void TestTensorFlowEntryPoint()
{
Data = importOutput.Data,
InputColumns = new[] { "Placeholder" },
OutputColumn = "Softmax",
OutputColumns = new[] { "Softmax" },
ModelFile = "mnist_model/frozen_saved_model.pb"
};
var tfTransformOutput = experiment.Add(tfTransformInput);
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3742,7 +3742,7 @@ public void EntryPointTensorFlowTransform()
{
@"'InputColumns': [ 'Placeholder' ],
'ModelFile': 'mnist_model/frozen_saved_model.pb',
'OutputColumn': 'Softmax'"
'OutputColumns': [ 'Softmax' ]"
});
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void TensorFlowTransformCifarLearningPipelineTest()
{
ModelFile = model_location,
InputColumns = new[] { "Input" },
OutputColumn = "Output"
OutputColumns = new[] { "Output" }
});

pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "Output"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,13 @@ public void TensorFlowTransformMNISTConvTest()
}
}, new MultiFileSource(dataPath));

IDataView trans = TensorFlowTransform.Create(env, loader, model_location, "Softmax", "Placeholder");
trans = new ConcatTransform(env, trans, "reshape_input", "Placeholder");
trans = TensorFlowTransform.Create(env, trans, model_location, "dense/Relu", "reshape_input");
IDataView trans = CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments()
{
Column = new[] { new CopyColumnsTransform.Column()
{ Name = "reshape_input", Source = "Placeholder" }
}
}, loader);
trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" });
trans = new ConcatTransform(env, trans, "Features", "Softmax", "dense/Relu");

var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments());
Expand Down