@@ -471,8 +471,11 @@ public sealed class Options : TrainerInputBaseWithLabel
471
471
private readonly string _checkpointPath ;
472
472
private readonly string _bottleneckOperationName ;
473
473
private readonly bool _useLRScheduling ;
474
+ private readonly bool _cleanupWorkspace ;
474
475
private int _classCount ;
475
476
private Graph Graph => _session . graph ;
477
+ private static readonly string _resourcePath = Path . Combine ( Path . GetTempPath ( ) , "MLNET" ) ;
478
+ private readonly string _sizeFile ;
476
479
477
480
/// <summary>
478
481
/// Initializes a new instance of <see cref="ImageClassificationTrainer"/>
@@ -518,6 +521,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
518
521
if ( string . IsNullOrEmpty ( options . WorkspacePath ) )
519
522
{
520
523
options . WorkspacePath = GetTemporaryDirectory ( ) ;
524
+ _cleanupWorkspace = true ;
525
+ }
526
+
527
+ if ( ! Directory . Exists ( _resourcePath ) )
528
+ {
529
+ Directory . CreateDirectory ( _resourcePath ) ;
521
530
}
522
531
523
532
if ( string . IsNullOrEmpty ( options . TrainSetBottleneckCachedValuesFileName ) )
@@ -542,6 +551,7 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
542
551
_useLRScheduling = _options . LearningRateScheduler != null ;
543
552
_checkpointPath = Path . Combine ( _options . WorkspacePath , _options . FinalModelPrefix +
544
553
ModelFileName [ _options . Arch ] ) ;
554
+ _sizeFile = Path . Combine ( _options . WorkspacePath , "TrainingSetSize.txt" ) ;
545
555
546
556
// Configure bottleneck tensor based on the model.
547
557
var arch = _options . Arch ;
@@ -552,8 +562,8 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
552
562
}
553
563
else if ( arch == Architecture . InceptionV3 )
554
564
{
555
- _bottleneckOperationName = "module_apply_default/hub_output/feature_vector /SpatialSqueeze" ;
556
- _inputTensorName = "Placeholder " ;
565
+ _bottleneckOperationName = "InceptionV3/Logits /SpatialSqueeze" ;
566
+ _inputTensorName = "input " ;
557
567
}
558
568
else if ( arch == Architecture . MobilenetV2 )
559
569
{
@@ -580,7 +590,7 @@ private void InitializeTrainingGraph(IDataView input)
580
590
581
591
_classCount = labelCount == 1 ? 2 : ( int ) labelCount ;
582
592
var imageSize = ImagePreprocessingSize [ _options . Arch ] ;
583
- _session = LoadTensorFlowSessionFromMetaGraph ( Host , _options . Arch , _options . WorkspacePath ) . Session ;
593
+ _session = LoadTensorFlowSessionFromMetaGraph ( Host , _options . Arch ) . Session ;
584
594
( _jpegData , _resizedImage ) = AddJpegDecoding ( imageSize . Item1 , imageSize . Item2 , 3 ) ;
585
595
_jpegDataTensorName = _jpegData . name ;
586
596
_resizedImageTensorName = _resizedImage . name ;
@@ -637,7 +647,7 @@ private protected override ImageClassificationModelParameters TrainModelCore(Tra
637
647
ImageClassificationMetrics . Dataset . Train , _options . MetricsCallback ) ;
638
648
639
649
// Write training set size to a file for use during training
640
- File . WriteAllText ( "TrainingSetSize.txt" , trainingsetSize . ToString ( ) ) ;
650
+ File . WriteAllText ( _sizeFile , trainingsetSize . ToString ( ) ) ;
641
651
}
642
652
643
653
if ( validationSet != null &&
@@ -905,7 +915,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
905
915
{
906
916
BatchSize = options . BatchSize ,
907
917
BatchesPerEpoch =
908
- ( trainingsetSize < 0 ? GetNumSamples ( "TrainingSetSize.txt" ) : trainingsetSize ) / options . BatchSize
918
+ ( trainingsetSize < 0 ? GetNumSamples ( _sizeFile ) : trainingsetSize ) / options . BatchSize
909
919
} ;
910
920
911
921
for ( int epoch = 0 ; epoch < epochs ; epoch += 1 )
@@ -1129,11 +1139,27 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
1129
1139
1130
1140
trainSaver . save ( _session , _checkpointPath ) ;
1131
1141
UpdateTransferLearningModelOnDisk ( _classCount ) ;
1142
+ TryCleanupTemporaryWorkspace ( ) ;
1143
+ }
1144
+
1145
+ private void TryCleanupTemporaryWorkspace ( )
1146
+ {
1147
+ if ( _cleanupWorkspace && Directory . Exists ( _options . WorkspacePath ) )
1148
+ {
1149
+ try
1150
+ {
1151
+ Directory . Delete ( _options . WorkspacePath , true ) ;
1152
+ }
1153
+ catch ( Exception )
1154
+ {
1155
+ //We do not want to stop pipeline due to failed cleanup.
1156
+ }
1157
+ }
1132
1158
}
1133
1159
1134
1160
private ( Session , Tensor , Tensor , Tensor ) BuildEvaluationSession ( int classCount )
1135
1161
{
1136
- var evalGraph = LoadMetaGraph ( Path . Combine ( _options . WorkspacePath , ModelFileName [ _options . Arch ] ) ) ;
1162
+ var evalGraph = LoadMetaGraph ( Path . Combine ( _resourcePath , ModelFileName [ _options . Arch ] ) ) ;
1137
1163
var evalSess = tf . Session ( graph : evalGraph ) ;
1138
1164
Tensor evaluationStep = null ;
1139
1165
Tensor prediction = null ;
@@ -1291,24 +1317,12 @@ private void AddTransferLearningLayer(string labelColumn,
1291
1317
1292
1318
}
1293
1319
1294
- private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph ( IHostEnvironment env , Architecture arch , string path )
1320
+ private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph ( IHostEnvironment env , Architecture arch )
1295
1321
{
1296
- if ( string . IsNullOrEmpty ( path ) )
1297
- {
1298
- path = GetTemporaryDirectory ( ) ;
1299
- }
1300
-
1301
1322
var modelFileName = ModelFileName [ arch ] ;
1302
- var modelFilePath = Path . Combine ( path , modelFileName ) ;
1323
+ var modelFilePath = Path . Combine ( _resourcePath , modelFileName ) ;
1303
1324
int timeout = 10 * 60 * 1000 ;
1304
- DownloadIfNeeded ( env , modelFileName , path , modelFileName , timeout ) ;
1305
- if ( arch == Architecture . InceptionV3 )
1306
- {
1307
- DownloadIfNeeded ( env , @"tfhub_modules.zip" , path , @"tfhub_modules.zip" , timeout ) ;
1308
- if ( ! Directory . Exists ( @"tfhub_modules" ) )
1309
- ZipFile . ExtractToDirectory ( Path . Combine ( path , @"tfhub_modules.zip" ) , @"tfhub_modules" ) ;
1310
- }
1311
-
1325
+ DownloadIfNeeded ( env , modelFileName , _resourcePath , modelFileName , timeout ) ;
1312
1326
return new TensorFlowSessionWrapper ( GetSession ( env , modelFilePath , true ) , modelFilePath ) ;
1313
1327
}
1314
1328
0 commit comments