Skip to content

Commit 36698ac

Browse files
committed
update multi nodes tests
Signed-off-by: Xin He (SW-GPU) <[email protected]>
1 parent c7269ea commit 36698ac

File tree

3 files changed

+70
-10
lines changed

3 files changed

+70
-10
lines changed

tests/integration/defs/common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,3 +956,23 @@ def get_dummy_spec_decoding_heads(hf_model_dir,
956956
export_hf_checkpoint(model,
957957
dtype=model.config.torch_dtype,
958958
export_dir=os.path.join(save_dir, 'fp8'))
959+
960+
961+
def get_mmlu_accuracy(output):
962+
mmlu_line = None
963+
for line in output.split('\n'):
964+
if "MMLU weighted average accuracy:" in line:
965+
mmlu_line = line
966+
break
967+
968+
if mmlu_line is None:
969+
raise Exception(
970+
f"Could not find 'MMLU weighted average accuracy:' in output. Full output:\n{output}"
971+
)
972+
973+
mmlu_accuracy = float(
974+
mmlu_line.split("MMLU weighted average accuracy: ")[1].split(" (")[0])
975+
976+
print(f"MMLU weighted average accuracy is: {mmlu_accuracy}")
977+
978+
return mmlu_accuracy

tests/integration/defs/test_e2e.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
from defs.trt_test_alternative import (check_call, check_call_negative_test,
2929
check_output)
3030

31-
from .common import (PluginOptions, convert_weights, prune_checkpoint,
32-
quantize_data, refit_model, venv_check_call)
31+
from .common import (PluginOptions, convert_weights, get_mmlu_accuracy,
32+
prune_checkpoint, quantize_data, refit_model,
33+
venv_check_call)
3334
from .conftest import (llm_models_root, skip_no_sm120, skip_nvlink_inactive,
3435
skip_post_blackwell, skip_pre_blackwell, skip_pre_hopper,
3536
tests_path, unittest_path)
@@ -42,6 +43,7 @@
4243
os.environ['TLLM_LOG_LEVEL'] = 'INFO'
4344

4445
_MEM_FRACTION_50 = 0.5
46+
_MEM_FRACTION_80 = 0.8
4547
_MEM_FRACTION_95 = 0.95
4648

4749

@@ -2574,4 +2576,43 @@ def test_ptp_quickstart_advanced_llama_multi_nodes(llm_root, llm_venv,
25742576
check_call(" ".join(run_cmd), shell=True, env=llm_venv._new_env)
25752577

25762578

2577-
# End of Pivot-To-Python examples
2579+
@pytest.mark.timeout(5400)
2580+
@pytest.mark.skip_less_device_memory(80000)
2581+
@pytest.mark.skip_less_device(4)
2582+
@pytest.mark.parametrize("eval_task", ["mmlu"])
2583+
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(16, 1, 8), (8, 2, 8)],
2584+
ids=["tp16", "tp8pp2"])
2585+
@pytest.mark.parametrize("model_path", [
2586+
pytest.param('llama-3.3-models/Llama-3.3-70B-Instruct',
2587+
marks=skip_pre_hopper),
2588+
pytest.param('llama4-models/Llama-4-Maverick-17B-128E-Instruct',
2589+
marks=skip_pre_hopper),
2590+
pytest.param('llama4-models/Llama-4-Maverick-17B-128E-Instruct-FP8',
2591+
marks=skip_pre_hopper),
2592+
pytest.param('Qwen3/Qwen3-235B-A22B', marks=skip_pre_hopper),
2593+
pytest.param('Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf',
2594+
marks=skip_pre_blackwell),
2595+
pytest.param('DeepSeek-R1/DeepSeek-R1-0528-FP4', marks=skip_pre_blackwell),
2596+
])
2597+
def test_multi_nodes_eval(llm_venv, model_path, tp_size, pp_size, ep_size,
2598+
eval_task):
2599+
if "Llama-4" in model_path and tp_size == 16:
2600+
pytest.skip("Llama-4 with tp16 is not supported")
2601+
2602+
mmlu_threshold = 81.5
2603+
run_cmd = [
2604+
"trtllm-llmapi-launch",
2605+
"trtllm-eval",
2606+
f"--model={llm_models_root()}/{model_path}",
2607+
f"--ep_size={ep_size}",
2608+
f"--tp_size={tp_size}",
2609+
f"--pp_size={pp_size}",
2610+
f"--kv_cache_free_gpu_memory_fraction={_MEM_FRACTION_80}",
2611+
"--max_batch_size=32",
2612+
eval_task,
2613+
]
2614+
output = check_output(" ".join(run_cmd), shell=True, env=llm_venv._new_env)
2615+
2616+
if os.environ.get("SLURM_PROCID", '0') == '0':
2617+
mmlu_accuracy = get_mmlu_accuracy(output)
2618+
assert mmlu_accuracy > mmlu_threshold, f"MMLU accuracy {mmlu_accuracy} is less than threshold {mmlu_threshold}"
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp8-tp16pp1-build]
2-
examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp8-tp16pp1-infer]
3-
examples/test_mixtral.py::test_llm_mixtral_2nodes_8gpus[Mixtral-8x22B-v0.1-plugin-renormalize-tensor_parallel-build]
4-
examples/test_mixtral.py::test_llm_mixtral_2nodes_8gpus[Mixtral-8x22B-v0.1-plugin-renormalize-tensor_parallel-infer]
51
test_e2e.py::test_ptp_quickstart_advanced_deepseek_multi_nodes[DeepSeek-V3]
6-
test_e2e.py::test_ptp_quickstart_advanced_deepseek_multi_nodes[DeepSeek-R1/DeepSeek-R1-0528-FP4]
7-
test_e2e.py::test_ptp_quickstart_advanced_llama_multi_nodes[llama-3.3-models/Llama-3.3-70B-Instruct]
8-
test_e2e.py::test_ptp_quickstart_advanced_llama_multi_nodes[llama4-models/Llama-4-Maverick-17B-128E-Instruct]
2+
test_e2e.py::test_multi_nodes_eval[DeepSeek-R1/DeepSeek-R1-0528-FP4-tp16-mmlu]
3+
test_e2e.py::test_multi_nodes_eval[llama-3.3-models/Llama-3.3-70B-Instruct-tp16-mmlu]
4+
test_e2e.py::test_multi_nodes_eval[llama4-models/Llama-4-Maverick-17B-128E-Instruct-tp8pp2-mmlu]
5+
test_e2e.py::test_multi_nodes_eval[llama4-models/Llama-4-Maverick-17B-128E-Instruct-FP8-Instruct-tp8pp2-mmlu]
6+
test_e2e.py::test_multi_nodes_eval[Qwen3/Qwen3-235B-A22B-tp16-mmlu]
7+
test_e2e.py::test_multi_nodes_eval[Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf-tp16-mmlu]
98
test_e2e.py::test_openai_multinodes_chat_tp16pp1

0 commit comments

Comments
 (0)