Skip to content

Commit c662d1f

Browse files
committed
Skipping CPU and lowering batch size
1 parent 14fde5e commit c662d1f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

test/integration/sagemaker/test_multi_worker_mirrored.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sagemaker.tensorflow import TensorFlow
1818
from sagemaker.utils import unique_name_from_base
1919

20+
import pytest
2021

2122
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources")
2223

@@ -44,19 +45,20 @@ def test_keras_example(
4445
assert "TF_CONFIG=" in logs
4546

4647

48+
@pytest.mark.skip_cpu
4749
def test_tf_model_garden(
4850
sagemaker_session, instance_type, image_uri, tmpdir, framework_version, capsys
4951
):
5052
epochs = 10
51-
batch_size = 512
52-
train_steps = int(1024 * epochs / batch_size)
53+
global_batch_size = 64
54+
train_steps = int(1024 * epochs / global_batch_size)
5355
steps_per_loop = train_steps // 10
5456
overrides = (
5557
f"runtime.enable_xla=False,"
5658
f"runtime.num_gpus=1,"
5759
f"runtime.distribution_strategy=multi_worker_mirrored,"
5860
f"runtime.mixed_precision_dtype=float16,"
59-
f"task.train_data.global_batch_size={batch_size},"
61+
f"task.train_data.global_batch_size={global_batch_size},"
6062
f"task.train_data.input_path=/opt/ml/input/data/training/validation*,"
6163
f"task.train_data.cache=True,"
6264
f"trainer.train_steps={train_steps},"

0 commit comments

Comments
 (0)