-
Notifications
You must be signed in to change notification settings - Fork 1.9k
TensorFlow estimator #840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorFlow estimator #840
Conversation
public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) | ||
=> Create(env, ctx).MakeRowMapper(inputSchema); | ||
|
||
private TFSession LoadTFSession(byte[] modelBytes, string modelArg = null) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
string modelArg = null [](start = 59, length = 22)
I couldn't find any usage of this, so I tempted to delete this #Resolved
@@ -16,7 +16,7 @@ public CSharpCodeGen(ITestOutputHelper output) : base(output) | |||
{ | |||
} | |||
|
|||
[Fact(Skip = "Execute this test if you want to regenerate CSharpApi file")] | |||
[Fact] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
] [](start = 13, length = 1)
revert #Resolved
@@ -234,7 +234,7 @@ private string GetBuildPrefix() | |||
#endif | |||
} | |||
|
|||
[Fact(Skip = "Execute this test if you want to regenerate ep-list and _manifest.json")] | |||
[Fact] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fact [](start = 9, length = 4)
revert #Resolved
var info = new RowMapperColumnInfo[_parent.Outputs.Length]; | ||
for (int i = 0; i < _parent.Outputs.Length; i++) | ||
info[i] = new RowMapperColumnInfo(_parent.Outputs[i], _parent.OutputTypes[i], null); | ||
return info; | ||
} | ||
|
||
private static (string[], int[], bool[], TFShape[], TFDataType[]) GetInputMetaData(TFGraph graph, string[] source, ISchema inputSchema) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed anymore #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we doing the shape checks somewhere ? e.g. lines 451-158
In reply to: 215430955 [](ancestors = 215430955)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They all got moved to TensorFlowMapper constructor.
In reply to: 215438936 [](ancestors = 215438936,215430955)
I have bunch of changes in my PR, is it going to affect this PR? #Resolved |
They all part of this PR. In reply to: 418913567 [](ancestors = 418913567) |
} | ||
|
||
private Delegate MakeGetter<T>(IRow input, ColumnType columnType) | ||
private Delegate[] MakeGetter(IRow input, Func<int, bool> activeOutput) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MakeGetter [](start = 31, length = 10)
nit: Change to MakeGetters. #Closed
return new TensorFlowTransform(env, File.ReadAllBytes(modelFile), source, names).MakeDataTransform(input); | ||
} | ||
|
||
// Factory method for SignatureLoadModel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Fact [](start = 8, length = 7)
xml style documentation, since it is a public method #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OR we could make that stuff internal
(and change ComponentCatalog
to look for internal
.
In reply to: 215723539 [](ancestors = 215723539)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's public because of flaws in our dependency injection.
In reply to: 215728692 [](ancestors = 215728692,215728658,215723539)
var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData); | ||
var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData); | ||
var invalidDataWrongVectorSize = ComponentCreation.CreateDataView(Env, sizeData); | ||
TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames, validSchemaCantFitInput: invalidDataWrongVectorSize); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validSchemaCantFitInput [](start = 83, length = 23)
I think it's unique enough to not be in the TestEstimatorCore
. You can just check that schema propagation works and Fit
fails in this test. #Closed
var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumn); | ||
return new RowToRowMapperTransform(host, input, mapper); | ||
_host = env.Register(nameof(TensorFlowEstimator)); | ||
_transformer = new TensorFlowTransform(env, File.ReadAllBytes(modelFile), inputs, outputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File [](start = 56, length = 4)
Should we check that the file exists? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the ctors for the trivial estimator should match the ctors for the transformer.
In reply to: 215728796 [](ancestors = 215728796)
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckValue(ctx, nameof(ctx)); | ||
ctx.CheckAtModel(GetVersionInfo()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CheckAtModel [](start = 16, length = 12)
where's the binary format comment? #Closed
|
||
ctx.SaveNonEmptyString(_outputColName); | ||
} | ||
private sealed class TensorFlowMapper : IRowMapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorFlowMapper [](start = 29, length = 16)
maybe just Mapper
to follow the pattern? #Closed
return TFTensor.Create(_vBufferDense.Values, _tfShape); | ||
} | ||
for (var i = 0; i < _transformer.Outputs.Length; i++) | ||
//IVAN: not sure about VectorKind. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VectorKind [](start = 39, length = 10)
For now it is Vector. There will soon be a PR for enabling VariableVector as well. #Resolved
public const string UserName = "TensorFlowTransform"; | ||
public const string ShortName = "TFTransform"; | ||
private const string RegistrationName = "TensorFlowTransform"; | ||
public sealed class TensorFlowEstimator : IEstimator<TensorFlowTransform> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IEstimator [](start = 46, length = 10)
can we do TrivialEstimator
? #Closed
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); | ||
var tfInput = new TFOutput(_transformer.Graph[input]); | ||
if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType)) | ||
throw _host.Except($"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be checked when the graph is loaded? It doesn't look like we need inputSchema
for this check #Closed
if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType)) | ||
throw _host.Except($"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); | ||
var tfShape = _transformer.Graph.GetTensorShape(tfInput); | ||
var shape = tfShape.ToIntArray().Skip(tfShape[0] == -1 ? TensorFlowTransform.BatchSize : 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape [](start = 20, length = 5)
is shape
used? #Closed
I sign off on the estimator conversion. I did not look at the multiple inputs change, so I'm not marking as approved. |
throw _host.ExceptParam(nameof(outputs), $"Input column '{input}' does not exist in the model"); | ||
var tfInput = new TFOutput(Session.Graph[input]); | ||
if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType)) | ||
throw _host.Except($"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Except [](start = 32, length = 6)
ExceptParam
on modelStream
#Closed
|
||
public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelBytes, string[] inputColNames, string outputColName) | ||
private TensorFlowTransform(IHostEnvironment env, byte[] modelStream, string[] inputs, string[] outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modelStream [](start = 65, length = 11)
nit: I would call that modelBytes
, since it's not actually a stream #Closed
{ | ||
_host.CheckNonWhiteSpace(input, nameof(inputs)); | ||
if (Session.Graph[input] == null) | ||
throw _host.ExceptParam(nameof(outputs), $"Input column '{input}' does not exist in the model"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputs [](start = 51, length = 7)
inputs #Resolved
{ | ||
_host.CheckNonWhiteSpace(input, nameof(inputs)); | ||
if (Session.Graph[input] == null) | ||
throw _host.ExceptParam(nameof(outputs), $"Input column '{input}' does not exist in the model"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[](start = 85, length = 2)
HORROR #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/Microsoft.ML.TensorFlow/doc.xml
Outdated
The TensorflowTransform extracts the specified output from the operation computed on the graph (given the input(s)) using a pre-trained <a href="https://www.tensorflow.org">Tensorflow</a> model. | ||
The transform takes as input the Tensorflow model together with the names of the inputs to the model and name of the operation for which output values will be extracted from the model. | ||
The TensorflowTransform extracts the specified outputs from the operations computed on the graph (given the input(s)) using a pre-trained <a href="https://www.tensorflow.org">Tensorflow</a> model. | ||
The transform takes as input the Tensorflow model together with the names of the inputs to the model and names of the operations for which output values will be extracted from the model. | ||
|
||
This transform requires the <a href="https://dotnet.myget.org/feed/dotnet-core/package/nuget/Microsoft.ML.TensorFlow/0.5.0-preview-26830-5">Microsoft.ML.TensorFlow</a> nuget to be installed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This [](start = 8, length = 4)
if you want real paragraphs, i think you have to wrap the sentence in #Resolved
src/Microsoft.ML.TensorFlow/doc.xml
Outdated
|
||
This transform requires the <a href="https://dotnet.myget.org/feed/dotnet-core/package/nuget/Microsoft.ML.TensorFlow/0.5.0-preview-26830-5">Microsoft.ML.TensorFlow</a> nuget to be installed. | ||
|
||
The TensorflowTransform has following assumptions regarding the input, output and processing of data. | ||
<list type="number"> | ||
<item> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[](start = 10, length = 6)
not part of your PR, but i have had trouble with list items not displaying if not inside #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Converts TensorFlow transform to estimator/transformer + tests for them