Skip to content

Commit d672d7d

Browse files
committed
fix: supplying hyperparameters to training step constructor drops hyperparameters specified in estimator
1 parent c82bd52 commit d672d7d

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,11 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
104104
parameters['TrainingJobName'] = job_name
105105

106106
if hyperparameters is not None:
107-
parameters['HyperParameters'] = hyperparameters
107+
merged_hyperparameters = {}
108+
if estimator.hyperparameters() is not None:
109+
merged_hyperparameters.update(estimator.hyperparameters())
110+
merged_hyperparameters.update(hyperparameters)
111+
parameters['HyperParameters'] = merged_hyperparameters
108112

109113
if experiment_config is not None:
110114
parameters['ExperimentConfig'] = experiment_config

tests/unit/test_sagemaker_steps.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,135 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
482482
'End': True
483483
}
484484

485+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
486+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
487+
def training_step_merges_hyperparameters_from_constructor_and_estimator(tensorflow_estimator):
488+
step = TrainingStep('Training',
489+
estimator=tensorflow_estimator,
490+
data={'train': 's3://sagemaker/train'},
491+
job_name='tensorflow-job',
492+
mini_batch_size=1024,
493+
hyperparameters={
494+
'key': 'value'
495+
}
496+
)
497+
498+
assert step.to_dict() == {
499+
'Type': 'Task',
500+
'Parameters': {
501+
'AlgorithmSpecification': {
502+
'TrainingImage': TENSORFLOW_IMAGE,
503+
'TrainingInputMode': 'File'
504+
},
505+
'InputDataConfig': [
506+
{
507+
'DataSource': {
508+
'S3DataSource': {
509+
'S3DataDistributionType': 'FullyReplicated',
510+
'S3DataType': 'S3Prefix',
511+
'S3Uri': 's3://sagemaker/train'
512+
}
513+
},
514+
'ChannelName': 'train'
515+
}
516+
],
517+
'OutputDataConfig': {
518+
'S3OutputPath': 's3://sagemaker/models'
519+
},
520+
'DebugHookConfig': {
521+
'S3OutputPath': 's3://sagemaker/models/debug'
522+
},
523+
'StoppingCondition': {
524+
'MaxRuntimeInSeconds': 86400
525+
},
526+
'ResourceConfig': {
527+
'InstanceCount': 1,
528+
'InstanceType': 'ml.p2.xlarge',
529+
'VolumeSizeInGB': 30
530+
},
531+
'RoleArn': EXECUTION_ROLE,
532+
'HyperParameters': {
533+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
534+
'evaluation_steps': '100',
535+
'key': 'value',
536+
'sagemaker_container_log_level': '20',
537+
'sagemaker_job_name': '"tensorflow-job"',
538+
'sagemaker_program': '"tf_train.py"',
539+
'sagemaker_region': '"us-east-1"',
540+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
541+
'training_steps': '1000',
542+
},
543+
'TrainingJobName': 'tensorflow-job',
544+
},
545+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
546+
'End': True
547+
}
548+
549+
550+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
551+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
552+
def training_step_uses_constructor_hyperparameters_when_duplicates_supplied_in_estimator(tensorflow_estimator):
553+
step = TrainingStep('Training',
554+
estimator=tensorflow_estimator,
555+
data={'train': 's3://sagemaker/train'},
556+
job_name='tensorflow-job',
557+
mini_batch_size=1024,
558+
hyperparameters={
559+
# set as 1000 in estimator
560+
'training_steps': '500'
561+
}
562+
)
563+
564+
assert step.to_dict() == {
565+
'Type': 'Task',
566+
'Parameters': {
567+
'AlgorithmSpecification': {
568+
'TrainingImage': TENSORFLOW_IMAGE,
569+
'TrainingInputMode': 'File'
570+
},
571+
'InputDataConfig': [
572+
{
573+
'DataSource': {
574+
'S3DataSource': {
575+
'S3DataDistributionType': 'FullyReplicated',
576+
'S3DataType': 'S3Prefix',
577+
'S3Uri': 's3://sagemaker/train'
578+
}
579+
},
580+
'ChannelName': 'train'
581+
}
582+
],
583+
'OutputDataConfig': {
584+
'S3OutputPath': 's3://sagemaker/models'
585+
},
586+
'DebugHookConfig': {
587+
'S3OutputPath': 's3://sagemaker/models/debug'
588+
},
589+
'StoppingCondition': {
590+
'MaxRuntimeInSeconds': 86400
591+
},
592+
'ResourceConfig': {
593+
'InstanceCount': 1,
594+
'InstanceType': 'ml.p2.xlarge',
595+
'VolumeSizeInGB': 30
596+
},
597+
'RoleArn': EXECUTION_ROLE,
598+
'HyperParameters': {
599+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
600+
'evaluation_steps': '100',
601+
'sagemaker_container_log_level': '20',
602+
'sagemaker_job_name': '"tensorflow-job"',
603+
'sagemaker_program': '"tf_train.py"',
604+
'sagemaker_region': '"us-east-1"',
605+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
606+
'training_steps': '500',
607+
},
608+
'TrainingJobName': 'tensorflow-job',
609+
},
610+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
611+
'End': True
612+
}
613+
485614

486615
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
487616
def test_transform_step_creation(pca_transformer):

0 commit comments

Comments
 (0)