@@ -482,6 +482,135 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
482
482
'End' : True
483
483
}
484
484
485
+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
486
+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
487
+ def training_step_merges_hyperparameters_from_constructor_and_estimator (tensorflow_estimator ):
488
+ step = TrainingStep ('Training' ,
489
+ estimator = tensorflow_estimator ,
490
+ data = {'train' : 's3://sagemaker/train' },
491
+ job_name = 'tensorflow-job' ,
492
+ mini_batch_size = 1024 ,
493
+ hyperparameters = {
494
+ 'key' : 'value'
495
+ }
496
+ )
497
+
498
+ assert step .to_dict () == {
499
+ 'Type' : 'Task' ,
500
+ 'Parameters' : {
501
+ 'AlgorithmSpecification' : {
502
+ 'TrainingImage' : TENSORFLOW_IMAGE ,
503
+ 'TrainingInputMode' : 'File'
504
+ },
505
+ 'InputDataConfig' : [
506
+ {
507
+ 'DataSource' : {
508
+ 'S3DataSource' : {
509
+ 'S3DataDistributionType' : 'FullyReplicated' ,
510
+ 'S3DataType' : 'S3Prefix' ,
511
+ 'S3Uri' : 's3://sagemaker/train'
512
+ }
513
+ },
514
+ 'ChannelName' : 'train'
515
+ }
516
+ ],
517
+ 'OutputDataConfig' : {
518
+ 'S3OutputPath' : 's3://sagemaker/models'
519
+ },
520
+ 'DebugHookConfig' : {
521
+ 'S3OutputPath' : 's3://sagemaker/models/debug'
522
+ },
523
+ 'StoppingCondition' : {
524
+ 'MaxRuntimeInSeconds' : 86400
525
+ },
526
+ 'ResourceConfig' : {
527
+ 'InstanceCount' : 1 ,
528
+ 'InstanceType' : 'ml.p2.xlarge' ,
529
+ 'VolumeSizeInGB' : 30
530
+ },
531
+ 'RoleArn' : EXECUTION_ROLE ,
532
+ 'HyperParameters' : {
533
+ 'checkpoint_path' : '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"' ,
534
+ 'evaluation_steps' : '100' ,
535
+ 'key' : 'value' ,
536
+ 'sagemaker_container_log_level' : '20' ,
537
+ 'sagemaker_job_name' : '"tensorflow-job"' ,
538
+ 'sagemaker_program' : '"tf_train.py"' ,
539
+ 'sagemaker_region' : '"us-east-1"' ,
540
+ 'sagemaker_submit_directory' : '"s3://sagemaker/source"' ,
541
+ 'training_steps' : '1000' ,
542
+ },
543
+ 'TrainingJobName' : 'tensorflow-job' ,
544
+ },
545
+ 'Resource' : 'arn:aws:states:::sagemaker:createTrainingJob.sync' ,
546
+ 'End' : True
547
+ }
548
+
549
+
550
+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
551
+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
552
+ def training_step_uses_constructor_hyperparameters_when_duplicates_supplied_in_estimator (tensorflow_estimator ):
553
+ step = TrainingStep ('Training' ,
554
+ estimator = tensorflow_estimator ,
555
+ data = {'train' : 's3://sagemaker/train' },
556
+ job_name = 'tensorflow-job' ,
557
+ mini_batch_size = 1024 ,
558
+ hyperparameters = {
559
+ # set as 1000 in estimator
560
+ 'training_steps' : '500'
561
+ }
562
+ )
563
+
564
+ assert step .to_dict () == {
565
+ 'Type' : 'Task' ,
566
+ 'Parameters' : {
567
+ 'AlgorithmSpecification' : {
568
+ 'TrainingImage' : TENSORFLOW_IMAGE ,
569
+ 'TrainingInputMode' : 'File'
570
+ },
571
+ 'InputDataConfig' : [
572
+ {
573
+ 'DataSource' : {
574
+ 'S3DataSource' : {
575
+ 'S3DataDistributionType' : 'FullyReplicated' ,
576
+ 'S3DataType' : 'S3Prefix' ,
577
+ 'S3Uri' : 's3://sagemaker/train'
578
+ }
579
+ },
580
+ 'ChannelName' : 'train'
581
+ }
582
+ ],
583
+ 'OutputDataConfig' : {
584
+ 'S3OutputPath' : 's3://sagemaker/models'
585
+ },
586
+ 'DebugHookConfig' : {
587
+ 'S3OutputPath' : 's3://sagemaker/models/debug'
588
+ },
589
+ 'StoppingCondition' : {
590
+ 'MaxRuntimeInSeconds' : 86400
591
+ },
592
+ 'ResourceConfig' : {
593
+ 'InstanceCount' : 1 ,
594
+ 'InstanceType' : 'ml.p2.xlarge' ,
595
+ 'VolumeSizeInGB' : 30
596
+ },
597
+ 'RoleArn' : EXECUTION_ROLE ,
598
+ 'HyperParameters' : {
599
+ 'checkpoint_path' : '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"' ,
600
+ 'evaluation_steps' : '100' ,
601
+ 'sagemaker_container_log_level' : '20' ,
602
+ 'sagemaker_job_name' : '"tensorflow-job"' ,
603
+ 'sagemaker_program' : '"tf_train.py"' ,
604
+ 'sagemaker_region' : '"us-east-1"' ,
605
+ 'sagemaker_submit_directory' : '"s3://sagemaker/source"' ,
606
+ 'training_steps' : '500' ,
607
+ },
608
+ 'TrainingJobName' : 'tensorflow-job' ,
609
+ },
610
+ 'Resource' : 'arn:aws:states:::sagemaker:createTrainingJob.sync' ,
611
+ 'End' : True
612
+ }
613
+
485
614
486
615
@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
487
616
def test_transform_step_creation (pca_transformer ):
0 commit comments