diff --git a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs index 517c6687c7..2b46731076 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs @@ -49,20 +49,20 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca /// The model is specified in the . /// /// The transform's catalog. - /// The specifying the inputs and the settings of the . + /// The specifying the inputs and the settings of the . public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog, - TensorFlowTransformer.Options args) - => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), args); + TensorFlowTransformer.Options options) + => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options); /// /// Scores or retrains (based on setting of the ) a pre-traiend TensorFlow model specified via . /// /// The transform's catalog. - /// The specifying the inputs and the settings of the . + /// The specifying the inputs and the settings of the . /// The pre-trained TensorFlow model. public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog, - TensorFlowTransformer.Options args, + TensorFlowTransformer.Options options, TensorFlowModelInfo tensorFlowModel) - => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), args, tensorFlowModel); + => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel); } } diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index ad28f48525..452ca70f24 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -297,79 +297,79 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.InputColumns, nameof(args.InputColumns)); - env.CheckValue(args.OutputColumns, nameof(args.OutputColumns)); + env.CheckValue(options.InputColumns, nameof(options.InputColumns)); + env.CheckValue(options.OutputColumns, nameof(options.OutputColumns)); - return new TensorFlowTransformer(env, args, input).MakeDataTransform(input); + return new TensorFlowTransformer(env, options, input).MakeDataTransform(input); } - internal TensorFlowTransformer(IHostEnvironment env, Options args, IDataView input) - : this(env, args, TensorFlowUtils.LoadTensorFlowModel(env, args.ModelLocation), input) + internal TensorFlowTransformer(IHostEnvironment env, Options options, IDataView input) + : this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation), input) { } - internal TensorFlowTransformer(IHostEnvironment env, Options args, TensorFlowModelInfo tensorFlowModel, IDataView input) - : this(env, tensorFlowModel.Session, args.OutputColumns, args.InputColumns, TensorFlowUtils.IsSavedModel(env, args.ModelLocation) ? args.ModelLocation : null, false) + internal TensorFlowTransformer(IHostEnvironment env, Options options, TensorFlowModelInfo tensorFlowModel, IDataView input) + : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, TensorFlowUtils.IsSavedModel(env, options.ModelLocation) ? options.ModelLocation : null, false) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); - if (args.ReTrain) + if (options.ReTrain) { env.CheckValue(input, nameof(input)); - CheckTrainingParameters(args); + CheckTrainingParameters(options); - if (!TensorFlowUtils.IsSavedModel(env, args.ModelLocation)) + if (!TensorFlowUtils.IsSavedModel(env, options.ModelLocation)) throw env.ExceptNotSupp("TensorFlowTransform: Re-Training of TensorFlow model is only supported for un-frozen model."); - TrainCore(args, input); + TrainCore(options, input); } } - private void CheckTrainingParameters(Options args) + private void CheckTrainingParameters(Options options) { - Host.CheckNonWhiteSpace(args.LabelColumn, nameof(args.LabelColumn)); - Host.CheckNonWhiteSpace(args.OptimizationOperation, nameof(args.OptimizationOperation)); - if (Session.Graph[args.OptimizationOperation] == null) - throw Host.ExceptParam(nameof(args.OptimizationOperation), $"Optimization operation '{args.OptimizationOperation}' does not exist in the model"); + Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn)); + Host.CheckNonWhiteSpace(options.OptimizationOperation, nameof(options.OptimizationOperation)); + if (Session.Graph[options.OptimizationOperation] == null) + throw Host.ExceptParam(nameof(options.OptimizationOperation), $"Optimization operation '{options.OptimizationOperation}' does not exist in the model"); - Host.CheckNonWhiteSpace(args.TensorFlowLabel, nameof(args.TensorFlowLabel)); - if (Session.Graph[args.TensorFlowLabel] == null) - throw Host.ExceptParam(nameof(args.TensorFlowLabel), $"'{args.TensorFlowLabel}' does not exist in the model"); + Host.CheckNonWhiteSpace(options.TensorFlowLabel, nameof(options.TensorFlowLabel)); + if (Session.Graph[options.TensorFlowLabel] == null) + throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model"); - Host.CheckNonWhiteSpace(args.SaveLocationOperation, nameof(args.SaveLocationOperation)); - if (Session.Graph[args.SaveLocationOperation] == null) - throw Host.ExceptParam(nameof(args.SaveLocationOperation), $"'{args.SaveLocationOperation}' does not exist in the model"); + Host.CheckNonWhiteSpace(options.SaveLocationOperation, nameof(options.SaveLocationOperation)); + if (Session.Graph[options.SaveLocationOperation] == null) + throw Host.ExceptParam(nameof(options.SaveLocationOperation), $"'{options.SaveLocationOperation}' does not exist in the model"); - Host.CheckNonWhiteSpace(args.SaveOperation, nameof(args.SaveOperation)); - if (Session.Graph[args.SaveOperation] == null) - throw Host.ExceptParam(nameof(args.SaveOperation), $"'{args.SaveOperation}' does not exist in the model"); + Host.CheckNonWhiteSpace(options.SaveOperation, nameof(options.SaveOperation)); + if (Session.Graph[options.SaveOperation] == null) + throw Host.ExceptParam(nameof(options.SaveOperation), $"'{options.SaveOperation}' does not exist in the model"); - if (args.LossOperation != null) + if (options.LossOperation != null) { - Host.CheckNonWhiteSpace(args.LossOperation, nameof(args.LossOperation)); - if (Session.Graph[args.LossOperation] == null) - throw Host.ExceptParam(nameof(args.LossOperation), $"'{args.LossOperation}' does not exist in the model"); + Host.CheckNonWhiteSpace(options.LossOperation, nameof(options.LossOperation)); + if (Session.Graph[options.LossOperation] == null) + throw Host.ExceptParam(nameof(options.LossOperation), $"'{options.LossOperation}' does not exist in the model"); } - if (args.MetricOperation != null) + if (options.MetricOperation != null) { - Host.CheckNonWhiteSpace(args.MetricOperation, nameof(args.MetricOperation)); - if (Session.Graph[args.MetricOperation] == null) - throw Host.ExceptParam(nameof(args.MetricOperation), $"'{args.MetricOperation}' does not exist in the model"); + Host.CheckNonWhiteSpace(options.MetricOperation, nameof(options.MetricOperation)); + if (Session.Graph[options.MetricOperation] == null) + throw Host.ExceptParam(nameof(options.MetricOperation), $"'{options.MetricOperation}' does not exist in the model"); } - if (args.LearningRateOperation != null) + if (options.LearningRateOperation != null) { - Host.CheckNonWhiteSpace(args.LearningRateOperation, nameof(args.LearningRateOperation)); - if (Session.Graph[args.LearningRateOperation] == null) - throw Host.ExceptParam(nameof(args.LearningRateOperation), $"'{args.LearningRateOperation}' does not exist in the model"); + Host.CheckNonWhiteSpace(options.LearningRateOperation, nameof(options.LearningRateOperation)); + if (Session.Graph[options.LearningRateOperation] == null) + throw Host.ExceptParam(nameof(options.LearningRateOperation), $"'{options.LearningRateOperation}' does not exist in the model"); } } @@ -401,7 +401,7 @@ private void CheckTrainingParameters(Options args) return (inputColIndex, isInputVector, tfInputType, tfInputShape); } - private void TrainCore(Options args, IDataView input) + private void TrainCore(Options options, IDataView input) { var inputsForTraining = new string[Inputs.Length + 1]; var inputColIndices = new int[inputsForTraining.Length]; @@ -418,22 +418,22 @@ private void TrainCore(Options args, IDataView input) for (int i = 0; i < inputsForTraining.Length - 1; i++) { (inputColIndices[i], isInputVector[i], tfInputTypes[i], tfInputShapes[i]) = - GetTrainingInputInfo(inputSchema, inputsForTraining[i], inputsForTraining[i], args.BatchSize); + GetTrainingInputInfo(inputSchema, inputsForTraining[i], inputsForTraining[i], options.BatchSize); } var index = inputsForTraining.Length - 1; - inputsForTraining[index] = args.TensorFlowLabel; + inputsForTraining[index] = options.TensorFlowLabel; (inputColIndices[index], isInputVector[index], tfInputTypes[index], tfInputShapes[index]) = - GetTrainingInputInfo(inputSchema, args.LabelColumn, inputsForTraining[index], args.BatchSize); + GetTrainingInputInfo(inputSchema, options.LabelColumn, inputsForTraining[index], options.BatchSize); var fetchList = new List(); - if (args.LossOperation != null) - fetchList.Add(args.LossOperation); - if (args.MetricOperation != null) - fetchList.Add(args.MetricOperation); + if (options.LossOperation != null) + fetchList.Add(options.LossOperation); + if (options.MetricOperation != null) + fetchList.Add(options.MetricOperation); var cols = input.Schema.Where(c => inputColIndices.Contains(c.Index)); - for (int epoch = 0; epoch < args.Epoch; epoch++) + for (int epoch = 0; epoch < options.Epoch; epoch++) { using (var cursor = input.GetRowCursor(cols)) { @@ -445,7 +445,7 @@ private void TrainCore(Options args, IDataView input) using (var ch = Host.Start("Training TensorFlow model...")) using (var pch = Host.StartProgressChannel("TensorFlow training progress...")) { - pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, args.Epoch)); + pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch)); while (cursor.MoveNext()) { @@ -455,10 +455,10 @@ private void TrainCore(Options args, IDataView input) srcTensorGetters[i].BufferTrainingData(); } - if (((cursor.Position + 1) % args.BatchSize) == 0) + if (((cursor.Position + 1) % options.BatchSize) == 0) { isDataLeft = false; - var (l, m) = TrainBatch(inputColIndices, inputsForTraining, srcTensorGetters, fetchList, args); + var (l, m) = TrainBatch(inputColIndices, inputsForTraining, srcTensorGetters, fetchList, options); loss += l; metric += m; } @@ -466,20 +466,20 @@ private void TrainCore(Options args, IDataView input) if (isDataLeft) { isDataLeft = false; - ch.Warning("Not training on the last batch. The batch size is less than {0}.", args.BatchSize); + ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize); } pch.Checkpoint(new double?[] { loss, metric }); } } } - UpdateModelOnDisk(args.ModelLocation, args); + UpdateModelOnDisk(options.ModelLocation, options); } private (float loss, float metric) TrainBatch(int[] inputColIndices, string[] inputsForTraining, ITensorValueGetter[] srcTensorGetters, List fetchList, - Options args) + Options options) { float loss = 0; float metric = 0; @@ -490,9 +490,9 @@ private void TrainCore(Options args, IDataView input) runner.AddInput(inputName, srcTensorGetters[i].GetBufferedBatchTensor()); } - if (args.LearningRateOperation != null) - runner.AddInput(args.LearningRateOperation, new TFTensor(args.LearningRate)); - runner.AddTarget(args.OptimizationOperation); + if (options.LearningRateOperation != null) + runner.AddInput(options.LearningRateOperation, new TFTensor(options.LearningRate)); + runner.AddTarget(options.OptimizationOperation); if (fetchList.Count > 0) runner.Fetch(fetchList.ToArray()); @@ -509,14 +509,14 @@ private void TrainCore(Options args, IDataView input) /// After retraining Session and Graphs are both up-to-date /// However model on disk is not which is used to serialzed to ML.Net stream /// - private void UpdateModelOnDisk(string modelDir, Options args) + private void UpdateModelOnDisk(string modelDir, Options options) { try { // Save the model on disk var path = Path.Combine(modelDir, DefaultModelFileNames.TmpMlnetModel); - Session.GetRunner().AddInput(args.SaveLocationOperation, TFTensor.CreateString(Encoding.UTF8.GetBytes(path))) - .AddTarget(args.SaveOperation).Run(); + Session.GetRunner().AddInput(options.SaveLocationOperation, TFTensor.CreateString(Encoding.UTF8.GetBytes(path))) + .AddTarget(options.SaveOperation).Run(); // Preserve original files var variablesPath = Path.Combine(modelDir, DefaultModelFileNames.VariablesFolder); @@ -1096,19 +1096,19 @@ internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, s { } - internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options args) - : this(env, args, TensorFlowUtils.LoadTensorFlowModel(env, args.ModelLocation)) + internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options options) + : this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation)) { } - internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options args, TensorFlowModelInfo tensorFlowModel) + internal TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Options options, TensorFlowModelInfo tensorFlowModel) { _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowEstimator)); - _args = args; + _args = options; _tensorFlowModel = tensorFlowModel; - var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, args.InputColumns); + var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns); _tfInputTypes = inputTuple.tfInputTypes; - var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, args.OutputColumns); + var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns); _outputTypes = outputTuple.outputTypes; }