diff --git a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
index 517c6687c7..2b46731076 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
@@ -49,20 +49,20 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
/// The model is specified in the .
///
/// The transform's catalog.
- /// The specifying the inputs and the settings of the .
+ /// The specifying the inputs and the settings of the .
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
- TensorFlowTransformer.Options args)
- => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), args);
+ TensorFlowTransformer.Options options)
+ => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options);
///
/// Scores or retrains (based on setting of the ) a pre-traiend TensorFlow model specified via .
///
/// The transform's catalog.
- /// The specifying the inputs and the settings of the .
+ /// The specifying the inputs and the settings of the .
/// The pre-trained TensorFlow model.
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
- TensorFlowTransformer.Options args,
+ TensorFlowTransformer.Options options,
TensorFlowModelInfo tensorFlowModel)
- => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), args, tensorFlowModel);
+ => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel);
}
}
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index ad28f48525..452ca70f24 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -297,79 +297,79 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.InputColumns, nameof(args.InputColumns));
- env.CheckValue(args.OutputColumns, nameof(args.OutputColumns));
+ env.CheckValue(options.InputColumns, nameof(options.InputColumns));
+ env.CheckValue(options.OutputColumns, nameof(options.OutputColumns));
- return new TensorFlowTransformer(env, args, input).MakeDataTransform(input);
+ return new TensorFlowTransformer(env, options, input).MakeDataTransform(input);
}
- internal TensorFlowTransformer(IHostEnvironment env, Options args, IDataView input)
- : this(env, args, TensorFlowUtils.LoadTensorFlowModel(env, args.ModelLocation), input)
+ internal TensorFlowTransformer(IHostEnvironment env, Options options, IDataView input)
+ : this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation), input)
{
}
- internal TensorFlowTransformer(IHostEnvironment env, Options args, TensorFlowModelInfo tensorFlowModel, IDataView input)
- : this(env, tensorFlowModel.Session, args.OutputColumns, args.InputColumns, TensorFlowUtils.IsSavedModel(env, args.ModelLocation) ? args.ModelLocation : null, false)
+ internal TensorFlowTransformer(IHostEnvironment env, Options options, TensorFlowModelInfo tensorFlowModel, IDataView input)
+ : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, TensorFlowUtils.IsSavedModel(env, options.ModelLocation) ? options.ModelLocation : null, false)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
- if (args.ReTrain)
+ if (options.ReTrain)
{
env.CheckValue(input, nameof(input));
- CheckTrainingParameters(args);
+ CheckTrainingParameters(options);
- if (!TensorFlowUtils.IsSavedModel(env, args.ModelLocation))
+ if (!TensorFlowUtils.IsSavedModel(env, options.ModelLocation))
throw env.ExceptNotSupp("TensorFlowTransform: Re-Training of TensorFlow model is only supported for un-frozen model.");
- TrainCore(args, input);
+ TrainCore(options, input);
}
}
- private void CheckTrainingParameters(Options args)
+ private void CheckTrainingParameters(Options options)
{
- Host.CheckNonWhiteSpace(args.LabelColumn, nameof(args.LabelColumn));
- Host.CheckNonWhiteSpace(args.OptimizationOperation, nameof(args.OptimizationOperation));
- if (Session.Graph[args.OptimizationOperation] == null)
- throw Host.ExceptParam(nameof(args.OptimizationOperation), $"Optimization operation '{args.OptimizationOperation}' does not exist in the model");
+ Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn));
+ Host.CheckNonWhiteSpace(options.OptimizationOperation, nameof(options.OptimizationOperation));
+ if (Session.Graph[options.OptimizationOperation] == null)
+ throw Host.ExceptParam(nameof(options.OptimizationOperation), $"Optimization operation '{options.OptimizationOperation}' does not exist in the model");
- Host.CheckNonWhiteSpace(args.TensorFlowLabel, nameof(args.TensorFlowLabel));
- if (Session.Graph[args.TensorFlowLabel] == null)
- throw Host.ExceptParam(nameof(args.TensorFlowLabel), $"'{args.TensorFlowLabel}' does not exist in the model");
+ Host.CheckNonWhiteSpace(options.TensorFlowLabel, nameof(options.TensorFlowLabel));
+ if (Session.Graph[options.TensorFlowLabel] == null)
+ throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model");
- Host.CheckNonWhiteSpace(args.SaveLocationOperation, nameof(args.SaveLocationOperation));
- if (Session.Graph[args.SaveLocationOperation] == null)
- throw Host.ExceptParam(nameof(args.SaveLocationOperation), $"'{args.SaveLocationOperation}' does not exist in the model");
+ Host.CheckNonWhiteSpace(options.SaveLocationOperation, nameof(options.SaveLocationOperation));
+ if (Session.Graph[options.SaveLocationOperation] == null)
+ throw Host.ExceptParam(nameof(options.SaveLocationOperation), $"'{options.SaveLocationOperation}' does not exist in the model");
- Host.CheckNonWhiteSpace(args.SaveOperation, nameof(args.SaveOperation));
- if (Session.Graph[args.SaveOperation] == null)
- throw Host.ExceptParam(nameof(args.SaveOperation), $"'{args.SaveOperation}' does not exist in the model");
+ Host.CheckNonWhiteSpace(options.SaveOperation, nameof(options.SaveOperation));
+ if (Session.Graph[options.SaveOperation] == null)
+ throw Host.ExceptParam(nameof(options.SaveOperation), $"'{options.SaveOperation}' does not exist in the model");
- if (args.LossOperation != null)
+ if (options.LossOperation != null)
{
- Host.CheckNonWhiteSpace(args.LossOperation, nameof(args.LossOperation));
- if (Session.Graph[args.LossOperation] == null)
- throw Host.ExceptParam(nameof(args.LossOperation), $"'{args.LossOperation}' does not exist in the model");
+ Host.CheckNonWhiteSpace(options.LossOperation, nameof(options.LossOperation));
+ if (Session.Graph[options.LossOperation] == null)
+ throw Host.ExceptParam(nameof(options.LossOperation), $"'{options.LossOperation}' does not exist in the model");
}
- if (args.MetricOperation != null)
+ if (options.MetricOperation != null)
{
- Host.CheckNonWhiteSpace(args.MetricOperation, nameof(args.MetricOperation));
- if (Session.Graph[args.MetricOperation] == null)
- throw Host.ExceptParam(nameof(args.MetricOperation), $"'{args.MetricOperation}' does not exist in the model");
+ Host.CheckNonWhiteSpace(options.MetricOperation, nameof(options.MetricOperation));
+ if (Session.Graph[options.MetricOperation] == null)
+ throw Host.ExceptParam(nameof(options.MetricOperation), $"'{options.MetricOperation}' does not exist in the model");
}
- if (args.LearningRateOperation != null)
+ if (options.LearningRateOperation != null)
{
- Host.CheckNonWhiteSpace(args.LearningRateOperation, nameof(args.LearningRateOperation));
- if (Session.Graph[args.LearningRateOperation] == null)
- throw Host.ExceptParam(nameof(args.LearningRateOperation), $"'{args.LearningRateOperation}' does not exist in the model");
+ Host.CheckNonWhiteSpace(options.LearningRateOperation, nameof(options.LearningRateOperation));
+ if (Session.Graph[options.LearningRateOperation] == null)
+ throw Host.ExceptParam(nameof(options.LearningRateOperation), $"'{options.LearningRateOperation}' does not exist in the model");
}
}
@@ -401,7 +401,7 @@ private void CheckTrainingParameters(Options args)
return (inputColIndex, isInputVector, tfInputType, tfInputShape);
}
- private void TrainCore(Options args, IDataView input)
+ private void TrainCore(Options options, IDataView input)
{
var inputsForTraining = new string[Inputs.Length + 1];
var inputColIndices = new int[inputsForTraining.Length];
@@ -418,22 +418,22 @@ private void TrainCore(Options args, IDataView input)
for (int i = 0; i < inputsForTraining.Length - 1; i++)
{
(inputColIndices[i], isInputVector[i], tfInputTypes[i], tfInputShapes[i]) =
- GetTrainingInputInfo(inputSchema, inputsForTraining[i], inputsForTraining[i], args.BatchSize);
+ GetTrainingInputInfo(inputSchema, inputsForTraining[i], inputsForTraining[i], options.BatchSize);
}
var index = inputsForTraining.Length - 1;
- inputsForTraining[index] = args.TensorFlowLabel;
+ inputsForTraining[index] = options.TensorFlowLabel;
(inputColIndices[index], isInputVector[index], tfInputTypes[index], tfInputShapes[index]) =
- GetTrainingInputInfo(inputSchema, args.LabelColumn, inputsForTraining[index], args.BatchSize);
+ GetTrainingInputInfo(inputSchema, options.LabelColumn, inputsForTraining[index], options.BatchSize);
var fetchList = new List();
- if (args.LossOperation != null)
- fetchList.Add(args.LossOperation);
- if (args.MetricOperation != null)
- fetchList.Add(args.MetricOperation);
+ if (options.LossOperation != null)
+ fetchList.Add(options.LossOperation);
+ if (options.MetricOperation != null)
+ fetchList.Add(options.MetricOperation);
var cols = input.Schema.Where(c => inputColIndices.Contains(c.Index));
- for (int epoch = 0; epoch < args.Epoch; epoch++)
+ for (int epoch = 0; epoch < options.Epoch; epoch++)
{
using (var cursor = input.GetRowCursor(cols))
{
@@ -445,7 +445,7 @@ private void TrainCore(Options args, IDataView input)
using (var ch = Host.Start("Training TensorFlow model..."))
using (var pch = Host.StartProgressChannel("TensorFlow training progress..."))
{
- pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, args.Epoch));
+ pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));
while (cursor.MoveNext())
{
@@ -455,10 +455,10 @@ private void TrainCore(Options args, IDataView input)
srcTensorGetters[i].BufferTrainingData();
}
- if (((cursor.Position + 1) % args.BatchSize) == 0)
+ if (((cursor.Position + 1) % options.BatchSize) == 0)
{
isDataLeft = false;
- var (l, m) = TrainBatch(inputColIndices, inputsForTraining, srcTensorGetters, fetchList, args);
+ var (l, m) = TrainBatch(inputColIndices, inputsForTraining, srcTensorGetters, fetchList, options);
loss += l;
metric += m;
}
@@ -466,20 +466,20 @@ private void TrainCore(Options args, IDataView input)
if (isDataLeft)
{
isDataLeft = false;
- ch.Warning("Not training on the last batch. The batch size is less than {0}.", args.BatchSize);
+ ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
}
pch.Checkpoint(new double?[] { loss, metric });
}
}
}
- UpdateModelOnDisk(args.ModelLocation, args);
+ UpdateModelOnDisk(options.ModelLocation, options);
}
private (float loss, float metric) TrainBatch(int[] inputColIndices,
string[] inputsForTraining,
ITensorValueGetter[] srcTensorGetters,
List fetchList,
- Options args)
+ Options options)
{
float loss = 0;
float metric = 0;
@@ -490,9 +490,9 @@ private void TrainCore(Options args, IDataView input)
runner.AddInput(inputName, srcTensorGetters[i].GetBufferedBatchTensor());
}
- if (args.LearningRateOperation != null)
- runner.AddInput(args.LearningRateOperation, new TFTensor(args.LearningRate));
- runner.AddTarget(args.OptimizationOperation);
+ if (options.LearningRateOperation != null)
+ runner.AddInput(options.LearningRateOperation, new TFTensor(options.LearningRate));
+ runner.AddTarget(options.OptimizationOperation);
if (fetchList.Count > 0)
runner.Fetch(fetchList.ToArray());
@@ -509,14 +509,14 @@ private void TrainCore(Options args, IDataView input)
/// After retraining Session and Graphs are both up-to-date
/// However model on disk is not which is used to serialzed to ML.Net stream
///
- private void UpdateModelOnDisk(string modelDir, Options args)
+ private void UpdateModelOnDisk(string modelDir, Options options)
{
try
{
// Save the model on disk
var path = Path.Combine(modelDir, DefaultModelFileNames.TmpMlnetModel);
- Session.GetRunner().AddInput(args.SaveLocationOperation, TFTensor.CreateString(Encoding.UTF8.GetBytes(path)))
- .AddTarget(args.SaveOperation).Run();
+ Session.GetRunner().AddInput(options.SaveLocationOperation, TFTensor.CreateString(Encoding.UTF8.GetBytes(path)))
+ .AddTarget(options.SaveOperation).Run();
// Preserve original files
var variablesPath = Path.Combine(modelDir, DefaultModelFileNames.VariablesFolder);
@@ -1096,19 +1096,19 @@ internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, s
{
}
- internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options args)
- : this(env, args, TensorFlowUtils.LoadTensorFlowModel(env, args.ModelLocation))
+ internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options options)
+ : this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation))
{
}
- internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options args, TensorFlowModelInfo tensorFlowModel)
+ internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options options, TensorFlowModelInfo tensorFlowModel)
{
_host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowEstimator));
- _args = args;
+ _args = options;
_tensorFlowModel = tensorFlowModel;
- var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, args.InputColumns);
+ var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns);
_tfInputTypes = inputTuple.tfInputTypes;
- var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, args.OutputColumns);
+ var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns);
_outputTypes = outputTuple.outputTypes;
}