Skip to content

Commit 72e0eee

Browse files
committed
PR feedback.
1 parent a43dee1 commit 72e0eee

File tree

6 files changed

+21
-38
lines changed

6 files changed

+21
-38
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/LearningRateSchedulingCifarResnetTransferLearning.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ public static void Example()
7878
LearningRate = 0.01f,
7979
MetricsCallback = (metrics) => Console.WriteLine(metrics),
8080
ValidationSet = testDataset,
81-
DisableEarlyStopping = true,
8281
ReuseValidationSetBottleneckCachedValues = false,
8382
ReuseTrainSetBottleneckCachedValues = false,
8483
// Use linear scaling rule and Learning rate decay as an option

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ public static void Example()
7272
BatchSize = 10,
7373
LearningRate = 0.01f,
7474
MetricsCallback = (metrics) => Console.WriteLine(metrics),
75-
ValidationSet = testDataset,
76-
DisableEarlyStopping = true
75+
ValidationSet = testDataset
7776
};
7877

7978
var pipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(options)

src/Microsoft.ML.Dnn/DnnCatalog.cs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,29 +81,28 @@ internal static DnnRetrainEstimator RetrainDnnModel(
8181

8282
/// <summary>
8383
/// Performs image classification using transfer learning.
84-
/// Usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information.
84+
/// Usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document
85+
/// for more information.
8586
/// <format type="text/markdown">
8687
/// <![CDATA[
8788
/// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)]
8889
/// ]]>
8990
/// </format>
9091
/// </summary>
9192
/// <param name="catalog">Catalog</param>
92-
/// <param name="options">An <see cref="ImageClassificationTrainer.Options"/> object specifying advanced options for <see cref="ImageClassificationTrainer"/>.</param>
93+
/// <param name="options">An <see cref="ImageClassificationTrainer.Options"/> object specifying advanced
94+
/// options for <see cref="ImageClassificationTrainer"/>.</param>
9395

9496
public static ImageClassificationTrainer ImageClassification(
9597
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
96-
ImageClassificationTrainer.Options options)
97-
{
98-
options.EarlyStoppingCriteria = options.DisableEarlyStopping ? null : options.EarlyStoppingCriteria ?? new ImageClassificationTrainer.EarlyStopping();
98+
ImageClassificationTrainer.Options options) =>
99+
new ImageClassificationTrainer(CatalogUtils.GetEnvironment(catalog), options);
99100

100-
var env = CatalogUtils.GetEnvironment(catalog);
101-
return new ImageClassificationTrainer(env, options);
102-
}
103101

104102
/// <summary>
105103
/// Performs image classification using transfer learning.
106-
/// Usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information.
104+
/// Usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for
105+
/// more information.
107106
/// <format type="text/markdown">
108107
/// <![CDATA[
109108
/// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)]

src/Microsoft.ML.Dnn/ImageClassificationTrainer.cs

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ namespace Microsoft.ML.Dnn
6565
/// ]]>
6666
/// </format>
6767
/// </remarks>
68-
public class ImageClassificationTrainer :
68+
public sealed class ImageClassificationTrainer :
6969
TrainerEstimatorBase<MulticlassPredictionTransformer<ImageClassificationModelParameters>,
7070
ImageClassificationModelParameters>
7171
{
@@ -117,12 +117,6 @@ public enum EarlyStoppingMetric
117117
Loss
118118
}
119119

120-
/// <summary>
121-
/// Callback that returns DNN statistics during bottlenack phase and training phase.
122-
/// Train metrics may be null when bottleneck phase is running, so have check!
123-
/// </summary>
124-
public delegate void ImageClassificationMetricsCallback(ImageClassificationMetrics metrics);
125-
126120
/// <summary>
127121
/// DNN training metrics.
128122
/// </summary>
@@ -333,6 +327,9 @@ public enum Dataset
333327
public override string ToString() => Train != null ? Train.ToString() : Bottleneck.ToString();
334328
}
335329

330+
/// <summary>
331+
/// Options class for <see cref="ImageClassificationTrainer"/>.
332+
/// </summary>
336333
public sealed class Options : TrainerInputBaseWithLabel
337334
{
338335
/// <summary>
@@ -353,12 +350,6 @@ public sealed class Options : TrainerInputBaseWithLabel
353350
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)]
354351
public float LearningRate = 0.01f;
355352

356-
/// <summary>
357-
/// Whether to disable use of early stopping technique. Training will go on for the full epoch count.
358-
/// </summary>
359-
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to disable use of early stopping technique. Training will go on for the full epoch count.", SortOrder = 15)]
360-
public bool DisableEarlyStopping = false;
361-
362353
/// <summary>
363354
/// Early stopping technique parameters to be used to terminate training when training metric stops improving.
364355
/// </summary>
@@ -393,7 +384,7 @@ public sealed class Options : TrainerInputBaseWithLabel
393384
/// Callback to report statistics on accuracy/cross entropy during training phase.
394385
/// </summary>
395386
[Argument(ArgumentType.AtMostOnce, HelpText = "Callback to report metrics during training and validation phase.", SortOrder = 15)]
396-
public ImageClassificationMetricsCallback MetricsCallback = null;
387+
public Action<ImageClassificationMetrics> MetricsCallback = null;
397388

398389
/// <summary>
399390
/// Indicates the path where the newly retrained model should be saved.
@@ -499,7 +490,8 @@ internal ImageClassificationTrainer(IHostEnvironment env,
499490
LabelColumnName = labelColumn,
500491
ScoreColumnName = scoreColumn,
501492
PredictedLabelColumnName = predictedLabelColumn,
502-
ValidationSet = validationSet
493+
ValidationSet = validationSet,
494+
EarlyStoppingCriteria = new EarlyStopping()
503495
})
504496
{
505497
}
@@ -520,9 +512,6 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
520512
Host.CheckNonEmpty(options.PredictedLabelColumnName, nameof(options.PredictedLabelColumnName));
521513

522514
_options = options;
523-
_options.EarlyStoppingCriteria = _options.DisableEarlyStopping ? null : _options.EarlyStoppingCriteria ??
524-
new EarlyStopping();
525-
526515
_session = DnnUtils.LoadDnnModel(env, _options.Arch, true).Session;
527516
_useLRScheduling = _options.LearningRateScheduler != null;
528517
_checkpointPath = _options.ModelSavePath ??
@@ -723,7 +712,7 @@ public Tensor ProcessImage(in VBuffer<byte> imageBuffer)
723712

724713
private int CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName, string imageColumnName,
725714
ImageProcessor imageProcessor, string inputTensorName, string outputTensorName, string cacheFilePath,
726-
ImageClassificationMetrics.Dataset dataset, ImageClassificationMetricsCallback metricsCallback)
715+
ImageClassificationMetrics.Dataset dataset, Action<ImageClassificationMetrics> metricsCallback)
727716
{
728717
var labelColumn = input.Schema[labelColumnName];
729718

@@ -809,7 +798,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
809798
int epochs = options.Epoch;
810799
float learningRate = options.LearningRate;
811800
bool evaluateOnly = !string.IsNullOrEmpty(validationSetBottleneckFilePath);
812-
ImageClassificationMetricsCallback statisticsCallback = _options.MetricsCallback;
801+
Action<ImageClassificationMetrics> statisticsCallback = _options.MetricsCallback;
813802
var trainingSet = GetShuffledData(trainBottleneckFilePath);
814803
IDataView validationSet = null;
815804
if (options.ValidationSet != null && !string.IsNullOrEmpty(validationSetBottleneckFilePath))

test/Microsoft.ML.Benchmarks/ImageClassificationBench.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ public TransformerChain<KeyToValueMappingTransformer> TrainResnetV250()
9191
LearningRate = 0.01f,
9292
EarlyStoppingCriteria = new ImageClassificationTrainer.EarlyStopping(minDelta: 0.001f, patience: 20, metric: ImageClassificationTrainer.EarlyStoppingMetric.Loss),
9393
ValidationSet = testDataset,
94-
ModelSavePath = assetsPath,
95-
DisableEarlyStopping = true
94+
ModelSavePath = assetsPath
9695
};
9796
var pipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(options)
9897
.Append(mlContext.Transforms.Conversion.MapKeyToValue(

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,8 +1339,7 @@ public void TensorFlowImageClassification(ImageClassificationTrainer.Architectur
13391339
LearningRate = 0.01f,
13401340
MetricsCallback = (metrics) => Console.WriteLine(metrics),
13411341
TestOnTrainSet = false,
1342-
ValidationSet = validationSet,
1343-
DisableEarlyStopping = true
1342+
ValidationSet = validationSet
13441343
};
13451344

13461345
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
@@ -1477,7 +1476,6 @@ public void TensorFlowImageClassificationWithLRScheduling()
14771476
LearningRate = 0.01f,
14781477
MetricsCallback = (metrics) => Console.WriteLine(metrics),
14791478
ValidationSet = validationSet,
1480-
DisableEarlyStopping = true,
14811479
ReuseValidationSetBottleneckCachedValues = false,
14821480
ReuseTrainSetBottleneckCachedValues = false,
14831481
// Using Exponential Decay for learning rate scheduling
@@ -1624,7 +1622,7 @@ public void TensorFlowImageClassificationEarlyStoppingIncreasing()
16241622
EarlyStoppingCriteria = new ImageClassificationTrainer.EarlyStopping(),
16251623
MetricsCallback = (metrics) => { Console.WriteLine(metrics); lastEpoch = metrics.Train != null ? metrics.Train.Epoch : 0; },
16261624
TestOnTrainSet = false,
1627-
ValidationSet = validationSet,
1625+
ValidationSet = validationSet
16281626
};
16291627

16301628
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>

0 commit comments

Comments
 (0)