Skip to content

Commit 447ae1d

Browse files
Treat TensorFlow output as non-batched. (#5634)
* Can now not treat output as batched. * updated comments based on PR comments. * Fixing saving/loading with new parameter. * Updates based on PR comments * Update src/Microsoft.ML.TensorFlow/TensorflowUtils.cs Co-authored-by: Eric Erhardt <[email protected]> * reverted accidental test changes * fixes based on PR comments Co-authored-by: Eric Erhardt <[email protected]>
1 parent f93fa09 commit 447ae1d

File tree

7 files changed

+167
-35
lines changed

7 files changed

+167
-35
lines changed

src/Microsoft.ML.TensorFlow/TensorFlowModel.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public sealed class TensorFlowModel : IDisposable
1818
{
1919
internal Session Session { get; }
2020
internal string ModelPath { get; }
21+
internal bool TreatOutputAsBatched { get; }
2122

2223
private readonly IHostEnvironment _env;
2324

@@ -27,10 +28,12 @@ public sealed class TensorFlowModel : IDisposable
2728
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param>
2829
/// <param name="session">TensorFlow session object.</param>
2930
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param>
30-
internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation)
31+
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
32+
internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation, bool treatOutputAsBatched = true)
3133
{
3234
Session = session;
3335
ModelPath = modelLocation;
36+
TreatOutputAsBatched = treatOutputAsBatched;
3437
_env = env;
3538
_disposed = false;
3639
}
@@ -40,7 +43,7 @@ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLoca
4043
/// </summary>
4144
public DataViewSchema GetModelSchema()
4245
{
43-
return TensorFlowUtils.GetModelSchema(_env, Session.graph);
46+
return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched);
4447
}
4548

4649
/// <summary>
@@ -49,7 +52,7 @@ public DataViewSchema GetModelSchema()
4952
/// </summary>
5053
public DataViewSchema GetInputSchema()
5154
{
52-
return TensorFlowUtils.GetModelSchema(_env, Session.graph, "Placeholder");
55+
return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched, "Placeholder");
5356
}
5457

5558
/// <summary>

src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,31 @@ public static class TensorflowCatalog
3535
/// </example>
3636
public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation)
3737
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation);
38+
39+
/// <summary>
40+
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
41+
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlowModel.ScoreTensorFlowModel(string, string, bool)"/>.
42+
/// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information.
43+
/// <see cref="TensorFlowModel"/> also holds references to unmanaged resources that need to be freed either with an explicit
44+
/// call to Dispose() or implicitly by declaring the variable with the "using" syntax/>
45+
///
46+
/// <format type="text/markdown">
47+
/// <![CDATA[
48+
/// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)]
49+
/// ]]>
50+
/// </format>
51+
/// </summary>
52+
/// <param name="catalog">The transform's catalog.</param>
53+
/// <param name="modelLocation">Location of the TensorFlow model.</param>
54+
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
55+
/// <example>
56+
/// <format type="text/markdown">
57+
/// <![CDATA[
58+
/// [!code-csharp[LoadTensorFlowModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs)]
59+
/// ]]>
60+
/// </format>
61+
/// </example>
62+
public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation, bool treatOutputAsBatched)
63+
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation, treatOutputAsBatched);
3864
}
3965
}

0 commit comments

Comments
 (0)