11
11
from dataclasses import dataclass
12
12
from typing import Sequence
13
13
14
+ from torchtitan .logging_utils import logger
15
+
14
16
try :
15
17
import tomllib
16
18
except ModuleNotFoundError :
@@ -25,6 +27,8 @@ class OverrideDefinitions:
25
27
26
28
override_args : Sequence [Sequence [str ]] = tuple (tuple (" " ))
27
29
test_descr : str = "default"
30
+ requires_seed_checkpoint : bool = False
31
+ ngpu : int = 4
28
32
29
33
30
34
def build_test_list (args ):
@@ -35,6 +39,78 @@ def build_test_list(args):
35
39
"""
36
40
integration_tests_flavors = defaultdict (list )
37
41
integration_tests_flavors ["debug_model.toml" ] = [
42
+ OverrideDefinitions (
43
+ [
44
+ [
45
+ "--checkpoint.enable_checkpoint" ,
46
+ f"--job.dump_folder { args .output_dir } /pp_1f1b/" ,
47
+ "--experimental.pipeline_parallel_degree 2" ,
48
+ "--experimental.pipeline_parallel_split_points layers.1" ,
49
+ "--experimental.pipeline_parallel_schedule 1f1b" ,
50
+ "--training.data_parallel_degree 1" ,
51
+ ],
52
+ ],
53
+ "PP 1D test 1f1b" ,
54
+ requires_seed_checkpoint = True ,
55
+ ngpu = 2 ,
56
+ ),
57
+ OverrideDefinitions (
58
+ [
59
+ [
60
+ "--checkpoint.enable_checkpoint" ,
61
+ f"--job.dump_folder { args .output_dir } /pp_gpipe/" ,
62
+ "--experimental.pipeline_parallel_degree 2" ,
63
+ "--experimental.pipeline_parallel_split_points layers.1" ,
64
+ "--experimental.pipeline_parallel_schedule gpipe" ,
65
+ "--training.data_parallel_degree 1" ,
66
+ ],
67
+ ],
68
+ "PP 1D test gpipe" ,
69
+ requires_seed_checkpoint = True ,
70
+ ngpu = 2 ,
71
+ ),
72
+ OverrideDefinitions (
73
+ [
74
+ [
75
+ "--checkpoint.enable_checkpoint" ,
76
+ f"--job.dump_folder { args .output_dir } /pp_dp_1f1b/" ,
77
+ "--experimental.pipeline_parallel_degree 2" ,
78
+ "--experimental.pipeline_parallel_split_points layers.1" ,
79
+ "--experimental.pipeline_parallel_schedule 1f1b" ,
80
+ "--training.data_parallel_degree 2" ,
81
+ ],
82
+ ],
83
+ "PP+DP 1f1b 2D test" ,
84
+ requires_seed_checkpoint = True ,
85
+ ),
86
+ OverrideDefinitions (
87
+ [
88
+ [
89
+ "--checkpoint.enable_checkpoint" ,
90
+ f"--job.dump_folder { args .output_dir } /pp_dp_gpipe/" ,
91
+ "--experimental.pipeline_parallel_degree 2" ,
92
+ "--experimental.pipeline_parallel_split_points layers.1" ,
93
+ "--experimental.pipeline_parallel_schedule gpipe" ,
94
+ "--training.data_parallel_degree 2" ,
95
+ ],
96
+ ],
97
+ "PP+DP gpipe 2D test" ,
98
+ requires_seed_checkpoint = True ,
99
+ ),
100
+ OverrideDefinitions (
101
+ [
102
+ [
103
+ "--checkpoint.enable_checkpoint" ,
104
+ f"--job.dump_folder { args .output_dir } /pp_tp/" ,
105
+ "--experimental.pipeline_parallel_degree 2" ,
106
+ "--experimental.pipeline_parallel_split_points layers.1" ,
107
+ "--training.tensor_parallel_degree 2" ,
108
+ "--model.norm_type rmsnorm" , # fused_rmsnorm not yet compatible with TP
109
+ ],
110
+ ],
111
+ "PP+TP 2D test" ,
112
+ requires_seed_checkpoint = True ,
113
+ ),
38
114
OverrideDefinitions (
39
115
[
40
116
[
@@ -100,23 +176,43 @@ def build_test_list(args):
100
176
return integration_tests_flavors
101
177
102
178
179
+ def _run_cmd (cmd ):
180
+ return subprocess .run (
181
+ [cmd ],
182
+ stdout = subprocess .PIPE ,
183
+ stderr = subprocess .STDOUT ,
184
+ text = True ,
185
+ shell = True ,
186
+ )
187
+
188
+
103
189
def run_test (test_flavor : OverrideDefinitions , full_path : str ):
104
190
# run_test supports sequence of tests.
105
191
for override_arg in test_flavor .override_args :
106
- cmd = f"CONFIG_FILE={ full_path } NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"
192
+
193
+ cmd = f"CONFIG_FILE={ full_path } NGPU={ test_flavor .ngpu } LOG_RANK=0,1,2,3 ./run_llama_train.sh"
107
194
if override_arg :
108
195
cmd += " " + " " .join (override_arg )
109
- print (
196
+ logger . info (
110
197
f"=====Integration test, flavor : { test_flavor .test_descr } , command : { cmd } ====="
111
198
)
112
- result = subprocess .run (
113
- [cmd ],
114
- stdout = subprocess .PIPE ,
115
- stderr = subprocess .STDOUT ,
116
- text = True ,
117
- shell = True ,
118
- )
119
- print (result .stdout )
199
+
200
+ if test_flavor .requires_seed_checkpoint :
201
+ dump_folder_arg = None
202
+ for arg in override_arg :
203
+ if "--job.dump_folder" in arg :
204
+ dump_folder_arg = arg
205
+ assert (
206
+ dump_folder_arg is not None
207
+ ), "Can't use seed checkpoint if folder is not specified"
208
+ logger .info ("Creating seed checkpoint" )
209
+ result = _run_cmd (
210
+ f"CONFIG_FILE={ full_path } ./create_seed_checkpoint.sh { dump_folder_arg } "
211
+ )
212
+ logger .info (result .stdout )
213
+
214
+ result = _run_cmd (cmd )
215
+ logger .info (result .stdout )
120
216
if result .returncode != 0 :
121
217
raise Exception (
122
218
f"Integration test failed, flavor : { test_flavor .test_descr } , command : { cmd } "
0 commit comments