Skip to content

Commit 9c07a53

Browse files
committed
Addressed reviewers' comments.
1 parent 735f7ca commit 9c07a53

File tree

6 files changed

+105
-70
lines changed

6 files changed

+105
-70
lines changed

src/Microsoft.ML.Legacy/CSharpApi.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15804,7 +15804,12 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.
1580415804
/// <summary>
1580515805
/// Training labels.
1580615806
/// </summary>
15807-
public string LabeLColumn { get; set; } = "Label";
15807+
public string LabelColumn { get; set; }
15808+
15809+
/// <summary>
15810+
/// TensorFlow label node.
15811+
/// </summary>
15812+
public string TensorFlowLabel { get; set; }
1580815813

1580915814
/// <summary>
1581015815
/// The name of the optimization operation in the TensorFlow graph.

src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,13 +436,15 @@ internal static unsafe TFTensor CreateString(byte[] buffer)
436436
// Clear offset table
437437
IntPtr dst = TF_TensorData(handle);
438438
Marshal.WriteInt64(dst, 0);
439-
var status = new TFStatus();
440-
fixed (byte* src = &buffer[0])
439+
using (var status = new TFStatus())
441440
{
442-
TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(dst + 8), size, status.handle);
443-
var ok = status.StatusCode == TFCode.Ok;
444-
if (!ok)
445-
return null;
441+
fixed (byte* src = &buffer[0])
442+
{
443+
TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(dst + 8), size, status.handle);
444+
var ok = status.StatusCode == TFCode.Ok;
445+
if (!ok)
446+
return null;
447+
}
446448
}
447449
return new TFTensor(handle);
448450
}

src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ public TFStatus() : base(TF_NewStatus())
290290

291291
// extern void TF_DeleteStatus (TF_Status *);
292292
[DllImport(NativeBinding.TensorFlowLibrary)]
293-
internal static extern unsafe void TF_DeleteStatus(TF_Status status);
293+
private static extern unsafe void TF_DeleteStatus(TF_Status status);
294294

295295
internal override void NativeDispose(IntPtr handle)
296296
{
@@ -313,7 +313,7 @@ public void SetStatusCode(TFCode code, string msg)
313313

314314
// extern TF_Code TF_GetCode (const TF_Status *s);
315315
[DllImport(NativeBinding.TensorFlowLibrary)]
316-
internal static extern unsafe TFCode TF_GetCode(TF_Status s);
316+
private static extern unsafe TFCode TF_GetCode(TF_Status s);
317317

318318
/// <summary>
319319
/// Gets the status code for the status code.

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,59 +60,65 @@ public sealed class Arguments : TransformInputBase
6060
public string[] OutputColumns;
6161

6262
/// <summary>
63-
/// The name of the column used as label for training.
63+
/// The name of the label column in <see cref="IDataView"/> that will be mapped to label node in TensorFlow model.
6464
/// </summary>
6565
[Argument(ArgumentType.AtMostOnce, HelpText = "Training labels.", ShortName = "label", SortOrder = 4)]
66-
public string LabeLColumn = DefaultColumnNames.Label;
66+
public string LabelColumn;
67+
68+
/// <summary>
69+
/// The name of the label in TensorFlow model.
70+
/// </summary>
71+
[Argument(ArgumentType.AtMostOnce, HelpText = "TensorFlow label node.", ShortName = "TFLabel", SortOrder = 5)]
72+
public string TensorFlowLabel;
6773

6874
/// <summary>
6975
/// Name of the operation in TensorFlow graph that is used for optimizing parameters in the graph.
7076
/// Usually it is the name specified in the minimize method of optimizer in python
7177
/// e.g. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost, name = "SGDOptimizer").
7278
/// </summary>
73-
[Argument(ArgumentType.AtMostOnce, HelpText = "The name of the optimization operation in the TensorFlow graph.", ShortName = "OptimizationOp", SortOrder = 4)]
79+
[Argument(ArgumentType.AtMostOnce, HelpText = "The name of the optimization operation in the TensorFlow graph.", ShortName = "OptimizationOp", SortOrder = 6)]
7480
public string OptimizationOperation;
7581

7682
/// <summary>
7783
/// The name of the operation in the TensorFlow graph to compute training loss (Optional).
7884
/// </summary>
79-
[Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph to compute training loss (Optional)", ShortName = "LossOp", SortOrder = 5)]
85+
[Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph to compute training loss (Optional)", ShortName = "LossOp", SortOrder = 7)]
8086
public string LossOperation;
8187

8288
/// <summary>
8389
/// The name of the operation in the TensorFlow graph to compute performance metric during training (Optional).
8490
/// </summary>
85-
[Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph to compute performance metric during training (Optional)", ShortName = "MetricOp", SortOrder = 6)]
91+
[Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph to compute performance metric during training (Optional)", ShortName = "MetricOp", SortOrder = 8)]
8692
public string MetricOperation;
8793

8894
/// <summary>
8995
/// Number of samples to use for mini-batch training.
9096
/// </summary>
91-
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of samples to use for mini-batch training.", SortOrder = 7)]
97+
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of samples to use for mini-batch training.", SortOrder = 9)]
9298
public int BatchSize = 64;
9399

94100
/// <summary>
95101
/// Number of training iterations.
96102
/// </summary>
97-
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations.", SortOrder = 8)]
103+
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations.", SortOrder = 10)]
98104
public int Epoch = 5;
99105

100106
/// <summary>
101107
/// The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).
102108
/// </summary>
103-
[Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).", SortOrder = 9)]
109+
[Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).", SortOrder = 11)]
104110
public string LearningRateOperation;
105111

106112
/// <summary>
107113
/// Learning rate to use during optimization.
108114
/// </summary>
109-
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 10)]
115+
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)]
110116
public float LearningRate = 0.01f;
111117

112118
/// <summary>
113119
/// Shuffle training data on each iteration?
114120
/// </summary>
115-
[Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data before each iteration.", SortOrder = 11)]
121+
[Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data before each iteration.", SortOrder = 13)]
116122
public bool Shuffle = true;
117123

118124
/// <summary>
@@ -121,7 +127,7 @@ public sealed class Arguments : TransformInputBase
121127
/// Therefore, its highly unlikely that this parameter is changed from its default value of 'save/Const'.
122128
/// Please change it cautiously if you need to.
123129
/// </summary>
124-
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.", SortOrder = 12)]
130+
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.", SortOrder = 14)]
125131
public string SaveLocationOperation = "save/Const";
126132

127133
/// <summary>
@@ -130,13 +136,13 @@ public sealed class Arguments : TransformInputBase
130136
/// Therefore, its highly unlikely that this parameter is changed from its default value of 'save/control_dependency'.
131137
/// Please change it cautiously if you need to.
132138
/// </summary>
133-
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.", SortOrder = 13)]
139+
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.", SortOrder = 15)]
134140
public string SaveOperation = "save/control_dependency";
135141

136142
/// <summary>
137143
/// Needed for command line to specify if retraining is requested.
138144
/// </summary>
139-
[Argument(ArgumentType.AtMostOnce, HelpText = "Retrain TensorFlow model.", SortOrder = 15)]
145+
[Argument(ArgumentType.AtMostOnce, HelpText = "Retrain TensorFlow model.", SortOrder = 16)]
140146
public bool ReTrain = false;
141147
}
142148

@@ -300,9 +306,9 @@ private void CheckParameters(Arguments args)
300306
if (Session.Graph[args.OptimizationOperation] == null)
301307
throw _host.ExceptParam(nameof(args.OptimizationOperation), $"Optimization operation '{args.OptimizationOperation}' does not exist in the model");
302308

303-
_host.CheckNonWhiteSpace(args.LabeLColumn, nameof(args.LabeLColumn));
304-
if (Session.Graph[args.LabeLColumn] == null)
305-
throw _host.ExceptParam(nameof(args.LabeLColumn), $"'{args.LabeLColumn}' does not exist in the model");
309+
_host.CheckNonWhiteSpace(args.TensorFlowLabel, nameof(args.TensorFlowLabel));
310+
if (Session.Graph[args.TensorFlowLabel] == null)
311+
throw _host.ExceptParam(nameof(args.TensorFlowLabel), $"'{args.TensorFlowLabel}' does not exist in the model");
306312

307313
_host.CheckNonWhiteSpace(args.SaveLocationOperation, nameof(args.SaveLocationOperation));
308314
if (Session.Graph[args.SaveLocationOperation] == null)
@@ -334,6 +340,33 @@ private void CheckParameters(Arguments args)
334340
}
335341
}
336342

343+
private (int, bool, TFDataType, TFShape) GetInputMetaData(ISchema inputSchema, string columnName, string tfNodeName, int batchSize)
344+
{
345+
if (!inputSchema.TryGetColumnIndex(columnName, out int inputColIndices))
346+
throw _host.Except($"Column {columnName} doesn't exist");
347+
348+
var type = inputSchema.GetColumnType(inputColIndices);
349+
var isInputVector = type.IsVector;
350+
351+
var tfInput = new TFOutput(Graph[tfNodeName]);
352+
var tfInputType = tfInput.OutputType;
353+
var tfInputShape = Graph.GetTensorShape(tfInput);
354+
if (tfInputShape.NumDimensions != -1)
355+
{
356+
var newShape = new long[tfInputShape.NumDimensions];
357+
newShape[0] = tfInputShape[0] == -1 ? batchSize : tfInputShape[0];
358+
359+
for (int j = 1; j < tfInputShape.NumDimensions; j++)
360+
newShape[j] = tfInputShape[j];
361+
tfInputShape = new TFShape(newShape);
362+
}
363+
364+
var expectedType = TensorFlowUtils.Tf2MlNetType(tfInputType);
365+
if (type.ItemType != expectedType)
366+
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", columnName, expectedType.ToString(), type.ToString());
367+
368+
return (inputColIndices, isInputVector, tfInputType, tfInputShape);
369+
}
337370
private void TrainCore(Arguments args, string model, IDataView input)
338371
{
339372
var inputsForTraining = new string[Inputs.Length + 1];
@@ -347,35 +380,18 @@ private void TrainCore(Arguments args, string model, IDataView input)
347380
inputsForTraining[i] = Inputs[i];
348381
}
349382

350-
inputsForTraining[inputsForTraining.Length - 1] = args.LabeLColumn;
351-
352383
var inputSchema = input.Schema;
353-
for (int i = 0; i < inputsForTraining.Length; i++)
384+
for (int i = 0; i < inputsForTraining.Length - 1; i++)
354385
{
355-
if (!inputSchema.TryGetColumnIndex(inputsForTraining[i], out inputColIndices[i]))
356-
throw _host.Except($"Column {inputsForTraining[i]} doesn't exist");
357-
358-
var type = inputSchema.GetColumnType(inputColIndices[i]);
359-
isInputVector[i] = type.IsVector;
360-
361-
var tfInput = new TFOutput(Graph[inputsForTraining[i]]);
362-
tfInputTypes[i] = tfInput.OutputType;
363-
tfInputShapes[i] = Graph.GetTensorShape(tfInput);
364-
if (tfInputShapes[i].NumDimensions != -1)
365-
{
366-
var newShape = new long[tfInputShapes[i].NumDimensions];
367-
newShape[0] = tfInputShapes[i][0] == -1 ? args.BatchSize : tfInputShapes[i][0];
368-
369-
for (int j = 1; j < tfInputShapes[i].NumDimensions; j++)
370-
newShape[j] = tfInputShapes[i][j];
371-
tfInputShapes[i] = new TFShape(newShape);
372-
}
373-
374-
var expectedType = TensorFlowUtils.Tf2MlNetType(tfInputTypes[i]);
375-
if (type.ItemType != expectedType)
376-
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputsForTraining[i], expectedType.ToString(), type.ToString());
386+
(inputColIndices[i], isInputVector[i], tfInputTypes[i], tfInputShapes[i]) =
387+
GetInputMetaData(inputSchema, inputsForTraining[i], inputsForTraining[i],args.BatchSize);
377388
}
378389

390+
var index = inputsForTraining.Length - 1;
391+
inputsForTraining[index] = args.TensorFlowLabel;
392+
(inputColIndices[index], isInputVector[index], tfInputTypes[index], tfInputShapes[index]) =
393+
GetInputMetaData(inputSchema, args.LabelColumn, inputsForTraining[index], args.BatchSize);
394+
379395
var fetchList = new List<string>();
380396
if (args.LossOperation != null)
381397
fetchList.Add(args.LossOperation);

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21773,7 +21773,7 @@
2177321773
"IsNullable": false
2177421774
},
2177521775
{
21776-
"Name": "LabeLColumn",
21776+
"Name": "LabelColumn",
2177721777
"Type": "String",
2177821778
"Desc": "Training labels.",
2177921779
"Aliases": [
@@ -21782,7 +21782,19 @@
2178221782
"Required": false,
2178321783
"SortOrder": 4.0,
2178421784
"IsNullable": false,
21785-
"Default": "Label"
21785+
"Default": null
21786+
},
21787+
{
21788+
"Name": "TensorFlowLabel",
21789+
"Type": "String",
21790+
"Desc": "TensorFlow label node.",
21791+
"Aliases": [
21792+
"TFLabel"
21793+
],
21794+
"Required": false,
21795+
"SortOrder": 5.0,
21796+
"IsNullable": false,
21797+
"Default": null
2178621798
},
2178721799
{
2178821800
"Name": "OptimizationOperation",
@@ -21792,7 +21804,7 @@
2179221804
"OptimizationOp"
2179321805
],
2179421806
"Required": false,
21795-
"SortOrder": 4.0,
21807+
"SortOrder": 6.0,
2179621808
"IsNullable": false,
2179721809
"Default": null
2179821810
},
@@ -21804,7 +21816,7 @@
2180421816
"LossOp"
2180521817
],
2180621818
"Required": false,
21807-
"SortOrder": 5.0,
21819+
"SortOrder": 7.0,
2180821820
"IsNullable": false,
2180921821
"Default": null
2181021822
},
@@ -21816,7 +21828,7 @@
2181621828
"MetricOp"
2181721829
],
2181821830
"Required": false,
21819-
"SortOrder": 6.0,
21831+
"SortOrder": 8.0,
2182021832
"IsNullable": false,
2182121833
"Default": null
2182221834
},
@@ -21825,7 +21837,7 @@
2182521837
"Type": "Int",
2182621838
"Desc": "Number of samples to use for mini-batch training.",
2182721839
"Required": false,
21828-
"SortOrder": 7.0,
21840+
"SortOrder": 9.0,
2182921841
"IsNullable": false,
2183021842
"Default": 64
2183121843
},
@@ -21834,7 +21846,7 @@
2183421846
"Type": "Int",
2183521847
"Desc": "Number of training iterations.",
2183621848
"Required": false,
21837-
"SortOrder": 8.0,
21849+
"SortOrder": 10.0,
2183821850
"IsNullable": false,
2183921851
"Default": 5
2184021852
},
@@ -21843,7 +21855,7 @@
2184321855
"Type": "String",
2184421856
"Desc": "The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).",
2184521857
"Required": false,
21846-
"SortOrder": 9.0,
21858+
"SortOrder": 11.0,
2184721859
"IsNullable": false,
2184821860
"Default": null
2184921861
},
@@ -21852,7 +21864,7 @@
2185221864
"Type": "Float",
2185321865
"Desc": "Learning rate to use during optimization.",
2185421866
"Required": false,
21855-
"SortOrder": 10.0,
21867+
"SortOrder": 12.0,
2185621868
"IsNullable": false,
2185721869
"Default": 0.01
2185821870
},
@@ -21861,7 +21873,7 @@
2186121873
"Type": "Bool",
2186221874
"Desc": "Shuffle data before each iteration.",
2186321875
"Required": false,
21864-
"SortOrder": 11.0,
21876+
"SortOrder": 13.0,
2186521877
"IsNullable": false,
2186621878
"Default": true
2186721879
},
@@ -21870,7 +21882,7 @@
2187021882
"Type": "String",
2187121883
"Desc": "Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.",
2187221884
"Required": false,
21873-
"SortOrder": 12.0,
21885+
"SortOrder": 14.0,
2187421886
"IsNullable": false,
2187521887
"Default": "save/Const"
2187621888
},
@@ -21879,7 +21891,7 @@
2187921891
"Type": "String",
2188021892
"Desc": "Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.",
2188121893
"Required": false,
21882-
"SortOrder": 13.0,
21894+
"SortOrder": 15.0,
2188321895
"IsNullable": false,
2188421896
"Default": "save/control_dependency"
2188521897
},
@@ -21888,7 +21900,7 @@
2188821900
"Type": "Bool",
2188921901
"Desc": "Retrain TensorFlow model.",
2189021902
"Required": false,
21891-
"SortOrder": 15.0,
21903+
"SortOrder": 16.0,
2189221904
"IsNullable": false,
2189321905
"Default": false
2189421906
}

0 commit comments

Comments
 (0)