Skip to content

Commit 7cd88ed

Browse files
committed
Addressed reviewers' comments.
1 parent e742885 commit 7cd88ed

File tree

3 files changed

+68
-34
lines changed

3 files changed

+68
-34
lines changed

src/Microsoft.ML.TensorFlow/TensorFlowModel.cs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,54 @@ public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, stri
8080
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, ModelPath);
8181

8282
/// <summary>
83-
/// Create the <see cref="TensorFlowEstimator"/> for scoring or retraining using the tensorflow model.
83+
/// Retrain the TensorFlow model on new data.
8484
/// The model is not loaded again instead the information contained in <see cref="TensorFlowModel"/> class is reused
8585
/// (c.f. <see cref="TensorFlowModel.ModelPath"/> and <see cref="TensorFlowModel.Session"/>).
8686
/// </summary>
87-
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
88-
public TensorFlowEstimator CreateTensorFlowEstimator(TensorFlowEstimator.Options options)
87+
/// <param name="inputColumnNames"> The names of the model inputs.</param>
88+
/// <param name="outputColumnNames">The names of the requested model outputs.</param>
89+
/// <param name="labelColumnName">Name of the label column.</param>
90+
/// <param name="tensorFlowLabel">Name of the node in TensorFlow graph that is used as label during training in TensorFlow.
91+
/// The value of <paramref name="labelColumnName"/> from <see cref="IDataView"/> is fed to this node.</param>
92+
/// <param name="optimizationOperation">The name of the optimization operation in the TensorFlow graph.</param>
93+
/// <param name="epoch">Number of training iterations.</param>
94+
/// <param name="batchSize">Number of samples to use for mini-batch training.</param>
95+
/// <param name="lossOperation">The name of the operation in the TensorFlow graph to compute training loss (Optional).</param>
96+
/// <param name="metricOperation">The name of the operation in the TensorFlow graph to compute performance metric during training (Optional).</param>
97+
/// <param name="learningRateOperation">The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).</param>
98+
/// <param name="learningRate">Learning rate to use during optimization (Optional).</param>
99+
/// <remarks>
100+
/// The support for retraining is experimental.
101+
/// </remarks>
102+
public TensorFlowEstimator RetrainTensorFlowModel(
103+
string[] outputColumnNames,
104+
string[] inputColumnNames,
105+
string labelColumnName,
106+
string tensorFlowLabel,
107+
string optimizationOperation,
108+
int epoch = 10,
109+
int batchSize = 20,
110+
string lossOperation= null,
111+
string metricOperation = null,
112+
string learningRateOperation = null,
113+
float learningRate = 0.01f)
89114
{
90-
options.ModelLocation = ModelPath;
115+
var options = new TensorFlowEstimator.Options()
116+
{
117+
ModelLocation = ModelPath,
118+
InputColumns = inputColumnNames,
119+
OutputColumns = outputColumnNames,
120+
LabelColumn = labelColumnName,
121+
TensorFlowLabel = tensorFlowLabel,
122+
OptimizationOperation = optimizationOperation,
123+
LossOperation = lossOperation,
124+
MetricOperation = metricOperation,
125+
Epoch = epoch,
126+
LearningRateOperation = learningRateOperation,
127+
LearningRate = learningRate,
128+
BatchSize = batchSize,
129+
ReTrain = true
130+
};
91131
return new TensorFlowEstimator(_env, options, this);
92132
}
93133
}

src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public static class TensorflowCatalog
1414
{
1515
/// <summary>
1616
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
17-
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlowModel.CreateTensorFlowEstimator(TensorFlowEstimator.Options)"/>.
17+
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlowModel.ScoreTensorFlowModel(string, string)"/>.
1818
/// </summary>
1919
/// <param name="catalog">The transform's catalog.</param>
2020
/// <param name="modelLocation">Location of the TensorFlow model.</param>

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -544,20 +544,17 @@ public void TensorFlowTransformMNISTLRTrainingTest()
544544

545545
var pipe = mlContext.Transforms.Categorical.OneHotEncoding("OneHotLabel", "Label")
546546
.Append(mlContext.Transforms.Normalize(new NormalizingEstimator.MinMaxColumnOptions("Features", "Placeholder")))
547-
.Append(mlContext.Model.LoadTensorFlowModel(model_location).CreateTensorFlowEstimator(new TensorFlowEstimator.Options()
548-
{
549-
InputColumns = new[] { "Features" },
550-
OutputColumns = new[] { "Prediction", "b" },
551-
LabelColumn = "OneHotLabel",
552-
TensorFlowLabel = "Label",
553-
OptimizationOperation = "SGDOptimizer",
554-
LossOperation = "Loss",
555-
Epoch = 10,
556-
LearningRateOperation = "SGDOptimizer/learning_rate",
557-
LearningRate = 0.001f,
558-
BatchSize = 20,
559-
ReTrain = true
560-
}))
547+
.Append(mlContext.Model.LoadTensorFlowModel(model_location).RetrainTensorFlowModel(
548+
inputColumnNames: new[] { "Features" },
549+
outputColumnNames: new[] { "Prediction", "b" },
550+
labelColumnName: "OneHotLabel",
551+
tensorFlowLabel: "Label",
552+
optimizationOperation: "SGDOptimizer",
553+
lossOperation: "Loss",
554+
epoch: 10,
555+
learningRateOperation: "SGDOptimizer/learning_rate",
556+
learningRate: 0.001f,
557+
batchSize: 20))
561558
.Append(mlContext.Transforms.Concatenate("Features", "Prediction"))
562559
.Append(mlContext.Transforms.Conversion.MapValueToKey("KeyLabel", "Label", maxNumKeys: 10))
563560
.Append(mlContext.MulticlassClassification.Trainers.LightGbm("KeyLabel", "Features"));
@@ -661,21 +658,18 @@ private void ExecuteTFTransformMNISTConvTrainingTest(bool shuffle, int? shuffleS
661658
}
662659

663660
var pipe = mlContext.Transforms.CopyColumns(("Features", "Placeholder"))
664-
.Append(mlContext.Model.LoadTensorFlowModel(modelLocation).CreateTensorFlowEstimator(new TensorFlowEstimator.Options()
665-
{
666-
InputColumns = new[] { "Features" },
667-
OutputColumns = new[] { "Prediction" },
668-
LabelColumn = "TfLabel",
669-
TensorFlowLabel = "Label",
670-
OptimizationOperation = "MomentumOp",
671-
LossOperation = "Loss",
672-
MetricOperation = "Accuracy",
673-
Epoch = 10,
674-
LearningRateOperation = "learning_rate",
675-
LearningRate = 0.01f,
676-
BatchSize = 20,
677-
ReTrain = true
678-
}))
661+
.Append(mlContext.Model.LoadTensorFlowModel(modelLocation).RetrainTensorFlowModel(
662+
inputColumnNames: new[] { "Features" },
663+
outputColumnNames: new[] { "Prediction" },
664+
labelColumnName: "TfLabel",
665+
tensorFlowLabel: "Label",
666+
optimizationOperation: "MomentumOp",
667+
lossOperation: "Loss",
668+
metricOperation: "Accuracy",
669+
epoch: 10,
670+
learningRateOperation: "learning_rate",
671+
learningRate: 0.01f,
672+
batchSize: 20))
679673
.Append(mlContext.Transforms.Concatenate("Features", "Prediction"))
680674
.AppendCacheCheckpoint(mlContext)
681675
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(new LightGBM.Options()

0 commit comments

Comments
 (0)