From 5773543a58095ef4f3d23dd525088f031d9beb04 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Thu, 10 Dec 2020 19:32:44 +0000 Subject: [PATCH 1/2] fix tensorflow issue on sample repo --- .../TensorflowTransform.cs | 11 ++++++++- .../TensorflowUtils.cs | 24 ++++++++++++------- .../TensorflowTests.cs | 4 ++-- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index bd0f08ca98..b2f08bc96e 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -509,7 +509,16 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : var shape = originalShape.dims; if (shape == null || (shape.Length == 0)) - _fullySpecifiedShapes[i] = new TensorShape(); + { + if (_isInputVector[i]) + { + vecType = (VectorDataViewType)type; + var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray(); + _fullySpecifiedShapes[i] = new TensorShape(colTypeDims); + } + else + _fullySpecifiedShapes[i] = new TensorShape(); + } else { vecType = (VectorDataViewType)type; diff --git a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs index 0b990b282e..a7bfbf5aac 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs @@ -52,13 +52,6 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap if (mlType == null || op.NumOutputs <= 0) continue; - // Construct the final ML.NET type of a Tensorflow variable. - var tensorShape = op.output.TensorShape.dims; - var columnType = new VectorDataViewType(mlType); - if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) && - (Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0))) - columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray()); - // There can be at most two metadata fields. // 1. The first field always presents. Its value is this operator's type. For example, // if an output is produced by an "Softmax" operator, the value of this field should be "Softmax". @@ -83,7 +76,22 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap (ref VBuffer> value) => { upstreamOperatorNames.CopyTo(ref value); }); } - schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations()); + // Construct the final ML.NET type of a Tensorflow variable. + var tensorShape = op.output.TensorShape.dims; + + if(tensorShape == null) + { + schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations()); + } + else + { + DataViewType columnType = new VectorDataViewType(mlType); + if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) && + (Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0))) + columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray()); + + schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations()); + } } return schemaBuilder.ToSchema(); } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 2c3e3ed8fe..3899b0b753 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1262,10 +1262,10 @@ class TextOutput class PrimitiveInput { - [LoadColumn(0, 1)] + [LoadColumn(0)] public string input1; - [LoadColumn(1, 2)] + [LoadColumn(1)] public string input2; } From c5d3638e819291006b2d0e215cb792e2db23ad20 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Thu, 10 Dec 2020 21:42:20 +0000 Subject: [PATCH 2/2] add comments --- src/Microsoft.ML.TensorFlow/TensorflowTransform.cs | 2 ++ src/Microsoft.ML.TensorFlow/TensorflowUtils.cs | 2 ++ .../ScenariosWithDirectInstantiation/TensorflowTests.cs | 6 ++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index b2f08bc96e..f5e85f158f 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -510,6 +510,7 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : if (shape == null || (shape.Length == 0)) { + // for vector type input TensorShape should same to dim if (_isInputVector[i]) { vecType = (VectorDataViewType)type; @@ -517,6 +518,7 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : _fullySpecifiedShapes[i] = new TensorShape(colTypeDims); } else + // for primitive type use default TensorShape _fullySpecifiedShapes[i] = new TensorShape(); } else diff --git a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs index a7bfbf5aac..805aedcb3d 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs @@ -81,10 +81,12 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap if(tensorShape == null) { + // primitive column type schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations()); } else { + // vector column type DataViewType columnType = new VectorDataViewType(mlType); if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) && (Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0))) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 3899b0b753..175e958f1b 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1305,8 +1305,10 @@ public void TensorFlowPrimitiveInputTest() { using var tensorFlowModel = _mlContext.Model.LoadTensorFlowModel(@"model_primitive_input_test"); var schema = tensorFlowModel.GetModelSchema(); - Assert.True(schema.TryGetColumnIndex("input1", out var colIndex)); - Assert.True(schema.TryGetColumnIndex("input2", out colIndex)); + Assert.True(schema.GetColumnOrNull("input1").HasValue); + Assert.True(schema.GetColumnOrNull("input1").Value.Type is TextDataViewType); + Assert.True(schema.GetColumnOrNull("input2").HasValue); + Assert.True(schema.GetColumnOrNull("input2").Value.Type is TextDataViewType); var dataview = _mlContext.Data.CreateTextLoader().Load(new MultiFileSource(null));