|
5 | 5 | # All rights reserved.
|
6 | 6 | import os
|
7 | 7 | import subprocess
|
| 8 | +from collections import defaultdict |
| 9 | +from dataclasses import dataclass |
| 10 | +from typing import Sequence |
8 | 11 |
|
9 | 12 | try:
|
10 | 13 | import tomllib
|
11 | 14 | except ModuleNotFoundError:
|
12 | 15 | import tomli as tomllib
|
13 | 16 |
|
| 17 | + |
| 18 | +@dataclass |
| 19 | +class OverrideDefinitions: |
| 20 | + """ |
| 21 | + This class is used to define the override definitions for the integration tests. |
| 22 | + """ |
| 23 | + |
| 24 | + override_args: Sequence[str] = tuple() |
| 25 | + test_descr: str = "default" |
| 26 | + |
| 27 | + |
14 | 28 | CONFIG_DIR = "./train_configs"
|
| 29 | + |
| 30 | +""" |
| 31 | +key is the config file name and value is a list of OverrideDefinitions |
| 32 | +that is used to generate variations of integration tests based on the |
| 33 | +same root config file. |
| 34 | +""" |
| 35 | +integration_tests_flavors = defaultdict(list) |
| 36 | +integration_tests_flavors["debug_model.toml"] = [ |
| 37 | + OverrideDefinitions(["--training.compile"], "1D compile"), |
| 38 | + OverrideDefinitions( |
| 39 | + ["--training.tensor_parallel_degree 2"], "Eager mode 2DParallel" |
| 40 | + ), |
| 41 | +] |
| 42 | + |
| 43 | + |
15 | 44 | for config_file in os.listdir(CONFIG_DIR):
|
16 | 45 | if config_file.endswith(".toml"):
|
17 | 46 | full_path = os.path.join(CONFIG_DIR, config_file)
|
18 | 47 | with open(full_path, "rb") as f:
|
19 | 48 | config = tomllib.load(f)
|
20 | 49 | is_integration_test = config["job"].get("use_for_integration_test", False)
|
21 | 50 | if is_integration_test:
|
22 |
| - cmd = f"CONFIG_FILE={full_path} NGPU=4 ./run_llama_train.sh" |
23 |
| - print(f"=====Integration test: {cmd}=====") |
24 |
| - result = subprocess.run( |
25 |
| - [cmd], |
26 |
| - stdout=subprocess.PIPE, |
27 |
| - stderr=subprocess.STDOUT, |
28 |
| - text=True, |
29 |
| - shell=True, |
30 |
| - ) |
31 |
| - print(result.stdout) |
| 51 | + test_flavors = [OverrideDefinitions()] + integration_tests_flavors[ |
| 52 | + config_file |
| 53 | + ] |
| 54 | + for test_flavor in test_flavors: |
| 55 | + cmd = f"CONFIG_FILE={full_path} NGPU=4 ./run_llama_train.sh" |
| 56 | + if test_flavor.override_args: |
| 57 | + cmd += " " + " ".join(test_flavor.override_args) |
| 58 | + print( |
| 59 | + f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" |
| 60 | + ) |
| 61 | + result = subprocess.run( |
| 62 | + [cmd], |
| 63 | + stdout=subprocess.PIPE, |
| 64 | + stderr=subprocess.STDOUT, |
| 65 | + text=True, |
| 66 | + shell=True, |
| 67 | + ) |
| 68 | + print(result.stdout) |
0 commit comments