17
17
import tomli as tomllib
18
18
19
19
20
- parser = argparse .ArgumentParser ()
21
- parser .add_argument ("output_dir" )
22
- args = parser .parse_args ()
23
-
24
-
25
20
@dataclass
26
21
class OverrideDefinitions :
27
22
"""
@@ -32,77 +27,77 @@ class OverrideDefinitions:
32
27
test_descr : str = "default"
33
28
34
29
35
- CONFIG_DIR = "./train_configs"
36
-
37
- """
38
- key is the config file name and value is a list of OverrideDefinitions
39
- that is used to generate variations of integration tests based on the
40
- same root config file.
41
- """
42
- integration_tests_flavors = defaultdict (list )
43
- integration_tests_flavors ["debug_model.toml" ] = [
44
- OverrideDefinitions (
45
- [
46
- [
47
- f"--job.dump_folder { args .output_dir } /default/" ,
48
- ],
49
- ],
50
- "Default" ,
51
- ),
52
- OverrideDefinitions (
53
- [
30
+ def build_test_list (args ):
31
+ """
32
+ key is the config file name and value is a list of OverrideDefinitions
33
+ that is used to generate variations of integration tests based on the
34
+ same root config file.
35
+ """
36
+ integration_tests_flavors = defaultdict (list )
37
+ integration_tests_flavors ["debug_model.toml" ] = [
38
+ OverrideDefinitions (
54
39
[
55
- "--training.compile" ,
56
- f"--job.dump_folder { args .output_dir } /1d_compile/" ,
40
+ [
41
+ f"--job.dump_folder { args .output_dir } /default/" ,
42
+ ],
57
43
],
58
- ],
59
- "1D compile" ,
60
- ),
61
- OverrideDefinitions (
62
- [
44
+ "Default" ,
45
+ ),
46
+ OverrideDefinitions (
63
47
[
64
- "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm" ,
65
- f"--job.dump_folder { args .output_dir } /eager_2d/" ,
48
+ [
49
+ "--training.compile" ,
50
+ f"--job.dump_folder { args .output_dir } /1d_compile/" ,
51
+ ],
66
52
],
67
- ],
68
- "Eager mode 2DParallel" ,
69
- ),
70
- OverrideDefinitions (
71
- [
53
+ "1D compile" ,
54
+ ),
55
+ OverrideDefinitions (
72
56
[
73
- "--checkpoint.enable_checkpoint" ,
74
- f"--job.dump_folder { args .output_dir } /full_checkpoint/" ,
57
+ [
58
+ "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm" ,
59
+ f"--job.dump_folder { args .output_dir } /eager_2d/" ,
60
+ ],
75
61
],
62
+ "Eager mode 2DParallel" ,
63
+ ),
64
+ OverrideDefinitions (
76
65
[
77
- "--checkpoint.enable_checkpoint" ,
78
- f"--job.dump_folder { args .output_dir } /full_checkpoint/" ,
79
- "--training.steps 20" ,
66
+ [
67
+ "--checkpoint.enable_checkpoint" ,
68
+ f"--job.dump_folder { args .output_dir } /full_checkpoint/" ,
69
+ ],
70
+ [
71
+ "--checkpoint.enable_checkpoint" ,
72
+ f"--job.dump_folder { args .output_dir } /full_checkpoint/" ,
73
+ "--training.steps 20" ,
74
+ ],
80
75
],
81
- ],
82
- "Checkpoint Integration Test - Save Load Full Checkpoint" ,
83
- ),
84
- OverrideDefinitions (
85
- [
76
+ "Checkpoint Integration Test - Save Load Full Checkpoint" ,
77
+ ),
78
+ OverrideDefinitions (
86
79
[
87
- "--checkpoint.enable_checkpoint" ,
88
- f"--job.dump_folder { args .output_dir } /model_weights_only_fp32/" ,
89
- "--checkpoint.model_weights_only" ,
80
+ [
81
+ "--checkpoint.enable_checkpoint" ,
82
+ f"--job.dump_folder { args .output_dir } /model_weights_only_fp32/" ,
83
+ "--checkpoint.model_weights_only" ,
84
+ ],
90
85
],
91
- ],
92
- "Checkpoint Integration Test - Save Model Weights Only fp32" ,
93
- ),
94
- OverrideDefinitions (
95
- [
86
+ "Checkpoint Integration Test - Save Model Weights Only fp32" ,
87
+ ),
88
+ OverrideDefinitions (
96
89
[
97
- "--checkpoint.enable_checkpoint" ,
98
- f"--job.dump_folder { args .output_dir } /model_weights_only_bf16/" ,
99
- "--checkpoint.model_weights_only" ,
100
- "--checkpoint.export_dtype bfloat16" ,
90
+ [
91
+ "--checkpoint.enable_checkpoint" ,
92
+ f"--job.dump_folder { args .output_dir } /model_weights_only_bf16/" ,
93
+ "--checkpoint.model_weights_only" ,
94
+ "--checkpoint.export_dtype bfloat16" ,
95
+ ],
101
96
],
102
- ] ,
103
- "Checkpoint Integration Test - Save Model Weights Only bf16" ,
104
- ),
105
- ]
97
+ "Checkpoint Integration Test - Save Model Weights Only bf16" ,
98
+ ) ,
99
+ ]
100
+ return integration_tests_flavors
106
101
107
102
108
103
def run_test (test_flavor : OverrideDefinitions , full_path : str ):
@@ -128,12 +123,33 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str):
128
123
)
129
124
130
125
131
- for config_file in os .listdir (CONFIG_DIR ):
132
- if config_file .endswith (".toml" ):
133
- full_path = os .path .join (CONFIG_DIR , config_file )
134
- with open (full_path , "rb" ) as f :
135
- config = tomllib .load (f )
136
- is_integration_test = config ["job" ].get ("use_for_integration_test" , False )
137
- if is_integration_test :
138
- for test_flavor in integration_tests_flavors [config_file ]:
139
- run_test (test_flavor , full_path )
126
+ def run_tests (args ):
127
+ integration_tests_flavors = build_test_list (args )
128
+ for config_file in os .listdir (args .config_dir ):
129
+ if config_file .endswith (".toml" ):
130
+ full_path = os .path .join (args .config_dir , config_file )
131
+ with open (full_path , "rb" ) as f :
132
+ config = tomllib .load (f )
133
+ is_integration_test = config ["job" ].get (
134
+ "use_for_integration_test" , False
135
+ )
136
+ if is_integration_test :
137
+ for test_flavor in integration_tests_flavors [config_file ]:
138
+ run_test (test_flavor , full_path )
139
+
140
+
141
+ def main ():
142
+ parser = argparse .ArgumentParser ()
143
+ parser .add_argument ("output_dir" )
144
+ parser .add_argument ("--config_dir" , default = "./train_configs" )
145
+ args = parser .parse_args ()
146
+
147
+ if not os .path .exists (args .output_dir ):
148
+ os .makedirs (args .output_dir )
149
+ if os .listdir (args .output_dir ):
150
+ raise RuntimeError ("Please provide an empty output directory." )
151
+ run_tests (args )
152
+
153
+
154
+ if __name__ == "__main__" :
155
+ main ()
0 commit comments