Skip to content

Fix naming of Options argument in TensorFlowTransform API #2392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,20 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
/// The model is specified in the <see cref="TensorFlowTransformer.Options.ModelLocation"/>.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="args">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
/// <param name="options">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
TensorFlowTransformer.Options args)
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), args);
TensorFlowTransformer.Options options)
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options);

/// <summary>
/// Scores or retrains (based on setting of the <see cref="TensorFlowTransformer.Options.ReTrain"/>) a pre-traiend TensorFlow model specified via <paramref name="tensorFlowModel"/>.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="args">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
/// <param name="options">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
/// <param name="tensorFlowModel">The pre-trained TensorFlow model.</param>
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
TensorFlowTransformer.Options args,
TensorFlowTransformer.Options options,
TensorFlowModelInfo tensorFlowModel)
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), args, tensorFlowModel);
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel);
}
}
134 changes: 67 additions & 67 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,79 +297,79 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
}

// Factory method for SignatureDataTransform.
internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input)
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
env.CheckValue(args.InputColumns, nameof(args.InputColumns));
env.CheckValue(args.OutputColumns, nameof(args.OutputColumns));
env.CheckValue(options.InputColumns, nameof(options.InputColumns));
env.CheckValue(options.OutputColumns, nameof(options.OutputColumns));

return new TensorFlowTransformer(env, args, input).MakeDataTransform(input);
return new TensorFlowTransformer(env, options, input).MakeDataTransform(input);
}

internal TensorFlowTransformer(IHostEnvironment env, Options args, IDataView input)
: this(env, args, TensorFlowUtils.LoadTensorFlowModel(env, args.ModelLocation), input)
internal TensorFlowTransformer(IHostEnvironment env, Options options, IDataView input)
: this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation), input)
{
}

internal TensorFlowTransformer(IHostEnvironment env, Options args, TensorFlowModelInfo tensorFlowModel, IDataView input)
: this(env, tensorFlowModel.Session, args.OutputColumns, args.InputColumns, TensorFlowUtils.IsSavedModel(env, args.ModelLocation) ? args.ModelLocation : null, false)
internal TensorFlowTransformer(IHostEnvironment env, Options options, TensorFlowModelInfo tensorFlowModel, IDataView input)
: this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, TensorFlowUtils.IsSavedModel(env, options.ModelLocation) ? options.ModelLocation : null, false)
{

Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
env.CheckValue(options, nameof(options));

if (args.ReTrain)
if (options.ReTrain)
{
env.CheckValue(input, nameof(input));

CheckTrainingParameters(args);
CheckTrainingParameters(options);

if (!TensorFlowUtils.IsSavedModel(env, args.ModelLocation))
if (!TensorFlowUtils.IsSavedModel(env, options.ModelLocation))
throw env.ExceptNotSupp("TensorFlowTransform: Re-Training of TensorFlow model is only supported for un-frozen model.");
TrainCore(args, input);
TrainCore(options, input);
}
}

private void CheckTrainingParameters(Options args)
private void CheckTrainingParameters(Options options)
{
Host.CheckNonWhiteSpace(args.LabelColumn, nameof(args.LabelColumn));
Host.CheckNonWhiteSpace(args.OptimizationOperation, nameof(args.OptimizationOperation));
if (Session.Graph[args.OptimizationOperation] == null)
throw Host.ExceptParam(nameof(args.OptimizationOperation), $"Optimization operation '{args.OptimizationOperation}' does not exist in the model");
Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn));
Host.CheckNonWhiteSpace(options.OptimizationOperation, nameof(options.OptimizationOperation));
if (Session.Graph[options.OptimizationOperation] == null)
throw Host.ExceptParam(nameof(options.OptimizationOperation), $"Optimization operation '{options.OptimizationOperation}' does not exist in the model");

Host.CheckNonWhiteSpace(args.TensorFlowLabel, nameof(args.TensorFlowLabel));
if (Session.Graph[args.TensorFlowLabel] == null)
throw Host.ExceptParam(nameof(args.TensorFlowLabel), $"'{args.TensorFlowLabel}' does not exist in the model");
Host.CheckNonWhiteSpace(options.TensorFlowLabel, nameof(options.TensorFlowLabel));
if (Session.Graph[options.TensorFlowLabel] == null)
throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model");

Host.CheckNonWhiteSpace(args.SaveLocationOperation, nameof(args.SaveLocationOperation));
if (Session.Graph[args.SaveLocationOperation] == null)
throw Host.ExceptParam(nameof(args.SaveLocationOperation), $"'{args.SaveLocationOperation}' does not exist in the model");
Host.CheckNonWhiteSpace(options.SaveLocationOperation, nameof(options.SaveLocationOperation));
if (Session.Graph[options.SaveLocationOperation] == null)
throw Host.ExceptParam(nameof(options.SaveLocationOperation), $"'{options.SaveLocationOperation}' does not exist in the model");

Host.CheckNonWhiteSpace(args.SaveOperation, nameof(args.SaveOperation));
if (Session.Graph[args.SaveOperation] == null)
throw Host.ExceptParam(nameof(args.SaveOperation), $"'{args.SaveOperation}' does not exist in the model");
Host.CheckNonWhiteSpace(options.SaveOperation, nameof(options.SaveOperation));
if (Session.Graph[options.SaveOperation] == null)
throw Host.ExceptParam(nameof(options.SaveOperation), $"'{options.SaveOperation}' does not exist in the model");

if (args.LossOperation != null)
if (options.LossOperation != null)
{
Host.CheckNonWhiteSpace(args.LossOperation, nameof(args.LossOperation));
if (Session.Graph[args.LossOperation] == null)
throw Host.ExceptParam(nameof(args.LossOperation), $"'{args.LossOperation}' does not exist in the model");
Host.CheckNonWhiteSpace(options.LossOperation, nameof(options.LossOperation));
if (Session.Graph[options.LossOperation] == null)
throw Host.ExceptParam(nameof(options.LossOperation), $"'{options.LossOperation}' does not exist in the model");
}

if (args.MetricOperation != null)
if (options.MetricOperation != null)
{
Host.CheckNonWhiteSpace(args.MetricOperation, nameof(args.MetricOperation));
if (Session.Graph[args.MetricOperation] == null)
throw Host.ExceptParam(nameof(args.MetricOperation), $"'{args.MetricOperation}' does not exist in the model");
Host.CheckNonWhiteSpace(options.MetricOperation, nameof(options.MetricOperation));
if (Session.Graph[options.MetricOperation] == null)
throw Host.ExceptParam(nameof(options.MetricOperation), $"'{options.MetricOperation}' does not exist in the model");
}

if (args.LearningRateOperation != null)
if (options.LearningRateOperation != null)
{
Host.CheckNonWhiteSpace(args.LearningRateOperation, nameof(args.LearningRateOperation));
if (Session.Graph[args.LearningRateOperation] == null)
throw Host.ExceptParam(nameof(args.LearningRateOperation), $"'{args.LearningRateOperation}' does not exist in the model");
Host.CheckNonWhiteSpace(options.LearningRateOperation, nameof(options.LearningRateOperation));
if (Session.Graph[options.LearningRateOperation] == null)
throw Host.ExceptParam(nameof(options.LearningRateOperation), $"'{options.LearningRateOperation}' does not exist in the model");
}
}

Expand Down Expand Up @@ -401,7 +401,7 @@ private void CheckTrainingParameters(Options args)
return (inputColIndex, isInputVector, tfInputType, tfInputShape);
}

private void TrainCore(Options args, IDataView input)
private void TrainCore(Options options, IDataView input)
{
var inputsForTraining = new string[Inputs.Length + 1];
var inputColIndices = new int[inputsForTraining.Length];
Expand All @@ -418,22 +418,22 @@ private void TrainCore(Options args, IDataView input)
for (int i = 0; i < inputsForTraining.Length - 1; i++)
{
(inputColIndices[i], isInputVector[i], tfInputTypes[i], tfInputShapes[i]) =
GetTrainingInputInfo(inputSchema, inputsForTraining[i], inputsForTraining[i], args.BatchSize);
GetTrainingInputInfo(inputSchema, inputsForTraining[i], inputsForTraining[i], options.BatchSize);
}

var index = inputsForTraining.Length - 1;
inputsForTraining[index] = args.TensorFlowLabel;
inputsForTraining[index] = options.TensorFlowLabel;
(inputColIndices[index], isInputVector[index], tfInputTypes[index], tfInputShapes[index]) =
GetTrainingInputInfo(inputSchema, args.LabelColumn, inputsForTraining[index], args.BatchSize);
GetTrainingInputInfo(inputSchema, options.LabelColumn, inputsForTraining[index], options.BatchSize);

var fetchList = new List<string>();
if (args.LossOperation != null)
fetchList.Add(args.LossOperation);
if (args.MetricOperation != null)
fetchList.Add(args.MetricOperation);
if (options.LossOperation != null)
fetchList.Add(options.LossOperation);
if (options.MetricOperation != null)
fetchList.Add(options.MetricOperation);

var cols = input.Schema.Where(c => inputColIndices.Contains(c.Index));
for (int epoch = 0; epoch < args.Epoch; epoch++)
for (int epoch = 0; epoch < options.Epoch; epoch++)
{
using (var cursor = input.GetRowCursor(cols))
{
Expand All @@ -445,7 +445,7 @@ private void TrainCore(Options args, IDataView input)
using (var ch = Host.Start("Training TensorFlow model..."))
using (var pch = Host.StartProgressChannel("TensorFlow training progress..."))
{
pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, args.Epoch));
pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));

while (cursor.MoveNext())
{
Expand All @@ -455,31 +455,31 @@ private void TrainCore(Options args, IDataView input)
srcTensorGetters[i].BufferTrainingData();
}

if (((cursor.Position + 1) % args.BatchSize) == 0)
if (((cursor.Position + 1) % options.BatchSize) == 0)
{
isDataLeft = false;
var (l, m) = TrainBatch(inputColIndices, inputsForTraining, srcTensorGetters, fetchList, args);
var (l, m) = TrainBatch(inputColIndices, inputsForTraining, srcTensorGetters, fetchList, options);
loss += l;
metric += m;
}
}
if (isDataLeft)
{
isDataLeft = false;
ch.Warning("Not training on the last batch. The batch size is less than {0}.", args.BatchSize);
ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
}
pch.Checkpoint(new double?[] { loss, metric });
}
}
}
UpdateModelOnDisk(args.ModelLocation, args);
UpdateModelOnDisk(options.ModelLocation, options);
}

private (float loss, float metric) TrainBatch(int[] inputColIndices,
string[] inputsForTraining,
ITensorValueGetter[] srcTensorGetters,
List<string> fetchList,
Options args)
Options options)
{
float loss = 0;
float metric = 0;
Expand All @@ -490,9 +490,9 @@ private void TrainCore(Options args, IDataView input)
runner.AddInput(inputName, srcTensorGetters[i].GetBufferedBatchTensor());
}

if (args.LearningRateOperation != null)
runner.AddInput(args.LearningRateOperation, new TFTensor(args.LearningRate));
runner.AddTarget(args.OptimizationOperation);
if (options.LearningRateOperation != null)
runner.AddInput(options.LearningRateOperation, new TFTensor(options.LearningRate));
runner.AddTarget(options.OptimizationOperation);

if (fetchList.Count > 0)
runner.Fetch(fetchList.ToArray());
Expand All @@ -509,14 +509,14 @@ private void TrainCore(Options args, IDataView input)
/// After retraining Session and Graphs are both up-to-date
/// However model on disk is not which is used to serialzed to ML.Net stream
/// </summary>
private void UpdateModelOnDisk(string modelDir, Options args)
private void UpdateModelOnDisk(string modelDir, Options options)
{
try
{
// Save the model on disk
var path = Path.Combine(modelDir, DefaultModelFileNames.TmpMlnetModel);
Session.GetRunner().AddInput(args.SaveLocationOperation, TFTensor.CreateString(Encoding.UTF8.GetBytes(path)))
.AddTarget(args.SaveOperation).Run();
Session.GetRunner().AddInput(options.SaveLocationOperation, TFTensor.CreateString(Encoding.UTF8.GetBytes(path)))
.AddTarget(options.SaveOperation).Run();

// Preserve original files
var variablesPath = Path.Combine(modelDir, DefaultModelFileNames.VariablesFolder);
Expand Down Expand Up @@ -1096,19 +1096,19 @@ internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, s
{
}

internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options args)
: this(env, args, TensorFlowUtils.LoadTensorFlowModel(env, args.ModelLocation))
internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options options)
: this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation))
{
}

internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options args, TensorFlowModelInfo tensorFlowModel)
internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options options, TensorFlowModelInfo tensorFlowModel)
{
_host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowEstimator));
_args = args;
_args = options;
_tensorFlowModel = tensorFlowModel;
var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, args.InputColumns);
var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns);
_tfInputTypes = inputTuple.tfInputTypes;
var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, args.OutputColumns);
var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns);
_outputTypes = outputTuple.outputTypes;
}

Expand Down