diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 0aaae57598..495a74fdac 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.IO; using System.Linq; using Microsoft.ML.Runtime; @@ -42,37 +43,41 @@ internal sealed class TensorFlowMapper : IRowMapper private readonly TFShape[] _tfInputShapes; private readonly TFDataType[] _tfInputTypes; - private readonly string _outputColName; - private readonly ColumnType _outputColType; - private readonly TFDataType _tfOutputType; - + private readonly string[] _outputColNames; + private readonly ColumnType[] _outputColTypes; + private readonly TFDataType[] _tfOutputTypes; private const int BatchSize = 1; public const string LoaderSignature = "TFMapper"; private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "TENSFLOW", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Upgraded when change for multiple outputs was implemented. + verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature); } - public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelBytes, string[] inputColNames, string outputColName) + public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelBytes, string[] inputColNames, string[] outputColNames) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("TensorFlowMapper"); _host.CheckValue(inputSchema, nameof(inputSchema)); _host.CheckNonEmpty(modelBytes, nameof(modelBytes)); _host.CheckNonEmpty(inputColNames, nameof(inputColNames)); - _host.CheckNonEmpty(outputColName, nameof(outputColName)); + _host.CheckNonEmpty(outputColNames, nameof(outputColNames)); + for (int i = 0; i < inputColNames.Length; i++) + _host.CheckNonWhiteSpace(inputColNames[i], nameof(inputColNames)); + for (int i = 0; i < outputColNames.Length; i++) + _host.CheckNonWhiteSpace(outputColNames[i], nameof(outputColNames)); _session = LoadTFSession(modelBytes, null); - _host.CheckValue(_session.Graph[outputColName], nameof(outputColName), "Output does not exist in the model"); _host.Check(inputColNames.All(name => _session.Graph[name] != null), "One of the input does not exist in the model"); + _host.Check(outputColNames.All(name => _session.Graph[name] != null), "One of the output does not exist in the model"); - _outputColName = outputColName; - (_outputColType, _tfOutputType) = GetOutputTypes(_session.Graph, _outputColName); + _outputColNames = outputColNames; + (_outputColTypes, _tfOutputTypes) = GetOutputTypes(_session.Graph, _outputColNames); (_inputColNames, _inputColIndices, _isVectorInput, _tfInputShapes, _tfInputTypes) = GetInputMetaData(_session.Graph, inputColNames, inputSchema); } @@ -93,9 +98,20 @@ public static TensorFlowMapper Create(IHostEnvironment env, ModelLoadContext ctx if (!ctx.TryLoadBinaryStream("TFModel", r => data = r.ReadByteArray())) throw env.ExceptDecode(); - var outputColName = ctx.LoadNonEmptyString(); + bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002; + + var numOutputs = 1; + if (isMultiOutput) + { + numOutputs = ctx.Reader.ReadInt32(); + } + + Contracts.CheckDecode(numOutputs > 0); + var outputColNames = new string[numOutputs]; + for (int j = 0; j < outputColNames.Length; j++) + outputColNames[j] = ctx.LoadNonEmptyString(); - return new TensorFlowMapper(env, schema, data, source, outputColName); + return new TensorFlowMapper(env, schema, data, source, outputColNames); } public void Save(ModelSaveContext ctx) @@ -111,12 +127,15 @@ public void Save(ModelSaveContext ctx) { w.WriteByteArray(buffer.ToArray()); }); - Contracts.AssertNonEmpty(_inputColNames); + _host.AssertNonEmpty(_inputColNames); ctx.Writer.Write(_inputColNames.Length); foreach (var colName in _inputColNames) ctx.SaveNonEmptyString(colName); - ctx.SaveNonEmptyString(_outputColName); + _host.AssertNonEmpty(_outputColNames); + ctx.Writer.Write(_outputColNames.Length); + foreach (var colName in _outputColNames) + ctx.SaveNonEmptyString(colName); } private TFSession LoadTFSession(byte[] modelBytes, string modelArg) @@ -164,21 +183,63 @@ private ITensorValueGetter[] GetTensorValueGetters(IRow input) return srcTensorGetters; } - private Delegate MakeGetter(IRow input) + private class OutputCache { - var type = TFTensor.TypeFromTensorType(_tfOutputType); - _host.Assert(type == _outputColType.ItemType.RawType); - return Utils.MarshalInvoke(MakeGetter, type, input, _outputColType); + public long Position; + public Dictionary Outputs; + public OutputCache() + { + Position = -1; + Outputs = new Dictionary(); + } } - private Delegate MakeGetter(IRow input, ColumnType columnType) + private Delegate[] MakeGetter(IRow input, Func activeOutput) { _host.AssertValue(input); - _host.Assert(typeof(T) == columnType.ItemType.RawType); - var srcTensorGetters = GetTensorValueGetters(input); + var outputCache = new OutputCache(); + var activeOutputColNames = _outputColNames.Where((x, i) => activeOutput(i)).ToArray(); + + var valueGetters = new List(); + for (int i = 0; i < _outputColNames.Length; i++) + { + if (activeOutput(i)) + { + var type = TFTensor.TypeFromTensorType(_tfOutputTypes[i]); + _host.Assert(type == _outputColTypes[i].ItemType.RawType); + var srcTensorGetters = GetTensorValueGetters(input); + valueGetters.Add(Utils.MarshalInvoke(MakeGetter, type, input, i, srcTensorGetters, activeOutputColNames, outputCache)); + } + } + return valueGetters.ToArray(); + } + private Delegate MakeGetter(IRow input, int iinfo, ITensorValueGetter[] srcTensorGetters, string[] activeOutputColNames, OutputCache outputCache) + { + _host.AssertValue(input); ValueGetter> valuegetter = (ref VBuffer dst) => + { + UpdateCacheIfNeeded(input.Position, srcTensorGetters, activeOutputColNames, outputCache); + + var values = dst.Values; + var indices = dst.Indices; + if (Utils.Size(values) < _outputColTypes[iinfo].VectorSize) + { + values = new T[_outputColTypes[iinfo].VectorSize]; + indices = new int[_outputColTypes[iinfo].VectorSize]; + } + + TensorFlowUtils.FetchData(outputCache.Outputs[_outputColNames[iinfo]].Data, values); + dst = new VBuffer(values.Length, values, indices); + }; + return valuegetter; + + } + + private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGetters, string[] activeOutputColNames, OutputCache outputCache) + { + if (outputCache.Position != position) { var runner = _session.GetRunner(); for (int i = 0; i < _inputColIndices.Length; i++) @@ -187,28 +248,24 @@ private Delegate MakeGetter(IRow input, ColumnType columnType) runner.AddInput(inputName, srcTensorGetters[i].GetTensor()); } - var tensors = runner.Fetch(_outputColName).Run(); - + var tensors = runner.Fetch(activeOutputColNames).Run(); Contracts.Assert(tensors.Length > 0); - var values = dst.Values; - if (Utils.Size(values) < _outputColType.VectorSize) - values = new T[_outputColType.VectorSize]; + for (int j = 0; j < tensors.Length; j++) + { + outputCache.Outputs[activeOutputColNames[j]] = tensors[j]; + } - TensorFlowUtils.FetchData(tensors[0].Data, values); - dst = new VBuffer(values.Length, values, dst.Indices); - }; - return valuegetter; + outputCache.Position = position; + } } public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) { - var getters = new Delegate[1]; disposer = null; using (var ch = _host.Start("CreateGetters")) { - if (activeOutput(0)) - getters[0] = MakeGetter(input); + var getters = MakeGetter(input, activeOutput); ch.Done(); return getters; } @@ -216,26 +273,37 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac public Func GetDependencies(Func activeOutput) { - return col => activeOutput(0) && _inputColIndices.Any(i => i == col); + return col => Enumerable.Range(0, _outputColNames.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col); } public RowMapperColumnInfo[] GetOutputColumns() { - return new[] { new RowMapperColumnInfo(_outputColName, _outputColType, null) }; + var info = new RowMapperColumnInfo[_outputColNames.Length]; + for (int i = 0; i < _outputColNames.Length; i++) + info[i] = new RowMapperColumnInfo(_outputColNames[i], _outputColTypes[i], null); + return info; } - private static (ColumnType, TFDataType) GetOutputTypes(TFGraph graph, string columnName) + private static (ColumnType[], TFDataType[]) GetOutputTypes(TFGraph graph, string[] columnNames) { Contracts.AssertValue(graph); - Contracts.AssertNonEmpty(columnName); - Contracts.AssertValue(graph[columnName]); + Contracts.AssertNonEmpty(columnNames); + Contracts.Assert(columnNames.All(name => graph[name] != null), "One of the output does not exist in the model"); - var tfoutput = new TFOutput(graph[columnName]); - var shape = graph.GetTensorShape(tfoutput); + var columnTypes = new ColumnType[columnNames.Length]; + var tfTypes = new TFDataType[columnNames.Length]; + for (int i = 0; i < columnNames.Length; i++) + { + var tfoutput = new TFOutput(graph[columnNames[i]]); + var shape = graph.GetTensorShape(tfoutput); - int[] dims = shape.ToIntArray().Skip(shape[0] == -1 ? BatchSize : 0).ToArray(); - var type = TensorFlowUtils.Tf2MlNetType(tfoutput.OutputType); - return (new VectorType(type, dims), tfoutput.OutputType); + int[] dims = shape.ToIntArray().Skip(shape[0] == -1 ? BatchSize : 0).ToArray(); + + var type = TensorFlowUtils.Tf2MlNetType(tfoutput.OutputType); + columnTypes[i] = new VectorType(type, dims); + tfTypes[i] = tfoutput.OutputType; + } + return (columnTypes, tfTypes); } private static (string[], int[], bool[], TFShape[], TFDataType[]) GetInputMetaData(TFGraph graph, string[] source, ISchema inputSchema) @@ -292,7 +360,7 @@ public sealed class Arguments : TransformInputBase public string[] InputColumns; [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the output", ShortName = "output", SortOrder = 2)] - public string OutputColumn; + public string[] OutputColumns; } public const string Summary = "Transforms the data using the TensorFlow model."; @@ -310,7 +378,20 @@ public sealed class Arguments : TransformInputBase /// Name of the input column(s). Keep it same as in the TensorFlow model. public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string name, params string[] source) { - return Create(env, new Arguments() { InputColumns = source, OutputColumn = name, ModelFile = modelFile }, input); + return Create(env, new Arguments() { InputColumns = source, OutputColumns = new[] { name }, ModelFile = modelFile }, input); + } + + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// This is the frozen tensorflow model file. https://www.tensorflow.org/mobile/prepare_models + /// Name of the output column(s). Keep it same as in the Tensorflow model. + /// Name of the input column(s). Keep it same as in the Tensorflow model. + public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] names, string[] source) + { + return Create(env, new Arguments() { InputColumns = source, OutputColumns = names, ModelFile = modelFile }, input); } /// @@ -322,11 +403,15 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV host.CheckUserArg(Utils.Size(args.InputColumns) > 0, nameof(args.InputColumns)); for (int i = 0; i < args.InputColumns.Length; i++) host.CheckNonWhiteSpace(args.InputColumns[i], nameof(args.InputColumns)); + for (int i = 0; i < args.OutputColumns.Length; i++) + host.CheckNonWhiteSpace(args.OutputColumns[i], nameof(args.OutputColumns)); + host.CheckUserArg(args.OutputColumns.Distinct().Count() == args.OutputColumns.Length, + nameof(args.OutputColumns), "Some of the output columns specified multiple times."); host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile)); host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile)); var modelBytes = File.ReadAllBytes(args.ModelFile); - var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumn); + var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumns); return new RowToRowMapperTransform(host, input, mapper); } diff --git a/src/Microsoft.ML.TensorFlow/doc.xml b/src/Microsoft.ML.TensorFlow/doc.xml index 63410cda02..329c75ca55 100644 --- a/src/Microsoft.ML.TensorFlow/doc.xml +++ b/src/Microsoft.ML.TensorFlow/doc.xml @@ -7,11 +7,11 @@ Extracts hidden layers' values from a pre-trained Tensorflow model. - The TensorflowTransform extracts the specified output from the operation computed on the graph (given the input(s)) using a pre-trained Tensorflow model. - The transform takes as input the Tensorflow model together with the names of the inputs to the model and name of the operation for which output values will be extracted from the model. + The TensorflowTransform extracts the specified outputs from the operations computed on the graph (given the input(s)) using a pre-trained Tensorflow model. + The transform takes as input the Tensorflow model together with the names of the inputs to the model and names of the operations for which output values will be extracted from the model. This transform requires the Microsoft.ML.TensorFlow nuget to be installed. - + The TensorflowTransform has following assumptions regarding the input, output and processing of data. @@ -19,15 +19,15 @@ The transform supports scoring only one example at a time. The name of input column(s) should match the name of input(s) in Tensorflow model. - The name of the output column should match one of the operations in the Tensorflow graph. + The name of each output column should match one of the operations in the Tensorflow graph. Currently, float and double are the only acceptable data types for input/output. - Upon success, the transform will introduce a new column in based on the name of the output column specified. + Upon success, the transform will introduce a new column in corresponding to each output column specified. - + The inputs and outputs of a TensorFlow model can be obtained using the summarize_graph tool. - + diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 66718150d1..b8d2c12bb4 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -15939,7 +15939,7 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints. /// /// The name of the output /// - public string OutputColumn { get; set; } + public string[] OutputColumns { get; set; } /// /// Input dataset diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 99db39b419..a13bc5eded 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -21919,8 +21919,11 @@ "IsNullable": false }, { - "Name": "OutputColumn", - "Type": "String", + "Name": "OutputColumns", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, "Desc": "The name of the output", "Aliases": [ "output" diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 24e53757a7..a567f5988a 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -978,7 +978,7 @@ public void TestTensorFlowEntryPoint() { Data = importOutput.Data, InputColumns = new[] { "Placeholder" }, - OutputColumn = "Softmax", + OutputColumns = new[] { "Softmax" }, ModelFile = "mnist_model/frozen_saved_model.pb" }; var tfTransformOutput = experiment.Add(tfTransformInput); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 0a1acadfd5..ded58f50bb 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -3742,7 +3742,7 @@ public void EntryPointTensorFlowTransform() { @"'InputColumns': [ 'Placeholder' ], 'ModelFile': 'mnist_model/frozen_saved_model.pb', - 'OutputColumn': 'Softmax'" + 'OutputColumns': [ 'Softmax' ]" }); } } diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs index affdbea72c..909d4c1a75 100644 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -50,7 +50,7 @@ public void TensorFlowTransformCifarLearningPipelineTest() { ModelFile = model_location, InputColumns = new[] { "Input" }, - OutputColumn = "Output" + OutputColumns = new[] { "Output" } }); pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "Output")); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 0bffb5c4d0..812c5e9a24 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -103,9 +103,13 @@ public void TensorFlowTransformMNISTConvTest() } }, new MultiFileSource(dataPath)); - IDataView trans = TensorFlowTransform.Create(env, loader, model_location, "Softmax", "Placeholder"); - trans = new ConcatTransform(env, trans, "reshape_input", "Placeholder"); - trans = TensorFlowTransform.Create(env, trans, model_location, "dense/Relu", "reshape_input"); + IDataView trans = CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments() + { + Column = new[] { new CopyColumnsTransform.Column() + { Name = "reshape_input", Source = "Placeholder" } + } + }, loader); + trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); trans = new ConcatTransform(env, trans, "Features", "Softmax", "dense/Relu"); var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments());