@@ -21,11 +21,12 @@ class OverrideDefinitions:
21
21
This class is used to define the override definitions for the integration tests.
22
22
"""
23
23
24
- override_args : Sequence [str ] = tuple ()
25
- test_descr : str = "default "
24
+ override_args : Sequence [Sequence [ str ]] = tuple (tuple ( " " ) )
25
+ test_descr : str = ""
26
26
27
27
28
28
CONFIG_DIR = "./train_configs"
29
+ test_checkpoint_dir = "./test_runner_checkpoint"
29
30
30
31
"""
31
32
key is the config file name and value is a list of OverrideDefinitions
@@ -34,13 +35,47 @@ class OverrideDefinitions:
34
35
"""
35
36
integration_tests_flavors = defaultdict (list )
36
37
integration_tests_flavors ["debug_model.toml" ] = [
37
- OverrideDefinitions (["--training.compile" ], "1D compile" ),
38
38
OverrideDefinitions (
39
- ["--training.tensor_parallel_degree 2" ], "Eager mode 2DParallel"
39
+ [
40
+ ["--training.compile" ],
41
+ ],
42
+ "1D compile" ,
43
+ ),
44
+ OverrideDefinitions (
45
+ [
46
+ ["--training.tensor_parallel_degree 2" ],
47
+ ],
48
+ "Eager mode 2DParallel" ,
49
+ ),
50
+ OverrideDefinitions (
51
+ [
52
+ [f"--checkpoint.folder { test_checkpoint_dir } " ],
53
+ [f"--checkpoint.folder { test_checkpoint_dir } " , "--training.steps 20" ],
54
+ ],
55
+ "Checkpoint Integration Test" ,
40
56
),
41
57
]
42
58
43
59
60
+ def run_test (test_flavor : OverrideDefinitions , full_path : str ):
61
+ # run_test supports sequence of tests.
62
+ for override_arg in test_flavor .override_args :
63
+ cmd = f"CONFIG_FILE={ full_path } NGPU=4 ./run_llama_train.sh"
64
+ if override_arg :
65
+ cmd += " " + " " .join (override_arg )
66
+ print (
67
+ f"=====Integration test, flavor : { test_flavor .test_descr } , command : { cmd } ====="
68
+ )
69
+ result = subprocess .run (
70
+ [cmd ],
71
+ stdout = subprocess .PIPE ,
72
+ stderr = subprocess .STDOUT ,
73
+ text = True ,
74
+ shell = True ,
75
+ )
76
+ print (result .stdout )
77
+
78
+
44
79
for config_file in os .listdir (CONFIG_DIR ):
45
80
if config_file .endswith (".toml" ):
46
81
full_path = os .path .join (CONFIG_DIR , config_file )
@@ -51,18 +86,6 @@ class OverrideDefinitions:
51
86
test_flavors = [OverrideDefinitions ()] + integration_tests_flavors [
52
87
config_file
53
88
]
89
+
54
90
for test_flavor in test_flavors :
55
- cmd = f"CONFIG_FILE={ full_path } NGPU=4 ./run_llama_train.sh"
56
- if test_flavor .override_args :
57
- cmd += " " + " " .join (test_flavor .override_args )
58
- print (
59
- f"=====Integration test, flavor : { test_flavor .test_descr } , command : { cmd } ====="
60
- )
61
- result = subprocess .run (
62
- [cmd ],
63
- stdout = subprocess .PIPE ,
64
- stderr = subprocess .STDOUT ,
65
- text = True ,
66
- shell = True ,
67
- )
68
- print (result .stdout )
91
+ run_test (test_flavor , full_path )
0 commit comments