Skip to content

Commit d758406

Browse files
committed
support sequence of tests and add checkpoint test
address comments ghstack-source-id: 7d6c51a Pull Request resolved: #198
1 parent 2c21f36 commit d758406

File tree

1 file changed

+41
-18
lines changed

1 file changed

+41
-18
lines changed

test/test_runner.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ class OverrideDefinitions:
2121
This class is used to define the override definitions for the integration tests.
2222
"""
2323

24-
override_args: Sequence[str] = tuple()
25-
test_descr: str = "default"
24+
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
25+
test_descr: str = ""
2626

2727

2828
CONFIG_DIR = "./train_configs"
29+
test_checkpoint_dir = "./test_runner_checkpoint"
2930

3031
"""
3132
key is the config file name and value is a list of OverrideDefinitions
@@ -34,13 +35,47 @@ class OverrideDefinitions:
3435
"""
3536
integration_tests_flavors = defaultdict(list)
3637
integration_tests_flavors["debug_model.toml"] = [
37-
OverrideDefinitions(["--training.compile"], "1D compile"),
3838
OverrideDefinitions(
39-
["--training.tensor_parallel_degree 2"], "Eager mode 2DParallel"
39+
[
40+
["--training.compile"],
41+
],
42+
"1D compile",
43+
),
44+
OverrideDefinitions(
45+
[
46+
["--training.tensor_parallel_degree 2"],
47+
],
48+
"Eager mode 2DParallel",
49+
),
50+
OverrideDefinitions(
51+
[
52+
[f"--checkpoint.folder {test_checkpoint_dir}"],
53+
[f"--checkpoint.folder {test_checkpoint_dir}", "--training.steps 20"],
54+
],
55+
"Checkpoint Integration Test",
4056
),
4157
]
4258

4359

60+
def run_test(test_flavor: OverrideDefinitions, full_path: str):
61+
# run_test supports sequence of tests.
62+
for override_arg in test_flavor.override_args:
63+
cmd = f"CONFIG_FILE={full_path} NGPU=4 ./run_llama_train.sh"
64+
if override_arg:
65+
cmd += " " + " ".join(override_arg)
66+
print(
67+
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
68+
)
69+
result = subprocess.run(
70+
[cmd],
71+
stdout=subprocess.PIPE,
72+
stderr=subprocess.STDOUT,
73+
text=True,
74+
shell=True,
75+
)
76+
print(result.stdout)
77+
78+
4479
for config_file in os.listdir(CONFIG_DIR):
4580
if config_file.endswith(".toml"):
4681
full_path = os.path.join(CONFIG_DIR, config_file)
@@ -51,18 +86,6 @@ class OverrideDefinitions:
5186
test_flavors = [OverrideDefinitions()] + integration_tests_flavors[
5287
config_file
5388
]
89+
5490
for test_flavor in test_flavors:
55-
cmd = f"CONFIG_FILE={full_path} NGPU=4 ./run_llama_train.sh"
56-
if test_flavor.override_args:
57-
cmd += " " + " ".join(test_flavor.override_args)
58-
print(
59-
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
60-
)
61-
result = subprocess.run(
62-
[cmd],
63-
stdout=subprocess.PIPE,
64-
stderr=subprocess.STDOUT,
65-
text=True,
66-
shell=True,
67-
)
68-
print(result.stdout)
91+
run_test(test_flavor, full_path)

0 commit comments

Comments
 (0)