-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Enabled option to get multiple outputs from TF graphs #814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Update Fork...
private readonly string[] _outputColNames; | ||
private readonly ColumnType[] _outputColTypes; | ||
private readonly TFDataType[] _tfOutputTypes; | ||
private IDictionary<string, TFTensor> _cachedOutputs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaeldekel, is it the right place to cache data here? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dictionary holding the cached outputs needs to be instantiated on each call to CreateGetters.
In reply to: 215089636 [](ancestors = 215089636)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -326,7 +384,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV | |||
host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile)); | |||
|
|||
var modelBytes = File.ReadAllBytes(args.ModelFile); | |||
var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumn); | |||
var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumns); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OutputColumns [](start = 102, length = 13)
Do you want to validate uniqueness of output columns? #Resolved
_host.Check(inputColNames.All(name => _session.Graph[name] != null), "One of the input does not exist in the model"); | ||
_host.Check(outputColNames.All(name => _session.Graph[name] != null), "One of the output does not exist in the model"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[name] [](start = 69, length = 6)
Could this throw if "name" is null? If it does, then we need to add a check of each item in outputColNames, in addition to the CheckNonEmpty on the array itself. #Closed
Contracts.AssertNonEmpty(_outputColNames); | ||
ctx.Writer.Write(_outputColNames.Length); | ||
foreach (var colName in _outputColNames) | ||
ctx.SaveNonEmptyString(colName); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to update the version. #Closed
@@ -116,7 +127,10 @@ public void Save(ModelSaveContext ctx) | |||
foreach (var colName in _inputColNames) | |||
ctx.SaveNonEmptyString(colName); | |||
|
|||
ctx.SaveNonEmptyString(_outputColName); | |||
Contracts.AssertNonEmpty(_outputColNames); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contracts [](start = 16, length = 9)
_host. #Closed
} | ||
|
||
private static (ColumnType, TFDataType) GetOutputTypes(TFGraph graph, string columnName) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static [](start = 20, length = 6)
Can this still be static? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using the _session.Graph[name]
in this method.
In reply to: 215388432 [](ancestors = 215388432)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the graph that is passed as an argument instead.
In reply to: 215390484 [](ancestors = 215390484,215388432)
return new[] { new RowMapperColumnInfo(_outputColName, _outputColType, null) }; | ||
var info = new RowMapperColumnInfo[_outputColNames.Length]; | ||
for (int i = 0; i < _outputColNames.Length; i++) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{ [](start = 16, length = 1)
nit: don't need the curly braces. #Closed
This should be updated to reflect that if any of the output columns are active, then all the inputs need to active. #Closed Refers to: src/Microsoft.ML.TensorFlow/TensorflowTransform.cs:251 in b57a5c9. [](commit_id = b57a5c9, deletion_comment = False) |
@@ -187,28 +217,30 @@ private Delegate MakeGetter<T>(IRow input, ColumnType columnType) | |||
runner.AddInput(inputName, srcTensorGetters[i].GetTensor()); | |||
} | |||
|
|||
var tensors = runner.Fetch(_outputColName).Run(); | |||
|
|||
var tensors = runner.Fetch(_outputColNames).Run(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_outputColNames [](start = 47, length = 15)
Does it make sense to create an array that only has the active outputs so that we don't need to compute and cache all of them if only some are active? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is basically an array passed by the user. I think any output column requested by the user must be active. Do you think a scenario where user specified column wont be active?
In reply to: 215390927 [](ancestors = 215390927)
values = new T[_outputColTypes[iinfo].VectorSize]; | ||
|
||
TensorFlowUtils.FetchData<T>(_cachedOutputs[_outputColNames[iinfo]].Data, values); | ||
dst = new VBuffer<T>(values.Length, values); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
); [](start = 62, length = 2)
Pass dst.Indices here too. #Closed
private readonly ColumnType[] _outputColTypes; | ||
private readonly TFDataType[] _tfOutputTypes; | ||
private IDictionary<string, TFTensor> _cachedOutputs; | ||
private long _cachedPosition; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: what do the _cachedOutputs and _cachedPosition represent ?
Are they only needed for the multiple inputs/outputs scenario ? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, output getters work independent of each other and in each getter we have call to compute operations on TF graph.
What we are doing here is for each row in the input dataview, cache the results of all the requested operations on the first output getter call. Then for subsequent output getter calls for that row, serve the results from this _cachedOutputs
dictionary.
If the row position is changed the computation is done again and the current position is saved in cachedPosition
.
In reply to: 215393239 [](ancestors = 215393239)
@dotnet-bot, test it please. #Closed |
Should we add a simple test showcasing the multiple input / multiple output scenario -- perhaps just taking 2 numbers as inputs and producing their {sum, product} as outputs ? #Closed Refers to: test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs:2 in b57a5c9. [](commit_id = b57a5c9, deletion_comment = False) |
I have already changed the test In reply to: 418871731 [](ancestors = 418871731) Refers to: test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs:2 in b57a5c9. [](commit_id = b57a5c9, deletion_comment = False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am planning to do it in next PR because it will cause the model nuget to be updated. I have created an issue for this #850 In reply to: 418903839 [](ancestors = 418903839,418871731) Refers to: test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs:2 in b57a5c9. [](commit_id = b57a5c9, deletion_comment = False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR fixes #712.