diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index bd0f08ca98..f5e85f158f 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -509,7 +509,18 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : var shape = originalShape.dims; if (shape == null || (shape.Length == 0)) - _fullySpecifiedShapes[i] = new TensorShape(); + { + // for vector type input TensorShape should same to dim + if (_isInputVector[i]) + { + vecType = (VectorDataViewType)type; + var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray(); + _fullySpecifiedShapes[i] = new TensorShape(colTypeDims); + } + else + // for primitive type use default TensorShape + _fullySpecifiedShapes[i] = new TensorShape(); + } else { vecType = (VectorDataViewType)type; diff --git a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs index 0b990b282e..805aedcb3d 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs @@ -52,13 +52,6 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap if (mlType == null || op.NumOutputs <= 0) continue; - // Construct the final ML.NET type of a Tensorflow variable. - var tensorShape = op.output.TensorShape.dims; - var columnType = new VectorDataViewType(mlType); - if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) && - (Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0))) - columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray()); - // There can be at most two metadata fields. // 1. The first field always presents. Its value is this operator's type. For example, // if an output is produced by an "Softmax" operator, the value of this field should be "Softmax". @@ -83,7 +76,24 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap (ref VBuffer> value) => { upstreamOperatorNames.CopyTo(ref value); }); } - schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations()); + // Construct the final ML.NET type of a Tensorflow variable. + var tensorShape = op.output.TensorShape.dims; + + if(tensorShape == null) + { + // primitive column type + schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations()); + } + else + { + // vector column type + DataViewType columnType = new VectorDataViewType(mlType); + if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) && + (Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0))) + columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray()); + + schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations()); + } } return schemaBuilder.ToSchema(); } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 2c3e3ed8fe..175e958f1b 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1262,10 +1262,10 @@ class TextOutput class PrimitiveInput { - [LoadColumn(0, 1)] + [LoadColumn(0)] public string input1; - [LoadColumn(1, 2)] + [LoadColumn(1)] public string input2; } @@ -1305,8 +1305,10 @@ public void TensorFlowPrimitiveInputTest() { using var tensorFlowModel = _mlContext.Model.LoadTensorFlowModel(@"model_primitive_input_test"); var schema = tensorFlowModel.GetModelSchema(); - Assert.True(schema.TryGetColumnIndex("input1", out var colIndex)); - Assert.True(schema.TryGetColumnIndex("input2", out colIndex)); + Assert.True(schema.GetColumnOrNull("input1").HasValue); + Assert.True(schema.GetColumnOrNull("input1").Value.Type is TextDataViewType); + Assert.True(schema.GetColumnOrNull("input2").HasValue); + Assert.True(schema.GetColumnOrNull("input2").Value.Type is TextDataViewType); var dataview = _mlContext.Data.CreateTextLoader().Load(new MultiFileSource(null));