Skip to content

Commit f53c0fc

Browse files
wconstabtianyu-l
authored andcommitted
Make test_runner.py warn on non-empty output dir
also wrap logic into functions and clean up global vars ghstack-source-id: 815c582 Pull Request resolved: #343
1 parent 76d0956 commit f53c0fc

File tree

1 file changed

+89
-73
lines changed

1 file changed

+89
-73
lines changed

test_runner.py

Lines changed: 89 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717
import tomli as tomllib
1818

1919

20-
parser = argparse.ArgumentParser()
21-
parser.add_argument("output_dir")
22-
args = parser.parse_args()
23-
24-
2520
@dataclass
2621
class OverrideDefinitions:
2722
"""
@@ -32,77 +27,77 @@ class OverrideDefinitions:
3227
test_descr: str = "default"
3328

3429

35-
CONFIG_DIR = "./train_configs"
36-
37-
"""
38-
key is the config file name and value is a list of OverrideDefinitions
39-
that is used to generate variations of integration tests based on the
40-
same root config file.
41-
"""
42-
integration_tests_flavors = defaultdict(list)
43-
integration_tests_flavors["debug_model.toml"] = [
44-
OverrideDefinitions(
45-
[
46-
[
47-
f"--job.dump_folder {args.output_dir}/default/",
48-
],
49-
],
50-
"Default",
51-
),
52-
OverrideDefinitions(
53-
[
30+
def build_test_list(args):
31+
"""
32+
key is the config file name and value is a list of OverrideDefinitions
33+
that is used to generate variations of integration tests based on the
34+
same root config file.
35+
"""
36+
integration_tests_flavors = defaultdict(list)
37+
integration_tests_flavors["debug_model.toml"] = [
38+
OverrideDefinitions(
5439
[
55-
"--training.compile",
56-
f"--job.dump_folder {args.output_dir}/1d_compile/",
40+
[
41+
f"--job.dump_folder {args.output_dir}/default/",
42+
],
5743
],
58-
],
59-
"1D compile",
60-
),
61-
OverrideDefinitions(
62-
[
44+
"Default",
45+
),
46+
OverrideDefinitions(
6347
[
64-
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
65-
f"--job.dump_folder {args.output_dir}/eager_2d/",
48+
[
49+
"--training.compile",
50+
f"--job.dump_folder {args.output_dir}/1d_compile/",
51+
],
6652
],
67-
],
68-
"Eager mode 2DParallel",
69-
),
70-
OverrideDefinitions(
71-
[
53+
"1D compile",
54+
),
55+
OverrideDefinitions(
7256
[
73-
"--checkpoint.enable_checkpoint",
74-
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
57+
[
58+
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
59+
f"--job.dump_folder {args.output_dir}/eager_2d/",
60+
],
7561
],
62+
"Eager mode 2DParallel",
63+
),
64+
OverrideDefinitions(
7665
[
77-
"--checkpoint.enable_checkpoint",
78-
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
79-
"--training.steps 20",
66+
[
67+
"--checkpoint.enable_checkpoint",
68+
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
69+
],
70+
[
71+
"--checkpoint.enable_checkpoint",
72+
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
73+
"--training.steps 20",
74+
],
8075
],
81-
],
82-
"Checkpoint Integration Test - Save Load Full Checkpoint",
83-
),
84-
OverrideDefinitions(
85-
[
76+
"Checkpoint Integration Test - Save Load Full Checkpoint",
77+
),
78+
OverrideDefinitions(
8679
[
87-
"--checkpoint.enable_checkpoint",
88-
f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/",
89-
"--checkpoint.model_weights_only",
80+
[
81+
"--checkpoint.enable_checkpoint",
82+
f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/",
83+
"--checkpoint.model_weights_only",
84+
],
9085
],
91-
],
92-
"Checkpoint Integration Test - Save Model Weights Only fp32",
93-
),
94-
OverrideDefinitions(
95-
[
86+
"Checkpoint Integration Test - Save Model Weights Only fp32",
87+
),
88+
OverrideDefinitions(
9689
[
97-
"--checkpoint.enable_checkpoint",
98-
f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/",
99-
"--checkpoint.model_weights_only",
100-
"--checkpoint.export_dtype bfloat16",
90+
[
91+
"--checkpoint.enable_checkpoint",
92+
f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/",
93+
"--checkpoint.model_weights_only",
94+
"--checkpoint.export_dtype bfloat16",
95+
],
10196
],
102-
],
103-
"Checkpoint Integration Test - Save Model Weights Only bf16",
104-
),
105-
]
97+
"Checkpoint Integration Test - Save Model Weights Only bf16",
98+
),
99+
]
100+
return integration_tests_flavors
106101

107102

108103
def run_test(test_flavor: OverrideDefinitions, full_path: str):
@@ -128,12 +123,33 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str):
128123
)
129124

130125

131-
for config_file in os.listdir(CONFIG_DIR):
132-
if config_file.endswith(".toml"):
133-
full_path = os.path.join(CONFIG_DIR, config_file)
134-
with open(full_path, "rb") as f:
135-
config = tomllib.load(f)
136-
is_integration_test = config["job"].get("use_for_integration_test", False)
137-
if is_integration_test:
138-
for test_flavor in integration_tests_flavors[config_file]:
139-
run_test(test_flavor, full_path)
126+
def run_tests(args):
127+
integration_tests_flavors = build_test_list(args)
128+
for config_file in os.listdir(args.config_dir):
129+
if config_file.endswith(".toml"):
130+
full_path = os.path.join(args.config_dir, config_file)
131+
with open(full_path, "rb") as f:
132+
config = tomllib.load(f)
133+
is_integration_test = config["job"].get(
134+
"use_for_integration_test", False
135+
)
136+
if is_integration_test:
137+
for test_flavor in integration_tests_flavors[config_file]:
138+
run_test(test_flavor, full_path)
139+
140+
141+
def main():
142+
parser = argparse.ArgumentParser()
143+
parser.add_argument("output_dir")
144+
parser.add_argument("--config_dir", default="./train_configs")
145+
args = parser.parse_args()
146+
147+
if not os.path.exists(args.output_dir):
148+
os.makedirs(args.output_dir)
149+
if os.listdir(args.output_dir):
150+
raise RuntimeError("Please provide an empty output directory.")
151+
run_tests(args)
152+
153+
154+
if __name__ == "__main__":
155+
main()

0 commit comments

Comments
 (0)