Skip to content

Commit 2dfc51a

Browse files
authored
Fix naming of Options argument in TensorFlowTransform API
1 parent e0e36af commit 2dfc51a

File tree

2 files changed

+73
-73
lines changed

2 files changed

+73
-73
lines changed

src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,20 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
4949
/// The model is specified in the <see cref="TensorFlowTransformer.Options.ModelLocation"/>.
5050
/// </summary>
5151
/// <param name="catalog">The transform's catalog.</param>
52-
/// <param name="args">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
52+
/// <param name="options">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
5353
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
54-
TensorFlowTransformer.Options args)
55-
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), args);
54+
TensorFlowTransformer.Options options)
55+
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options);
5656

5757
/// <summary>
5858
/// Scores or retrains (based on setting of the <see cref="TensorFlowTransformer.Options.ReTrain"/>) a pre-traiend TensorFlow model specified via <paramref name="tensorFlowModel"/>.
5959
/// </summary>
6060
/// <param name="catalog">The transform's catalog.</param>
61-
/// <param name="args">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
61+
/// <param name="options">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
6262
/// <param name="tensorFlowModel">The pre-trained TensorFlow model.</param>
6363
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
64-
TensorFlowTransformer.Options args,
64+
TensorFlowTransformer.Options options,
6565
TensorFlowModelInfo tensorFlowModel)
66-
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), args, tensorFlowModel);
66+
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel);
6767
}
6868
}

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -297,79 +297,79 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
297297
}
298298

299299
// Factory method for SignatureDataTransform.
300-
internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input)
300+
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
301301
{
302302
Contracts.CheckValue(env, nameof(env));
303-
env.CheckValue(args, nameof(args));
303+
env.CheckValue(options, nameof(options));
304304
env.CheckValue(input, nameof(input));
305-
env.CheckValue(args.InputColumns, nameof(args.InputColumns));
306-
env.CheckValue(args.OutputColumns, nameof(args.OutputColumns));
305+
env.CheckValue(options.InputColumns, nameof(options.InputColumns));
306+
env.CheckValue(options.OutputColumns, nameof(options.OutputColumns));
307307

308-
return new TensorFlowTransformer(env, args, input).MakeDataTransform(input);
308+
return new TensorFlowTransformer(env, options, input).MakeDataTransform(input);
309309
}
310310

311-
internal TensorFlowTransformer(IHostEnvironment env, Options args, IDataView input)
312-
: this(env, args, TensorFlowUtils.LoadTensorFlowModel(env, args.ModelLocation), input)
311+
internal TensorFlowTransformer(IHostEnvironment env, Options options, IDataView input)
312+
: this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation), input)
313313
{
314314
}
315315

316-
internal TensorFlowTransformer(IHostEnvironment env, Options args, TensorFlowModelInfo tensorFlowModel, IDataView input)
317-
: this(env, tensorFlowModel.Session, args.OutputColumns, args.InputColumns, TensorFlowUtils.IsSavedModel(env, args.ModelLocation) ? args.ModelLocation : null, false)
316+
internal TensorFlowTransformer(IHostEnvironment env, Options options, TensorFlowModelInfo tensorFlowModel, IDataView input)
317+
: this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, TensorFlowUtils.IsSavedModel(env, options.ModelLocation) ? options.ModelLocation : null, false)
318318
{
319319

320320
Contracts.CheckValue(env, nameof(env));
321-
env.CheckValue(args, nameof(args));
321+
env.CheckValue(options, nameof(options));
322322

323-
if (args.ReTrain)
323+
if (options.ReTrain)
324324
{
325325
env.CheckValue(input, nameof(input));
326326

327-
CheckTrainingParameters(args);
327+
CheckTrainingParameters(options);
328328

329-
if (!TensorFlowUtils.IsSavedModel(env, args.ModelLocation))
329+
if (!TensorFlowUtils.IsSavedModel(env, options.ModelLocation))
330330
throw env.ExceptNotSupp("TensorFlowTransform: Re-Training of TensorFlow model is only supported for un-frozen model.");
331-
TrainCore(args, input);
331+
TrainCore(options, input);
332332
}
333333
}
334334

335-
private void CheckTrainingParameters(Options args)
335+
private void CheckTrainingParameters(Options options)
336336
{
337-
Host.CheckNonWhiteSpace(args.LabelColumn, nameof(args.LabelColumn));
338-
Host.CheckNonWhiteSpace(args.OptimizationOperation, nameof(args.OptimizationOperation));
339-
if (Session.Graph[args.OptimizationOperation] == null)
340-
throw Host.ExceptParam(nameof(args.OptimizationOperation), $"Optimization operation '{args.OptimizationOperation}' does not exist in the model");
337+
Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn));
338+
Host.CheckNonWhiteSpace(options.OptimizationOperation, nameof(options.OptimizationOperation));
339+
if (Session.Graph[options.OptimizationOperation] == null)
340+
throw Host.ExceptParam(nameof(options.OptimizationOperation), $"Optimization operation '{options.OptimizationOperation}' does not exist in the model");
341341

342-
Host.CheckNonWhiteSpace(args.TensorFlowLabel, nameof(args.TensorFlowLabel));
343-
if (Session.Graph[args.TensorFlowLabel] == null)
344-
throw Host.ExceptParam(nameof(args.TensorFlowLabel), $"'{args.TensorFlowLabel}' does not exist in the model");
342+
Host.CheckNonWhiteSpace(options.TensorFlowLabel, nameof(options.TensorFlowLabel));
343+
if (Session.Graph[options.TensorFlowLabel] == null)
344+
throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model");
345345

346-
Host.CheckNonWhiteSpace(args.SaveLocationOperation, nameof(args.SaveLocationOperation));
347-
if (Session.Graph[args.SaveLocationOperation] == null)
348-
throw Host.ExceptParam(nameof(args.SaveLocationOperation), $"'{args.SaveLocationOperation}' does not exist in the model");
346+
Host.CheckNonWhiteSpace(options.SaveLocationOperation, nameof(options.SaveLocationOperation));
347+
if (Session.Graph[options.SaveLocationOperation] == null)
348+
throw Host.ExceptParam(nameof(options.SaveLocationOperation), $"'{options.SaveLocationOperation}' does not exist in the model");
349349

350-
Host.CheckNonWhiteSpace(args.SaveOperation, nameof(args.SaveOperation));
351-
if (Session.Graph[args.SaveOperation] == null)
352-
throw Host.ExceptParam(nameof(args.SaveOperation), $"'{args.SaveOperation}' does not exist in the model");
350+
Host.CheckNonWhiteSpace(options.SaveOperation, nameof(options.SaveOperation));
351+
if (Session.Graph[options.SaveOperation] == null)
352+
throw Host.ExceptParam(nameof(options.SaveOperation), $"'{options.SaveOperation}' does not exist in the model");
353353

354-
if (args.LossOperation != null)
354+
if (options.LossOperation != null)
355355
{
356-
Host.CheckNonWhiteSpace(args.LossOperation, nameof(args.LossOperation));
357-
if (Session.Graph[args.LossOperation] == null)
358-
throw Host.ExceptParam(nameof(args.LossOperation), $"'{args.LossOperation}' does not exist in the model");
356+
Host.CheckNonWhiteSpace(options.LossOperation, nameof(options.LossOperation));
357+
if (Session.Graph[options.LossOperation] == null)
358+
throw Host.ExceptParam(nameof(options.LossOperation), $"'{options.LossOperation}' does not exist in the model");
359359
}
360360

361-
if (args.MetricOperation != null)
361+
if (options.MetricOperation != null)
362362
{
363-
Host.CheckNonWhiteSpace(args.MetricOperation, nameof(args.MetricOperation));
364-
if (Session.Graph[args.MetricOperation] == null)
365-
throw Host.ExceptParam(nameof(args.MetricOperation), $"'{args.MetricOperation}' does not exist in the model");
363+
Host.CheckNonWhiteSpace(options.MetricOperation, nameof(options.MetricOperation));
364+
if (Session.Graph[options.MetricOperation] == null)
365+
throw Host.ExceptParam(nameof(options.MetricOperation), $"'{options.MetricOperation}' does not exist in the model");
366366
}
367367

368-
if (args.LearningRateOperation != null)
368+
if (options.LearningRateOperation != null)
369369
{
370-
Host.CheckNonWhiteSpace(args.LearningRateOperation, nameof(args.LearningRateOperation));
371-
if (Session.Graph[args.LearningRateOperation] == null)
372-
throw Host.ExceptParam(nameof(args.LearningRateOperation), $"'{args.LearningRateOperation}' does not exist in the model");
370+
Host.CheckNonWhiteSpace(options.LearningRateOperation, nameof(options.LearningRateOperation));
371+
if (Session.Graph[options.LearningRateOperation] == null)
372+
throw Host.ExceptParam(nameof(options.LearningRateOperation), $"'{options.LearningRateOperation}' does not exist in the model");
373373
}
374374
}
375375

@@ -401,7 +401,7 @@ private void CheckTrainingParameters(Options args)
401401
return (inputColIndex, isInputVector, tfInputType, tfInputShape);
402402
}
403403

404-
private void TrainCore(Options args, IDataView input)
404+
private void TrainCore(Options options, IDataView input)
405405
{
406406
var inputsForTraining = new string[Inputs.Length + 1];
407407
var inputColIndices = new int[inputsForTraining.Length];
@@ -418,22 +418,22 @@ private void TrainCore(Options args, IDataView input)
418418
for (int i = 0; i < inputsForTraining.Length - 1; i++)
419419
{
420420
(inputColIndices[i], isInputVector[i], tfInputTypes[i], tfInputShapes[i]) =
421-
GetTrainingInputInfo(inputSchema, inputsForTraining[i], inputsForTraining[i], args.BatchSize);
421+
GetTrainingInputInfo(inputSchema, inputsForTraining[i], inputsForTraining[i], options.BatchSize);
422422
}
423423

424424
var index = inputsForTraining.Length - 1;
425-
inputsForTraining[index] = args.TensorFlowLabel;
425+
inputsForTraining[index] = options.TensorFlowLabel;
426426
(inputColIndices[index], isInputVector[index], tfInputTypes[index], tfInputShapes[index]) =
427-
GetTrainingInputInfo(inputSchema, args.LabelColumn, inputsForTraining[index], args.BatchSize);
427+
GetTrainingInputInfo(inputSchema, options.LabelColumn, inputsForTraining[index], options.BatchSize);
428428

429429
var fetchList = new List<string>();
430-
if (args.LossOperation != null)
431-
fetchList.Add(args.LossOperation);
432-
if (args.MetricOperation != null)
433-
fetchList.Add(args.MetricOperation);
430+
if (options.LossOperation != null)
431+
fetchList.Add(options.LossOperation);
432+
if (options.MetricOperation != null)
433+
fetchList.Add(options.MetricOperation);
434434

435435
var cols = input.Schema.Where(c => inputColIndices.Contains(c.Index));
436-
for (int epoch = 0; epoch < args.Epoch; epoch++)
436+
for (int epoch = 0; epoch < options.Epoch; epoch++)
437437
{
438438
using (var cursor = input.GetRowCursor(cols))
439439
{
@@ -445,7 +445,7 @@ private void TrainCore(Options args, IDataView input)
445445
using (var ch = Host.Start("Training TensorFlow model..."))
446446
using (var pch = Host.StartProgressChannel("TensorFlow training progress..."))
447447
{
448-
pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, args.Epoch));
448+
pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));
449449

450450
while (cursor.MoveNext())
451451
{
@@ -455,31 +455,31 @@ private void TrainCore(Options args, IDataView input)
455455
srcTensorGetters[i].BufferTrainingData();
456456
}
457457

458-
if (((cursor.Position + 1) % args.BatchSize) == 0)
458+
if (((cursor.Position + 1) % options.BatchSize) == 0)
459459
{
460460
isDataLeft = false;
461-
var (l, m) = TrainBatch(inputColIndices, inputsForTraining, srcTensorGetters, fetchList, args);
461+
var (l, m) = TrainBatch(inputColIndices, inputsForTraining, srcTensorGetters, fetchList, options);
462462
loss += l;
463463
metric += m;
464464
}
465465
}
466466
if (isDataLeft)
467467
{
468468
isDataLeft = false;
469-
ch.Warning("Not training on the last batch. The batch size is less than {0}.", args.BatchSize);
469+
ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
470470
}
471471
pch.Checkpoint(new double?[] { loss, metric });
472472
}
473473
}
474474
}
475-
UpdateModelOnDisk(args.ModelLocation, args);
475+
UpdateModelOnDisk(options.ModelLocation, options);
476476
}
477477

478478
private (float loss, float metric) TrainBatch(int[] inputColIndices,
479479
string[] inputsForTraining,
480480
ITensorValueGetter[] srcTensorGetters,
481481
List<string> fetchList,
482-
Options args)
482+
Options options)
483483
{
484484
float loss = 0;
485485
float metric = 0;
@@ -490,9 +490,9 @@ private void TrainCore(Options args, IDataView input)
490490
runner.AddInput(inputName, srcTensorGetters[i].GetBufferedBatchTensor());
491491
}
492492

493-
if (args.LearningRateOperation != null)
494-
runner.AddInput(args.LearningRateOperation, new TFTensor(args.LearningRate));
495-
runner.AddTarget(args.OptimizationOperation);
493+
if (options.LearningRateOperation != null)
494+
runner.AddInput(options.LearningRateOperation, new TFTensor(options.LearningRate));
495+
runner.AddTarget(options.OptimizationOperation);
496496

497497
if (fetchList.Count > 0)
498498
runner.Fetch(fetchList.ToArray());
@@ -509,14 +509,14 @@ private void TrainCore(Options args, IDataView input)
509509
/// After retraining Session and Graphs are both up-to-date
510510
/// However model on disk is not which is used to serialzed to ML.Net stream
511511
/// </summary>
512-
private void UpdateModelOnDisk(string modelDir, Options args)
512+
private void UpdateModelOnDisk(string modelDir, Options options)
513513
{
514514
try
515515
{
516516
// Save the model on disk
517517
var path = Path.Combine(modelDir, DefaultModelFileNames.TmpMlnetModel);
518-
Session.GetRunner().AddInput(args.SaveLocationOperation, TFTensor.CreateString(Encoding.UTF8.GetBytes(path)))
519-
.AddTarget(args.SaveOperation).Run();
518+
Session.GetRunner().AddInput(options.SaveLocationOperation, TFTensor.CreateString(Encoding.UTF8.GetBytes(path)))
519+
.AddTarget(options.SaveOperation).Run();
520520

521521
// Preserve original files
522522
var variablesPath = Path.Combine(modelDir, DefaultModelFileNames.VariablesFolder);
@@ -1096,19 +1096,19 @@ internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, s
10961096
{
10971097
}
10981098

1099-
internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options args)
1100-
: this(env, args, TensorFlowUtils.LoadTensorFlowModel(env, args.ModelLocation))
1099+
internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options options)
1100+
: this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation))
11011101
{
11021102
}
11031103

1104-
internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options args, TensorFlowModelInfo tensorFlowModel)
1104+
internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options options, TensorFlowModelInfo tensorFlowModel)
11051105
{
11061106
_host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowEstimator));
1107-
_args = args;
1107+
_args = options;
11081108
_tensorFlowModel = tensorFlowModel;
1109-
var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, args.InputColumns);
1109+
var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns);
11101110
_tfInputTypes = inputTuple.tfInputTypes;
1111-
var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, args.OutputColumns);
1111+
var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns);
11121112
_outputTypes = outputTuple.outputTypes;
11131113
}
11141114

0 commit comments

Comments
 (0)