Skip to content

Commit 7a8f063

Browse files
bpstarkcodemzs
authored andcommitted
Modified how data is saved to disk (#4424)
* Modified how data is saved to disk pre-trained meta files are now stored in one location always, this allows multiple runst to re-use the same meta file without having to redownload. Additionally added the ability to cleanup the temporary workspace used to train the model. This should prevent issues of running out of disk space when running multiple training session sequentially. * fixed extra line error. * changes based on comments. * comment fixes. * address comments. * fixed error * fix for when to download. * Remove special handling of inception as we have fixed the meta file to not have extras that need to be downloaded.
1 parent b9c68bf commit 7a8f063

File tree

2 files changed

+38
-22
lines changed

2 files changed

+38
-22
lines changed

src/Microsoft.ML.Vision/ImageClassificationTrainer.cs

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,11 @@ public sealed class Options : TrainerInputBaseWithLabel
471471
private readonly string _checkpointPath;
472472
private readonly string _bottleneckOperationName;
473473
private readonly bool _useLRScheduling;
474+
private readonly bool _cleanupWorkspace;
474475
private int _classCount;
475476
private Graph Graph => _session.graph;
477+
private static readonly string _resourcePath = Path.Combine(Path.GetTempPath(), "MLNET");
478+
private readonly string _sizeFile;
476479

477480
/// <summary>
478481
/// Initializes a new instance of <see cref="ImageClassificationTrainer"/>
@@ -518,6 +521,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
518521
if (string.IsNullOrEmpty(options.WorkspacePath))
519522
{
520523
options.WorkspacePath = GetTemporaryDirectory();
524+
_cleanupWorkspace = true;
525+
}
526+
527+
if (!Directory.Exists(_resourcePath))
528+
{
529+
Directory.CreateDirectory(_resourcePath);
521530
}
522531

523532
if (string.IsNullOrEmpty(options.TrainSetBottleneckCachedValuesFileName))
@@ -542,6 +551,7 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
542551
_useLRScheduling = _options.LearningRateScheduler != null;
543552
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix +
544553
ModelFileName[_options.Arch]);
554+
_sizeFile = Path.Combine(_options.WorkspacePath, "TrainingSetSize.txt");
545555

546556
// Configure bottleneck tensor based on the model.
547557
var arch = _options.Arch;
@@ -552,8 +562,8 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
552562
}
553563
else if (arch == Architecture.InceptionV3)
554564
{
555-
_bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze";
556-
_inputTensorName = "Placeholder";
565+
_bottleneckOperationName = "InceptionV3/Logits/SpatialSqueeze";
566+
_inputTensorName = "input";
557567
}
558568
else if (arch == Architecture.MobilenetV2)
559569
{
@@ -580,7 +590,7 @@ private void InitializeTrainingGraph(IDataView input)
580590

581591
_classCount = labelCount == 1 ? 2 : (int)labelCount;
582592
var imageSize = ImagePreprocessingSize[_options.Arch];
583-
_session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch, _options.WorkspacePath).Session;
593+
_session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session;
584594
(_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3);
585595
_jpegDataTensorName = _jpegData.name;
586596
_resizedImageTensorName = _resizedImage.name;
@@ -637,7 +647,7 @@ private protected override ImageClassificationModelParameters TrainModelCore(Tra
637647
ImageClassificationMetrics.Dataset.Train, _options.MetricsCallback);
638648

639649
// Write training set size to a file for use during training
640-
File.WriteAllText("TrainingSetSize.txt", trainingsetSize.ToString());
650+
File.WriteAllText(_sizeFile, trainingsetSize.ToString());
641651
}
642652

643653
if (validationSet != null &&
@@ -905,7 +915,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
905915
{
906916
BatchSize = options.BatchSize,
907917
BatchesPerEpoch =
908-
(trainingsetSize < 0 ? GetNumSamples("TrainingSetSize.txt") : trainingsetSize) / options.BatchSize
918+
(trainingsetSize < 0 ? GetNumSamples(_sizeFile) : trainingsetSize) / options.BatchSize
909919
};
910920

911921
for (int epoch = 0; epoch < epochs; epoch += 1)
@@ -1129,11 +1139,27 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
11291139

11301140
trainSaver.save(_session, _checkpointPath);
11311141
UpdateTransferLearningModelOnDisk(_classCount);
1142+
TryCleanupTemporaryWorkspace();
1143+
}
1144+
1145+
private void TryCleanupTemporaryWorkspace()
1146+
{
1147+
if (_cleanupWorkspace && Directory.Exists(_options.WorkspacePath))
1148+
{
1149+
try
1150+
{
1151+
Directory.Delete(_options.WorkspacePath, true);
1152+
}
1153+
catch (Exception)
1154+
{
1155+
//We do not want to stop pipeline due to failed cleanup.
1156+
}
1157+
}
11321158
}
11331159

11341160
private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(int classCount)
11351161
{
1136-
var evalGraph = LoadMetaGraph(Path.Combine(_options.WorkspacePath, ModelFileName[_options.Arch]));
1162+
var evalGraph = LoadMetaGraph(Path.Combine(_resourcePath, ModelFileName[_options.Arch]));
11371163
var evalSess = tf.Session(graph: evalGraph);
11381164
Tensor evaluationStep = null;
11391165
Tensor prediction = null;
@@ -1291,24 +1317,12 @@ private void AddTransferLearningLayer(string labelColumn,
12911317

12921318
}
12931319

1294-
private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch, string path)
1320+
private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch)
12951321
{
1296-
if (string.IsNullOrEmpty(path))
1297-
{
1298-
path = GetTemporaryDirectory();
1299-
}
1300-
13011322
var modelFileName = ModelFileName[arch];
1302-
var modelFilePath = Path.Combine(path, modelFileName);
1323+
var modelFilePath = Path.Combine(_resourcePath, modelFileName);
13031324
int timeout = 10 * 60 * 1000;
1304-
DownloadIfNeeded(env, modelFileName, path, modelFileName, timeout);
1305-
if (arch == Architecture.InceptionV3)
1306-
{
1307-
DownloadIfNeeded(env, @"tfhub_modules.zip", path, @"tfhub_modules.zip", timeout);
1308-
if (!Directory.Exists(@"tfhub_modules"))
1309-
ZipFile.ExtractToDirectory(Path.Combine(path, @"tfhub_modules.zip"), @"tfhub_modules");
1310-
}
1311-
1325+
DownloadIfNeeded(env, modelFileName, _resourcePath, modelFileName, timeout);
13121326
return new TensorFlowSessionWrapper(GetSession(env, modelFilePath, true), modelFilePath);
13131327
}
13141328

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,7 @@ public void TensorFlowImageClassificationDefault()
12891289
[InlineData(ImageClassificationTrainer.Architecture.ResnetV2101)]
12901290
[InlineData(ImageClassificationTrainer.Architecture.MobilenetV2)]
12911291
[InlineData(ImageClassificationTrainer.Architecture.ResnetV250)]
1292+
[InlineData(ImageClassificationTrainer.Architecture.InceptionV3)]
12921293
public void TensorFlowImageClassification(ImageClassificationTrainer.Architecture arch)
12931294
{
12941295
string assetsRelativePath = @"assets";
@@ -1582,8 +1583,9 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule
15821583

15831584
Assert.True(File.Exists(Path.Combine(options.WorkspacePath, options.TrainSetBottleneckCachedValuesFileName)));
15841585
Assert.True(File.Exists(Path.Combine(options.WorkspacePath, options.ValidationSetBottleneckCachedValuesFileName)));
1585-
Assert.True(File.Exists(Path.Combine(options.WorkspacePath, ImageClassificationTrainer.ModelFileName[options.Arch])));
1586+
Assert.True(File.Exists(Path.Combine(options.WorkspacePath, "TrainingSetSize.txt")));
15861587
Directory.Delete(options.WorkspacePath, true);
1588+
Assert.True(File.Exists(Path.Combine(Path.GetTempPath(), "MLNET", ImageClassificationTrainer.ModelFileName[options.Arch])));
15871589
}
15881590

15891591
[TensorFlowFact]

0 commit comments

Comments
 (0)