Skip to content

TensorFlow estimator #840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 7, 2018

Conversation

Ivanidzo4ka
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka commented Sep 5, 2018

Converts TensorFlow transform to estimator/transformer + tests for them

public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);

private TFSession LoadTFSession(byte[] modelBytes, string modelArg = null)
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Sep 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string modelArg = null [](start = 59, length = 22)

I couldn't find any usage of this, so I tempted to delete this #Resolved

@@ -16,7 +16,7 @@ public CSharpCodeGen(ITestOutputHelper output) : base(output)
{
}

[Fact(Skip = "Execute this test if you want to regenerate CSharpApi file")]
[Fact]
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Sep 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

] [](start = 13, length = 1)

revert #Resolved

@@ -234,7 +234,7 @@ private string GetBuildPrefix()
#endif
}

[Fact(Skip = "Execute this test if you want to regenerate ep-list and _manifest.json")]
[Fact]
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Sep 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fact [](start = 9, length = 4)

revert #Resolved

var info = new RowMapperColumnInfo[_parent.Outputs.Length];
for (int i = 0; i < _parent.Outputs.Length; i++)
info[i] = new RowMapperColumnInfo(_parent.Outputs[i], _parent.OutputTypes[i], null);
return info;
}

private static (string[], int[], bool[], TFShape[], TFDataType[]) GetInputMetaData(TFGraph graph, string[] source, ISchema inputSchema)
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Sep 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed anymore #Resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we doing the shape checks somewhere ? e.g. lines 451-158


In reply to: 215430955 [](ancestors = 215430955)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They all got moved to TensorFlowMapper constructor.


In reply to: 215438936 [](ancestors = 215438936,215430955)

@zeahmed
Copy link
Contributor

zeahmed commented Sep 5, 2018

I have bunch of changes in my PR, is it going to affect this PR? #Resolved

@Ivanidzo4ka Ivanidzo4ka changed the title WIP TensorFlow estimator TensorFlow estimator Sep 6, 2018
@Ivanidzo4ka Ivanidzo4ka self-assigned this Sep 6, 2018
@Ivanidzo4ka Ivanidzo4ka added the API Issues pertaining the friendly API label Sep 6, 2018
@Ivanidzo4ka Ivanidzo4ka added this to the 0918 milestone Sep 6, 2018
@Ivanidzo4ka
Copy link
Contributor Author

They all part of this PR.


In reply to: 418913567 [](ancestors = 418913567)

@Zruty0 Zruty0 mentioned this pull request Sep 6, 2018
}

private Delegate MakeGetter<T>(IRow input, ColumnType columnType)
private Delegate[] MakeGetter(IRow input, Func<int, bool> activeOutput)
Copy link

@yaeldekel yaeldekel Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MakeGetter [](start = 31, length = 10)

nit: Change to MakeGetters. #Closed

return new TensorFlowTransform(env, File.ReadAllBytes(modelFile), source, names).MakeDataTransform(input);
}

// Factory method for SignatureLoadModel.
Copy link
Member

@sfilipi sfilipi Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Fact [](start = 8, length = 7)

xml style documentation, since it is a public method #ByDesign

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OR we could make that stuff internal (and change ComponentCatalog to look for internal.


In reply to: 215723539 [](ancestors = 215723539)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe later though.


In reply to: 215728658 [](ancestors = 215728658,215723539)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's public because of flaws in our dependency injection.


In reply to: 215728692 [](ancestors = 215728692,215728658,215723539)

var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData);
var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData);
var invalidDataWrongVectorSize = ComponentCreation.CreateDataView(Env, sizeData);
TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames, validSchemaCantFitInput: invalidDataWrongVectorSize);
Copy link
Contributor

@Zruty0 Zruty0 Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validSchemaCantFitInput [](start = 83, length = 23)

I think it's unique enough to not be in the TestEstimatorCore. You can just check that schema propagation works and Fit fails in this test. #Closed

var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumn);
return new RowToRowMapperTransform(host, input, mapper);
_host = env.Register(nameof(TensorFlowEstimator));
_transformer = new TensorFlowTransform(env, File.ReadAllBytes(modelFile), inputs, outputs);
Copy link

@yaeldekel yaeldekel Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File [](start = 56, length = 4)

Should we check that the file exists? #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the ctors for the trivial estimator should match the ctors for the transformer.


In reply to: 215728796 [](ancestors = 215728796)

{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
Copy link
Contributor

@Zruty0 Zruty0 Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CheckAtModel [](start = 16, length = 12)

where's the binary format comment? #Closed


ctx.SaveNonEmptyString(_outputColName);
}
private sealed class TensorFlowMapper : IRowMapper
Copy link
Contributor

@Zruty0 Zruty0 Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorFlowMapper [](start = 29, length = 16)

maybe just Mapper to follow the pattern? #Closed

return TFTensor.Create(_vBufferDense.Values, _tfShape);
}
for (var i = 0; i < _transformer.Outputs.Length; i++)
//IVAN: not sure about VectorKind.
Copy link

@yaeldekel yaeldekel Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VectorKind [](start = 39, length = 10)

For now it is Vector. There will soon be a PR for enabling VariableVector as well. #Resolved

public const string UserName = "TensorFlowTransform";
public const string ShortName = "TFTransform";
private const string RegistrationName = "TensorFlowTransform";
public sealed class TensorFlowEstimator : IEstimator<TensorFlowTransform>
Copy link
Contributor

@Zruty0 Zruty0 Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IEstimator [](start = 46, length = 10)

can we do TrivialEstimator ? #Closed

throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
var tfInput = new TFOutput(_transformer.Graph[input]);
if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType))
throw _host.Except($"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
Copy link
Contributor

@Zruty0 Zruty0 Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be checked when the graph is loaded? It doesn't look like we need inputSchema for this check #Closed

if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType))
throw _host.Except($"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
var tfShape = _transformer.Graph.GetTensorShape(tfInput);
var shape = tfShape.ToIntArray().Skip(tfShape[0] == -1 ? TensorFlowTransform.BatchSize : 0);
Copy link
Contributor

@Zruty0 Zruty0 Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape [](start = 20, length = 5)

is shape used? #Closed

@Zruty0
Copy link
Contributor

Zruty0 commented Sep 6, 2018

I sign off on the estimator conversion. I did not look at the multiple inputs change, so I'm not marking as approved.

throw _host.ExceptParam(nameof(outputs), $"Input column '{input}' does not exist in the model");
var tfInput = new TFOutput(Session.Graph[input]);
if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType))
throw _host.Except($"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
Copy link
Contributor

@Zruty0 Zruty0 Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except [](start = 32, length = 6)

ExceptParam on modelStream #Closed


public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelBytes, string[] inputColNames, string outputColName)
private TensorFlowTransform(IHostEnvironment env, byte[] modelStream, string[] inputs, string[] outputs)
Copy link
Contributor

@Zruty0 Zruty0 Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modelStream [](start = 65, length = 11)

nit: I would call that modelBytes, since it's not actually a stream #Closed

{
_host.CheckNonWhiteSpace(input, nameof(inputs));
if (Session.Graph[input] == null)
throw _host.ExceptParam(nameof(outputs), $"Input column '{input}' does not exist in the model");
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputs [](start = 51, length = 7)

inputs #Resolved

{
_host.CheckNonWhiteSpace(input, nameof(inputs));
if (Session.Graph[input] == null)
throw _host.ExceptParam(nameof(outputs), $"Input column '{input}' does not exist in the model");
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](start = 85, length = 2)

HORROR #Resolved

Copy link

@yaeldekel yaeldekel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

The TensorflowTransform extracts the specified output from the operation computed on the graph (given the input(s)) using a pre-trained <a href="https://www.tensorflow.org">Tensorflow</a> model.
The transform takes as input the Tensorflow model together with the names of the inputs to the model and name of the operation for which output values will be extracted from the model.
The TensorflowTransform extracts the specified outputs from the operations computed on the graph (given the input(s)) using a pre-trained <a href="https://www.tensorflow.org">Tensorflow</a> model.
The transform takes as input the Tensorflow model together with the names of the inputs to the model and names of the operations for which output values will be extracted from the model.

This transform requires the <a href="https://dotnet.myget.org/feed/dotnet-core/package/nuget/Microsoft.ML.TensorFlow/0.5.0-preview-26830-5">Microsoft.ML.TensorFlow</a> nuget to be installed.
Copy link
Member

@sfilipi sfilipi Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This [](start = 8, length = 4)

if you want real paragraphs, i think you have to wrap the sentence in #Resolved


This transform requires the <a href="https://dotnet.myget.org/feed/dotnet-core/package/nuget/Microsoft.ML.TensorFlow/0.5.0-preview-26830-5">Microsoft.ML.TensorFlow</a> nuget to be installed.

The TensorflowTransform has following assumptions regarding the input, output and processing of data.
<list type="number">
<item>
Copy link
Member

@sfilipi sfilipi Sep 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](start = 10, length = 6)

not part of your PR, but i have had trouble with list items not displaying if not inside #Resolved

Copy link
Contributor

@Zruty0 Zruty0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@Ivanidzo4ka Ivanidzo4ka merged commit 44c6e90 into dotnet:master Sep 7, 2018
@ghost ghost locked as resolved and limited conversation to collaborators Mar 29, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
API Issues pertaining the friendly API
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants