@@ -297,79 +297,79 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
297
297
}
298
298
299
299
// Factory method for SignatureDataTransform.
300
- internal static IDataTransform Create ( IHostEnvironment env , Options args , IDataView input )
300
+ internal static IDataTransform Create ( IHostEnvironment env , Options options , IDataView input )
301
301
{
302
302
Contracts . CheckValue ( env , nameof ( env ) ) ;
303
- env . CheckValue ( args , nameof ( args ) ) ;
303
+ env . CheckValue ( options , nameof ( options ) ) ;
304
304
env . CheckValue ( input , nameof ( input ) ) ;
305
- env . CheckValue ( args . InputColumns , nameof ( args . InputColumns ) ) ;
306
- env . CheckValue ( args . OutputColumns , nameof ( args . OutputColumns ) ) ;
305
+ env . CheckValue ( options . InputColumns , nameof ( options . InputColumns ) ) ;
306
+ env . CheckValue ( options . OutputColumns , nameof ( options . OutputColumns ) ) ;
307
307
308
- return new TensorFlowTransformer ( env , args , input ) . MakeDataTransform ( input ) ;
308
+ return new TensorFlowTransformer ( env , options , input ) . MakeDataTransform ( input ) ;
309
309
}
310
310
311
- internal TensorFlowTransformer ( IHostEnvironment env , Options args , IDataView input )
312
- : this ( env , args , TensorFlowUtils . LoadTensorFlowModel ( env , args . ModelLocation ) , input )
311
+ internal TensorFlowTransformer ( IHostEnvironment env , Options options , IDataView input )
312
+ : this ( env , options , TensorFlowUtils . LoadTensorFlowModel ( env , options . ModelLocation ) , input )
313
313
{
314
314
}
315
315
316
- internal TensorFlowTransformer ( IHostEnvironment env , Options args , TensorFlowModelInfo tensorFlowModel , IDataView input )
317
- : this ( env , tensorFlowModel . Session , args . OutputColumns , args . InputColumns , TensorFlowUtils . IsSavedModel ( env , args . ModelLocation ) ? args . ModelLocation : null , false )
316
+ internal TensorFlowTransformer ( IHostEnvironment env , Options options , TensorFlowModelInfo tensorFlowModel , IDataView input )
317
+ : this ( env , tensorFlowModel . Session , options . OutputColumns , options . InputColumns , TensorFlowUtils . IsSavedModel ( env , options . ModelLocation ) ? options . ModelLocation : null , false )
318
318
{
319
319
320
320
Contracts . CheckValue ( env , nameof ( env ) ) ;
321
- env . CheckValue ( args , nameof ( args ) ) ;
321
+ env . CheckValue ( options , nameof ( options ) ) ;
322
322
323
- if ( args . ReTrain )
323
+ if ( options . ReTrain )
324
324
{
325
325
env . CheckValue ( input , nameof ( input ) ) ;
326
326
327
- CheckTrainingParameters ( args ) ;
327
+ CheckTrainingParameters ( options ) ;
328
328
329
- if ( ! TensorFlowUtils . IsSavedModel ( env , args . ModelLocation ) )
329
+ if ( ! TensorFlowUtils . IsSavedModel ( env , options . ModelLocation ) )
330
330
throw env . ExceptNotSupp ( "TensorFlowTransform: Re-Training of TensorFlow model is only supported for un-frozen model." ) ;
331
- TrainCore ( args , input ) ;
331
+ TrainCore ( options , input ) ;
332
332
}
333
333
}
334
334
335
- private void CheckTrainingParameters ( Options args )
335
+ private void CheckTrainingParameters ( Options options )
336
336
{
337
- Host . CheckNonWhiteSpace ( args . LabelColumn , nameof ( args . LabelColumn ) ) ;
338
- Host . CheckNonWhiteSpace ( args . OptimizationOperation , nameof ( args . OptimizationOperation ) ) ;
339
- if ( Session . Graph [ args . OptimizationOperation ] == null )
340
- throw Host . ExceptParam ( nameof ( args . OptimizationOperation ) , $ "Optimization operation '{ args . OptimizationOperation } ' does not exist in the model") ;
337
+ Host . CheckNonWhiteSpace ( options . LabelColumn , nameof ( options . LabelColumn ) ) ;
338
+ Host . CheckNonWhiteSpace ( options . OptimizationOperation , nameof ( options . OptimizationOperation ) ) ;
339
+ if ( Session . Graph [ options . OptimizationOperation ] == null )
340
+ throw Host . ExceptParam ( nameof ( options . OptimizationOperation ) , $ "Optimization operation '{ options . OptimizationOperation } ' does not exist in the model") ;
341
341
342
- Host . CheckNonWhiteSpace ( args . TensorFlowLabel , nameof ( args . TensorFlowLabel ) ) ;
343
- if ( Session . Graph [ args . TensorFlowLabel ] == null )
344
- throw Host . ExceptParam ( nameof ( args . TensorFlowLabel ) , $ "'{ args . TensorFlowLabel } ' does not exist in the model") ;
342
+ Host . CheckNonWhiteSpace ( options . TensorFlowLabel , nameof ( options . TensorFlowLabel ) ) ;
343
+ if ( Session . Graph [ options . TensorFlowLabel ] == null )
344
+ throw Host . ExceptParam ( nameof ( options . TensorFlowLabel ) , $ "'{ options . TensorFlowLabel } ' does not exist in the model") ;
345
345
346
- Host . CheckNonWhiteSpace ( args . SaveLocationOperation , nameof ( args . SaveLocationOperation ) ) ;
347
- if ( Session . Graph [ args . SaveLocationOperation ] == null )
348
- throw Host . ExceptParam ( nameof ( args . SaveLocationOperation ) , $ "'{ args . SaveLocationOperation } ' does not exist in the model") ;
346
+ Host . CheckNonWhiteSpace ( options . SaveLocationOperation , nameof ( options . SaveLocationOperation ) ) ;
347
+ if ( Session . Graph [ options . SaveLocationOperation ] == null )
348
+ throw Host . ExceptParam ( nameof ( options . SaveLocationOperation ) , $ "'{ options . SaveLocationOperation } ' does not exist in the model") ;
349
349
350
- Host . CheckNonWhiteSpace ( args . SaveOperation , nameof ( args . SaveOperation ) ) ;
351
- if ( Session . Graph [ args . SaveOperation ] == null )
352
- throw Host . ExceptParam ( nameof ( args . SaveOperation ) , $ "'{ args . SaveOperation } ' does not exist in the model") ;
350
+ Host . CheckNonWhiteSpace ( options . SaveOperation , nameof ( options . SaveOperation ) ) ;
351
+ if ( Session . Graph [ options . SaveOperation ] == null )
352
+ throw Host . ExceptParam ( nameof ( options . SaveOperation ) , $ "'{ options . SaveOperation } ' does not exist in the model") ;
353
353
354
- if ( args . LossOperation != null )
354
+ if ( options . LossOperation != null )
355
355
{
356
- Host . CheckNonWhiteSpace ( args . LossOperation , nameof ( args . LossOperation ) ) ;
357
- if ( Session . Graph [ args . LossOperation ] == null )
358
- throw Host . ExceptParam ( nameof ( args . LossOperation ) , $ "'{ args . LossOperation } ' does not exist in the model") ;
356
+ Host . CheckNonWhiteSpace ( options . LossOperation , nameof ( options . LossOperation ) ) ;
357
+ if ( Session . Graph [ options . LossOperation ] == null )
358
+ throw Host . ExceptParam ( nameof ( options . LossOperation ) , $ "'{ options . LossOperation } ' does not exist in the model") ;
359
359
}
360
360
361
- if ( args . MetricOperation != null )
361
+ if ( options . MetricOperation != null )
362
362
{
363
- Host . CheckNonWhiteSpace ( args . MetricOperation , nameof ( args . MetricOperation ) ) ;
364
- if ( Session . Graph [ args . MetricOperation ] == null )
365
- throw Host . ExceptParam ( nameof ( args . MetricOperation ) , $ "'{ args . MetricOperation } ' does not exist in the model") ;
363
+ Host . CheckNonWhiteSpace ( options . MetricOperation , nameof ( options . MetricOperation ) ) ;
364
+ if ( Session . Graph [ options . MetricOperation ] == null )
365
+ throw Host . ExceptParam ( nameof ( options . MetricOperation ) , $ "'{ options . MetricOperation } ' does not exist in the model") ;
366
366
}
367
367
368
- if ( args . LearningRateOperation != null )
368
+ if ( options . LearningRateOperation != null )
369
369
{
370
- Host . CheckNonWhiteSpace ( args . LearningRateOperation , nameof ( args . LearningRateOperation ) ) ;
371
- if ( Session . Graph [ args . LearningRateOperation ] == null )
372
- throw Host . ExceptParam ( nameof ( args . LearningRateOperation ) , $ "'{ args . LearningRateOperation } ' does not exist in the model") ;
370
+ Host . CheckNonWhiteSpace ( options . LearningRateOperation , nameof ( options . LearningRateOperation ) ) ;
371
+ if ( Session . Graph [ options . LearningRateOperation ] == null )
372
+ throw Host . ExceptParam ( nameof ( options . LearningRateOperation ) , $ "'{ options . LearningRateOperation } ' does not exist in the model") ;
373
373
}
374
374
}
375
375
@@ -401,7 +401,7 @@ private void CheckTrainingParameters(Options args)
401
401
return ( inputColIndex , isInputVector , tfInputType , tfInputShape ) ;
402
402
}
403
403
404
- private void TrainCore ( Options args , IDataView input )
404
+ private void TrainCore ( Options options , IDataView input )
405
405
{
406
406
var inputsForTraining = new string [ Inputs . Length + 1 ] ;
407
407
var inputColIndices = new int [ inputsForTraining . Length ] ;
@@ -418,22 +418,22 @@ private void TrainCore(Options args, IDataView input)
418
418
for ( int i = 0 ; i < inputsForTraining . Length - 1 ; i ++ )
419
419
{
420
420
( inputColIndices [ i ] , isInputVector [ i ] , tfInputTypes [ i ] , tfInputShapes [ i ] ) =
421
- GetTrainingInputInfo ( inputSchema , inputsForTraining [ i ] , inputsForTraining [ i ] , args . BatchSize ) ;
421
+ GetTrainingInputInfo ( inputSchema , inputsForTraining [ i ] , inputsForTraining [ i ] , options . BatchSize ) ;
422
422
}
423
423
424
424
var index = inputsForTraining . Length - 1 ;
425
- inputsForTraining [ index ] = args . TensorFlowLabel ;
425
+ inputsForTraining [ index ] = options . TensorFlowLabel ;
426
426
( inputColIndices [ index ] , isInputVector [ index ] , tfInputTypes [ index ] , tfInputShapes [ index ] ) =
427
- GetTrainingInputInfo ( inputSchema , args . LabelColumn , inputsForTraining [ index ] , args . BatchSize ) ;
427
+ GetTrainingInputInfo ( inputSchema , options . LabelColumn , inputsForTraining [ index ] , options . BatchSize ) ;
428
428
429
429
var fetchList = new List < string > ( ) ;
430
- if ( args . LossOperation != null )
431
- fetchList . Add ( args . LossOperation ) ;
432
- if ( args . MetricOperation != null )
433
- fetchList . Add ( args . MetricOperation ) ;
430
+ if ( options . LossOperation != null )
431
+ fetchList . Add ( options . LossOperation ) ;
432
+ if ( options . MetricOperation != null )
433
+ fetchList . Add ( options . MetricOperation ) ;
434
434
435
435
var cols = input . Schema . Where ( c => inputColIndices . Contains ( c . Index ) ) ;
436
- for ( int epoch = 0 ; epoch < args . Epoch ; epoch ++ )
436
+ for ( int epoch = 0 ; epoch < options . Epoch ; epoch ++ )
437
437
{
438
438
using ( var cursor = input . GetRowCursor ( cols ) )
439
439
{
@@ -445,7 +445,7 @@ private void TrainCore(Options args, IDataView input)
445
445
using ( var ch = Host . Start ( "Training TensorFlow model..." ) )
446
446
using ( var pch = Host . StartProgressChannel ( "TensorFlow training progress..." ) )
447
447
{
448
- pch . SetHeader ( new ProgressHeader ( new [ ] { "Loss" , "Metric" } , new [ ] { "Epoch" } ) , ( e ) => e . SetProgress ( 0 , epoch , args . Epoch ) ) ;
448
+ pch . SetHeader ( new ProgressHeader ( new [ ] { "Loss" , "Metric" } , new [ ] { "Epoch" } ) , ( e ) => e . SetProgress ( 0 , epoch , options . Epoch ) ) ;
449
449
450
450
while ( cursor . MoveNext ( ) )
451
451
{
@@ -455,31 +455,31 @@ private void TrainCore(Options args, IDataView input)
455
455
srcTensorGetters [ i ] . BufferTrainingData ( ) ;
456
456
}
457
457
458
- if ( ( ( cursor . Position + 1 ) % args . BatchSize ) == 0 )
458
+ if ( ( ( cursor . Position + 1 ) % options . BatchSize ) == 0 )
459
459
{
460
460
isDataLeft = false ;
461
- var ( l , m ) = TrainBatch ( inputColIndices , inputsForTraining , srcTensorGetters , fetchList , args ) ;
461
+ var ( l , m ) = TrainBatch ( inputColIndices , inputsForTraining , srcTensorGetters , fetchList , options ) ;
462
462
loss += l ;
463
463
metric += m ;
464
464
}
465
465
}
466
466
if ( isDataLeft )
467
467
{
468
468
isDataLeft = false ;
469
- ch . Warning ( "Not training on the last batch. The batch size is less than {0}." , args . BatchSize ) ;
469
+ ch . Warning ( "Not training on the last batch. The batch size is less than {0}." , options . BatchSize ) ;
470
470
}
471
471
pch . Checkpoint ( new double ? [ ] { loss , metric } ) ;
472
472
}
473
473
}
474
474
}
475
- UpdateModelOnDisk ( args . ModelLocation , args ) ;
475
+ UpdateModelOnDisk ( options . ModelLocation , options ) ;
476
476
}
477
477
478
478
private ( float loss , float metric ) TrainBatch ( int [ ] inputColIndices ,
479
479
string [ ] inputsForTraining ,
480
480
ITensorValueGetter [ ] srcTensorGetters ,
481
481
List < string > fetchList ,
482
- Options args )
482
+ Options options )
483
483
{
484
484
float loss = 0 ;
485
485
float metric = 0 ;
@@ -490,9 +490,9 @@ private void TrainCore(Options args, IDataView input)
490
490
runner . AddInput ( inputName , srcTensorGetters [ i ] . GetBufferedBatchTensor ( ) ) ;
491
491
}
492
492
493
- if ( args . LearningRateOperation != null )
494
- runner . AddInput ( args . LearningRateOperation , new TFTensor ( args . LearningRate ) ) ;
495
- runner . AddTarget ( args . OptimizationOperation ) ;
493
+ if ( options . LearningRateOperation != null )
494
+ runner . AddInput ( options . LearningRateOperation , new TFTensor ( options . LearningRate ) ) ;
495
+ runner . AddTarget ( options . OptimizationOperation ) ;
496
496
497
497
if ( fetchList . Count > 0 )
498
498
runner . Fetch ( fetchList . ToArray ( ) ) ;
@@ -509,14 +509,14 @@ private void TrainCore(Options args, IDataView input)
509
509
/// After retraining Session and Graphs are both up-to-date
510
510
/// However model on disk is not which is used to serialzed to ML.Net stream
511
511
/// </summary>
512
- private void UpdateModelOnDisk ( string modelDir , Options args )
512
+ private void UpdateModelOnDisk ( string modelDir , Options options )
513
513
{
514
514
try
515
515
{
516
516
// Save the model on disk
517
517
var path = Path . Combine ( modelDir , DefaultModelFileNames . TmpMlnetModel ) ;
518
- Session . GetRunner ( ) . AddInput ( args . SaveLocationOperation , TFTensor . CreateString ( Encoding . UTF8 . GetBytes ( path ) ) )
519
- . AddTarget ( args . SaveOperation ) . Run ( ) ;
518
+ Session . GetRunner ( ) . AddInput ( options . SaveLocationOperation , TFTensor . CreateString ( Encoding . UTF8 . GetBytes ( path ) ) )
519
+ . AddTarget ( options . SaveOperation ) . Run ( ) ;
520
520
521
521
// Preserve original files
522
522
var variablesPath = Path . Combine ( modelDir , DefaultModelFileNames . VariablesFolder ) ;
@@ -1096,19 +1096,19 @@ internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, s
1096
1096
{
1097
1097
}
1098
1098
1099
- internal TensorFlowEstimator ( IHostEnvironment env , TensorFlowTransformer . Options args )
1100
- : this ( env , args , TensorFlowUtils . LoadTensorFlowModel ( env , args . ModelLocation ) )
1099
+ internal TensorFlowEstimator ( IHostEnvironment env , TensorFlowTransformer . Options options )
1100
+ : this ( env , options , TensorFlowUtils . LoadTensorFlowModel ( env , options . ModelLocation ) )
1101
1101
{
1102
1102
}
1103
1103
1104
- internal TensorFlowEstimator ( IHostEnvironment env , TensorFlowTransformer . Options args , TensorFlowModelInfo tensorFlowModel )
1104
+ internal TensorFlowEstimator ( IHostEnvironment env , TensorFlowTransformer . Options options , TensorFlowModelInfo tensorFlowModel )
1105
1105
{
1106
1106
_host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( TensorFlowEstimator ) ) ;
1107
- _args = args ;
1107
+ _args = options ;
1108
1108
_tensorFlowModel = tensorFlowModel ;
1109
- var inputTuple = TensorFlowTransformer . GetInputInfo ( _host , tensorFlowModel . Session , args . InputColumns ) ;
1109
+ var inputTuple = TensorFlowTransformer . GetInputInfo ( _host , tensorFlowModel . Session , options . InputColumns ) ;
1110
1110
_tfInputTypes = inputTuple . tfInputTypes ;
1111
- var outputTuple = TensorFlowTransformer . GetOutputInfo ( _host , tensorFlowModel . Session , args . OutputColumns ) ;
1111
+ var outputTuple = TensorFlowTransformer . GetOutputInfo ( _host , tensorFlowModel . Session , options . OutputColumns ) ;
1112
1112
_outputTypes = outputTuple . outputTypes ;
1113
1113
}
1114
1114
0 commit comments