diff --git a/build/Dependencies.props b/build/Dependencies.props
index 9d2174267b..221d1e6552 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -43,7 +43,7 @@
0.11.3
0.0.3-test
- 0.0.7-test
+ 0.0.10-test
0.0.4-test
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index f87c2da37c..cfe041445e 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -667,15 +667,6 @@ internal static (TFDataType[] tfInputTypes, TFShape[] tfInputShapes) GetInputInf
var tfInput = new TFOutput(session.Graph[inputs[i]]);
tfInputTypes[i] = tfInput.OutputType;
tfInputShapes[i] = session.Graph.GetTensorShape(tfInput);
- if (tfInputShapes[i].NumDimensions != -1)
- {
- var newShape = new long[tfInputShapes[i].NumDimensions];
- newShape[0] = tfInputShapes[i][0] == -1 ? BatchSize : tfInputShapes[i][0];
-
- for (int j = 1; j < tfInputShapes[i].NumDimensions; j++)
- newShape[j] = tfInputShapes[i][j];
- tfInputShapes[i] = new TFShape(newShape);
- }
}
return (tfInputTypes, tfInputShapes);
}
@@ -698,7 +689,14 @@ internal static (TFDataType[] tfOutputTypes, ColumnType[] outputTypes) GetOutput
{
var tfOutput = new TFOutput(session.Graph[outputs[i]]);
var shape = session.Graph.GetTensorShape(tfOutput);
+
+ // The transformer can only retreive the output as fixed length vector with shape of kind [-1, d1, d2, d3, ...]
+ // i.e. the first dimension (if unknown) is assumed to be batch dimension.
+ // If there are other dimension that are unknown the transformer will return a variable length vector.
+ // This is the work around in absence of reshape transformer.
int[] dims = shape.NumDimensions > 0 ? shape.ToIntArray().Skip(shape[0] == -1 ? 1 : 0).ToArray() : new[] { 0 };
+ for (int j = 0; j < dims.Length; j++)
+ dims[j] = dims[j] == -1 ? 0 : dims[j];
var type = TensorFlowUtils.Tf2MlNetType(tfOutput.OutputType);
outputTypes[i] = new VectorType(type, dims);
tfOutputTypes[i] = tfOutput.OutputType;
@@ -837,14 +835,22 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
var originalShape = _parent.TFInputShapes[i];
var shape = originalShape.ToIntArray();
- var colTypeDims = vecType.Dimensions.Prepend(1).Select(dim => (long)dim).ToArray();
+ var colTypeDims = vecType.Dimensions.Select(dim => (long)dim).ToArray();
if (shape == null)
_fullySpecifiedShapes[i] = new TFShape(colTypeDims);
- else if (vecType.Dimensions.Length == 1)
+ else
{
// If the column is one dimension we make sure that the total size of the TF shape matches.
// Compute the total size of the known dimensions of the shape.
- int valCount = shape.Where(x => x > 0).Aggregate((x, y) => x * y);
+ int valCount = 1;
+ int numOfUnkDim = 0;
+ foreach (var s in shape)
+ {
+ if (s > 0)
+ valCount *= s;
+ else
+ numOfUnkDim++;
+ }
// The column length should be divisible by this, so that the other dimensions can be integral.
int typeValueCount = type.GetValueCount();
if (typeValueCount % valCount != 0)
@@ -853,8 +859,8 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
// If the shape is multi-dimensional, we should be able to create the length of the vector by plugging
// in a single value for the unknown shapes. For example, if the shape is [?,?,3], then there should exist a value
// d such that d*d*3 is equal to the length of the input column.
- var d = originalShape.NumDimensions > 2 ? Math.Pow(typeValueCount / valCount, 1.0 / (originalShape.NumDimensions - 2)) : 1;
- if (originalShape.NumDimensions > 2 && d - (int)d != 0)
+ var d = numOfUnkDim > 0 ? Math.Pow(typeValueCount / valCount, 1.0 / numOfUnkDim) : 0;
+ if (d - (int)d != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");
// Fill in the unknown dimensions.
@@ -863,17 +869,6 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
l[ishape] = originalShape[ishape] == -1 ? (int)d : originalShape[ishape];
_fullySpecifiedShapes[i] = new TFShape(l);
}
- else
- {
- if (shape.Select((dim, j) => dim != -1 && dim != colTypeDims[j]).Any(b => b))
- throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is {vecType.ToString()}.");
-
- // Fill in the unknown dimensions.
- var l = new long[originalShape.NumDimensions];
- for (int ishape = 0; ishape < originalShape.NumDimensions; ishape++)
- l[ishape] = originalShape[ishape] == -1 ? colTypeDims[ishape] : originalShape[ishape];
- _fullySpecifiedShapes[i] = new TFShape(l);
- }
}
}
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index fd8d3183bb..398d426684 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -74,6 +74,117 @@ public void TensorFlowTransformMatrixMultiplicationTest()
}
}
+ private class ShapeData
+ {
+ // Data will be passed as 1-D vector.
+ // Intended data shape [5], model shape [None]
+ [VectorType(5)]
+ public float[] OneDim;
+
+ // Data will be passed as flat vector.
+ // Intended data shape [2,2], model shape [2, None]
+ [VectorType(4)]
+ public float[] TwoDim;
+
+ // Data will be passed as 3-D vector.
+ // Intended data shape [1, 2, 2], model shape [1, None, 2]
+ [VectorType(1, 2, 2)]
+ public float[] ThreeDim;
+
+ // Data will be passed as flat vector.
+ // Intended data shape [1, 2, 2, 3], model shape [1, None, None, 3]
+ [VectorType(12)]
+ public float[] FourDim;
+
+ // Data will be passed as 4-D vector.
+ // Intended data shape [2, 2, 2, 2], model shape [2, 2, 2, 2]
+ [VectorType(2, 2, 2, 2)]
+ public float[] FourDimKnown;
+ }
+
+ private List GetShapeData()
+ {
+ return new List(new ShapeData[] {
+ new ShapeData() { OneDim = new[] { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f },
+ TwoDim = new[] { 1.0f, 2.0f, 3.0f, 4.0f },
+ ThreeDim = new[] { 11.0f, 12.0f, 13.0f, 14.0f },
+ FourDim = new[]{ 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f,
+ 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f },
+ FourDimKnown = new[]{ 41.0f , 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f,
+ 49.0f , 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f}
+ },
+ new ShapeData() { OneDim = new[] { 100.1f, 100.2f, 100.3f, 100.4f, 100.5f },
+ TwoDim = new[] { 101.0f, 102.0f, 103.0f, 104.0f },
+ ThreeDim = new[] { 111.0f, 112.0f, 113.0f, 114.0f },
+ FourDim = new[]{ 121.0f, 122.0f, 123.0f, 124.0f, 125.0f, 126.0f,
+ 127.0f, 128.0f, 129.0f, 130.0f, 131.0f, 132.0f},
+ FourDimKnown = new[]{ 141.0f , 142.0f, 143.0f, 144.0f, 145.0f, 146.0f, 147.0f, 148.0f,
+ 149.0f , 150.0f, 151.0f, 152.0f, 153.0f, 154.0f, 155.0f, 156.0f }
+ }
+ });
+ }
+
+ [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only
+ public void TensorFlowTransformInputShapeTest()
+ {
+ var modelLocation = "model_shape_test";
+ var mlContext = new MLContext(seed: 1, conc: 1);
+ var data = GetShapeData();
+ // Pipeline
+ var loader = mlContext.Data.ReadFromEnumerable(data);
+ var inputs = new string[] { "OneDim", "TwoDim", "ThreeDim", "FourDim", "FourDimKnown" };
+ var outputs = new string[] { "o_OneDim", "o_TwoDim", "o_ThreeDim", "o_FourDim", "o_FourDimKnown" };
+
+ var trans = mlContext.Transforms.ScoreTensorFlowModel(modelLocation, outputs, inputs).Fit(loader).Transform(loader);
+
+ using (var cursor = trans.GetRowCursorForAllColumns())
+ {
+ int outColIndex = 5;
+ var oneDimgetter = cursor.GetGetter>(outColIndex);
+ var twoDimgetter = cursor.GetGetter>(outColIndex + 1);
+ var threeDimgetter = cursor.GetGetter>(outColIndex + 2);
+ var fourDimgetter = cursor.GetGetter>(outColIndex + 3);
+ var fourDimKnowngetter = cursor.GetGetter>(outColIndex + 4);
+
+ VBuffer oneDim = default;
+ VBuffer twoDim = default;
+ VBuffer threeDim = default;
+ VBuffer fourDim = default;
+ VBuffer fourDimKnown = default;
+ foreach (var sample in data)
+ {
+ Assert.True(cursor.MoveNext());
+
+ oneDimgetter(ref oneDim);
+ twoDimgetter(ref twoDim);
+ threeDimgetter(ref threeDim);
+ fourDimgetter(ref fourDim);
+ fourDimKnowngetter(ref fourDimKnown);
+
+ var oneDimValues = oneDim.GetValues();
+ Assert.Equal(sample.OneDim.Length, oneDimValues.Length);
+ Assert.True(oneDimValues.SequenceEqual(sample.OneDim));
+
+ var twoDimValues = twoDim.GetValues();
+ Assert.Equal(sample.TwoDim.Length, twoDimValues.Length);
+ Assert.True(twoDimValues.SequenceEqual(sample.TwoDim));
+
+ var threeDimValues = threeDim.GetValues();
+ Assert.Equal(sample.ThreeDim.Length, threeDimValues.Length);
+ Assert.True(threeDimValues.SequenceEqual(sample.ThreeDim));
+
+ var fourDimValues = fourDim.GetValues();
+ Assert.Equal(sample.FourDim.Length, fourDimValues.Length);
+ Assert.True(fourDimValues.SequenceEqual(sample.FourDim));
+
+ var fourDimKnownValues = fourDimKnown.GetValues();
+ Assert.Equal(sample.FourDimKnown.Length, fourDimKnownValues.Length);
+ Assert.True(fourDimKnownValues.SequenceEqual(sample.FourDimKnown));
+ }
+ Assert.False(cursor.MoveNext());
+ }
+ }
+
private class TypesData
{
[VectorType(2)]
@@ -142,7 +253,7 @@ public void TensorFlowTransformInputOutputTypesTest()
var loader = mlContext.Data.ReadFromEnumerable(data);
- var inputs = new string[]{"f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8","b"};
+ var inputs = new string[] { "f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8", "b" };
var outputs = new string[] { "o_f64", "o_f32", "o_i64", "o_i32", "o_i16", "o_i8", "o_u64", "o_u32", "o_u16", "o_u8", "o_b" };
var trans = mlContext.Transforms.ScoreTensorFlowModel(model_location, outputs, inputs).Fit(loader).Transform(loader); ;
@@ -160,7 +271,7 @@ public void TensorFlowTransformInputOutputTypesTest()
var u8getter = cursor.GetGetter>(20);
var boolgetter = cursor.GetGetter>(21);
-
+
VBuffer f64 = default;
VBuffer f32 = default;
VBuffer i64 = default;
@@ -449,7 +560,7 @@ public void TensorFlowTransformMNISTLRTrainingTest()
ReTrain = true
}))
.Append(mlContext.Transforms.Concatenate("Features", "Prediction"))
- .Append(mlContext.Transforms.Conversion.MapValueToKey("KeyLabel","Label", maxNumKeys: 10))
+ .Append(mlContext.Transforms.Conversion.MapValueToKey("KeyLabel", "Label", maxNumKeys: 10))
.Append(mlContext.MulticlassClassification.Trainers.LightGbm("KeyLabel", "Features"));
var trainedModel = pipe.Fit(trainData);