|
28 | 28 | from defs.trt_test_alternative import (check_call, check_call_negative_test,
|
29 | 29 | check_output)
|
30 | 30 |
|
31 |
| -from .common import (PluginOptions, convert_weights, prune_checkpoint, |
32 |
| - quantize_data, refit_model, venv_check_call) |
| 31 | +from .common import (PluginOptions, convert_weights, get_mmlu_accuracy, |
| 32 | + prune_checkpoint, quantize_data, refit_model, |
| 33 | + venv_check_call) |
33 | 34 | from .conftest import (llm_models_root, skip_no_sm120, skip_nvlink_inactive,
|
34 | 35 | skip_post_blackwell, skip_pre_blackwell, skip_pre_hopper,
|
35 | 36 | tests_path, unittest_path)
|
|
42 | 43 | os.environ['TLLM_LOG_LEVEL'] = 'INFO'
|
43 | 44 |
|
44 | 45 | _MEM_FRACTION_50 = 0.5
|
| 46 | +_MEM_FRACTION_80 = 0.8 |
45 | 47 | _MEM_FRACTION_95 = 0.95
|
46 | 48 |
|
47 | 49 |
|
@@ -2574,4 +2576,43 @@ def test_ptp_quickstart_advanced_llama_multi_nodes(llm_root, llm_venv,
|
2574 | 2576 | check_call(" ".join(run_cmd), shell=True, env=llm_venv._new_env)
|
2575 | 2577 |
|
2576 | 2578 |
|
2577 |
| -# End of Pivot-To-Python examples |
| 2579 | +@pytest.mark.timeout(5400) |
| 2580 | +@pytest.mark.skip_less_device_memory(80000) |
| 2581 | +@pytest.mark.skip_less_device(4) |
| 2582 | +@pytest.mark.parametrize("eval_task", ["mmlu"]) |
| 2583 | +@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(16, 1, 8), (8, 2, 8)], |
| 2584 | + ids=["tp16", "tp8pp2"]) |
| 2585 | +@pytest.mark.parametrize("model_path", [ |
| 2586 | + pytest.param('llama-3.3-models/Llama-3.3-70B-Instruct', |
| 2587 | + marks=skip_pre_hopper), |
| 2588 | + pytest.param('llama4-models/Llama-4-Maverick-17B-128E-Instruct', |
| 2589 | + marks=skip_pre_hopper), |
| 2590 | + pytest.param('llama4-models/Llama-4-Maverick-17B-128E-Instruct-FP8', |
| 2591 | + marks=skip_pre_hopper), |
| 2592 | + pytest.param('Qwen3/Qwen3-235B-A22B', marks=skip_pre_hopper), |
| 2593 | + pytest.param('Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf', |
| 2594 | + marks=skip_pre_blackwell), |
| 2595 | + pytest.param('DeepSeek-R1/DeepSeek-R1-0528-FP4', marks=skip_pre_blackwell), |
| 2596 | +]) |
| 2597 | +def test_multi_nodes_eval(llm_venv, model_path, tp_size, pp_size, ep_size, |
| 2598 | + eval_task): |
| 2599 | + if "Llama-4" in model_path and tp_size == 16: |
| 2600 | + pytest.skip("Llama-4 with tp16 is not supported") |
| 2601 | + |
| 2602 | + mmlu_threshold = 81.5 |
| 2603 | + run_cmd = [ |
| 2604 | + "trtllm-llmapi-launch", |
| 2605 | + "trtllm-eval", |
| 2606 | + f"--model={llm_models_root()}/{model_path}", |
| 2607 | + f"--ep_size={ep_size}", |
| 2608 | + f"--tp_size={tp_size}", |
| 2609 | + f"--pp_size={pp_size}", |
| 2610 | + f"--kv_cache_free_gpu_memory_fraction={_MEM_FRACTION_80}", |
| 2611 | + "--max_batch_size=32", |
| 2612 | + eval_task, |
| 2613 | + ] |
| 2614 | + output = check_output(" ".join(run_cmd), shell=True, env=llm_venv._new_env) |
| 2615 | + |
| 2616 | + if os.environ.get("SLURM_PROCID", '0') == '0': |
| 2617 | + mmlu_accuracy = get_mmlu_accuracy(output) |
| 2618 | + assert mmlu_accuracy > mmlu_threshold, f"MMLU accuracy {mmlu_accuracy} is less than threshold {mmlu_threshold}" |
0 commit comments