Skip to content

Commit a15e333

Browse files
authored
[None][fix] Revert commit 48ddc3d & add test for disagg server with different max_num_tokens (#6259)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent 8c82ee2 commit a15e333

File tree

9 files changed

+104
-15
lines changed

9 files changed

+104
-15
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[
2+
"Place of birth\nThe place of birth (POB) or birthplace is the place where a person was born. This place is often used in legal documents, together with name and date of birth, to uniquely identify a person. Practice regarding whether this place should be a country, a territory or a city/town/locality differs in different countries, but often city or territory is used for native-born citizen passports and countries for foreign-born ones.\nAs a general rule with respect to passports, if the place of birth is to be a country, it's determined to be the country that currently has sovereignty over the actual place of birth, regardless of when the birth actually occurred. The place of birth is not necessarily the place where the parents of the new baby live. If the baby is born in a hospital in another place, that place is the place of birth. In many countries, this also means that the government requires that the birth of the new baby is registered in the place of birth.\nSome countries place less or no importance on the place of birth, instead using alternative geographical characteristics for the purpose of identity documents. For example, Sweden has used the concept of födelsehemort (\"domicile of birth\") since 1947. This means that the domicile of the baby's mother is the registered place of birth.\nSimilarly, Switzerland uses the concept of place of origin. A child born to Swiss parents is automatically assigned the place of origin of the parent with the same last name, so the child either gets their mother's or father's place of origin. A child born to one Swiss parent and one foreign parent acquires the place of origin of their Swiss parent. In a Swiss passport and identity card, the holder's place of origin is stated, not their place of birth. In Japan, the registered domicile is a similar concept.\nIn some countries (primarily in the Americas), the place of birth automatically determines the nationality of the baby, a practice often referred to by the Latin phrase jus soli."
3+
]

tensorrt_llm/llmapi/llm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,14 @@ def _check_arguments(self, prompt_len: int, query_len: int,
566566
raise ValueError(
567567
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
568568
)
569+
# Check prompt length and query length against max_num_tokens to filter illegal requests.
570+
# Skip check for gen-only requests
571+
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
572+
max_num_tokens = self.args.max_num_tokens
573+
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
574+
raise ValueError(
575+
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
576+
f"max_num_tokens ({max_num_tokens})")
569577
return
570578

571579
build_config = self.args.build_config
@@ -582,7 +590,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
582590
(sampling_params.max_tokens or 0) > max_seq_len):
583591
raise ValueError(
584592
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed "
585-
f"max_seq_len ({build_config.max_seq_len})")
593+
f"max_seq_len ({max_seq_len})")
586594

587595
if sampling_params.use_beam_search and sampling_params.best_of > build_config.max_beam_width:
588596
if sampling_params.n == sampling_params.best_of:
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
hostname: localhost
2+
port: 8000
3+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
4+
free_gpu_memory_fraction: 0.5
5+
backend: "pytorch"
6+
cuda_graph_config: null
7+
disable_overlap_scheduler: True
8+
context_servers:
9+
num_instances: 1
10+
max_num_tokens: 512
11+
max_batch_size: 256
12+
cache_transceiver_config:
13+
backend: default
14+
urls:
15+
- "localhost:8001"
16+
generation_servers:
17+
num_instances: 1
18+
max_num_tokens: 256
19+
max_batch_size: 128
20+
cache_transceiver_config:
21+
backend: default
22+
urls:
23+
- "localhost:8002"

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def get_test_config(test_desc, example_dir, test_root):
3636
"""Get test configuration based on test description."""
3737
test_configs_root = f"{test_root}/test_configs"
3838
config_map = {
39+
"2_ranks_diff_max_tokens":
40+
(2, f"{test_configs_root}/disagg_config_diff_max_tokens.yaml"),
3941
"2_ranks": (2, f"{example_dir}/disagg_config.yaml"),
4042
"2_ranks_trt_backend":
4143
(2, f"{test_configs_root}/disagg_config_trt_backend.yaml"),
@@ -144,7 +146,8 @@ def run_disaggregated_test(example_dir,
144146
test_desc,
145147
num_iters=5,
146148
env=None,
147-
cwd=None):
149+
cwd=None,
150+
prompt_file="prompts.json"):
148151
"""Run disaggregated test with given configuration."""
149152
cleanup_output_files()
150153
run_env = env.copy()
@@ -185,10 +188,13 @@ def run_disaggregated_test(example_dir,
185188
client_cmd = [
186189
'python3', f'{client_dir}/disagg_client.py', '-c',
187190
f'{example_dir}/disagg_config.yaml', '-p',
188-
f'{client_dir}/prompts.json', '--ignore-eos',
191+
f'{client_dir}/{prompt_file}', '--ignore-eos',
189192
'--server-start-timeout',
190193
str(server_start_timeout)
191194
]
195+
if prompt_file == "long_prompts.json":
196+
# Use max_tokens 4 for long prompts to reduce test time
197+
client_cmd.extend(['--max-tokens', '4'])
192198
check_call(client_cmd,
193199
env=env,
194200
poll_procs=[workers_proc, server_proc])
@@ -217,6 +223,10 @@ def run_disaggregated_test(example_dir,
217223
env=env,
218224
poll_procs=[workers_proc, server_proc])
219225

226+
# Skip output verification for long prompts test
227+
if prompt_file == "long_prompts.json":
228+
continue
229+
220230
# Verify outputs
221231
not_expected_strings = ["Berlin Berlin"]
222232

@@ -269,6 +279,27 @@ def run_disaggregated_test(example_dir,
269279
workers_proc.wait()
270280

271281

282+
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
283+
indirect=True)
284+
def test_disaggregated_diff_max_tokens(disaggregated_test_root,
285+
disaggregated_example_root, llm_venv,
286+
llama_model_root):
287+
src_dst_dict = {
288+
llama_model_root:
289+
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
290+
}
291+
for src, dst in src_dst_dict.items():
292+
if not os.path.islink(dst):
293+
os.makedirs(os.path.dirname(dst), exist_ok=True)
294+
os.symlink(src, dst, target_is_directory=True)
295+
296+
run_disaggregated_test(disaggregated_example_root,
297+
"2_ranks_diff_max_tokens",
298+
env=llm_venv._new_env,
299+
cwd=llm_venv.get_working_directory(),
300+
prompt_file="long_prompts.json")
301+
302+
272303
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
273304
indirect=True)
274305
def test_disaggregated_single_gpu_with_mpirun(disaggregated_test_root,

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ l0_a10:
2121
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
2222
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]
2323
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
24+
- disaggregated/test_disaggregated.py::test_disaggregated_diff_max_tokens[TinyLlama-1.1B-Chat-v1.0]
2425
- test_e2e.py::test_openai_chat_structural_tag_example
2526
- test_e2e.py::test_openai_chat_json_example
2627
- test_e2e.py::test_openai_chat_multimodal_example

tests/unittest/llmapi/test_llm.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,24 +2239,36 @@ def success_path():
22392239
success_path()
22402240

22412241

2242-
def _test_llm_capture_request_error(tp_size: int = 1):
2243-
build_config = BuildConfig()
2244-
build_config.max_num_tokens = 64
2242+
def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
2243+
llm_args_extra = {}
2244+
if pytorch_backend:
2245+
LLM_CLASS = LLM_torch
2246+
llm_args_extra["max_num_tokens"] = 64
2247+
else:
2248+
LLM_CLASS = LLM
2249+
build_config = BuildConfig()
2250+
build_config.max_num_tokens = 64
2251+
llm_args_extra["fast_build"] = True
2252+
llm_args_extra["build_config"] = build_config
22452253

2246-
llm = LLM(
2254+
llm = LLM_CLASS(
22472255
model=llama_model_path,
2248-
build_config=build_config,
2249-
fast_build=True,
2256+
tensor_parallel_size=tp_size,
2257+
**llm_args_extra,
22502258
)
22512259

22522260
prompt = 'A ' * 65 # the minimum max_num_tokens is 64
2253-
2254-
with pytest.raises(RequestError):
2255-
llm.generate(prompt)
2261+
if pytorch_backend:
2262+
# pytorch backend will raise ValueError for max_num_tokens
2263+
with pytest.raises(ValueError):
2264+
llm.generate(prompt)
2265+
else:
2266+
with pytest.raises(RequestError):
2267+
llm.generate(prompt)
22562268

22572269

22582270
def test_llm_capture_request_error():
2259-
_test_llm_capture_request_error(tp_size=1)
2271+
_test_llm_capture_request_error(pytorch_backend=False, tp_size=1)
22602272

22612273

22622274
def test_llm_shutdown_executor():

tests/unittest/llmapi/test_llm_multi_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def test_llm_get_stats_async_tp2(pytorch_backend):
463463

464464

465465
def test_llm_capture_request_error():
466-
_test_llm_capture_request_error(tp_size=2)
466+
_test_llm_capture_request_error(pytorch_backend=False, tp_size=2)
467467

468468

469469
def test_llm_with_postprocess_parallel_tp2():

tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
from tensorrt_llm.lora_manager import LoraConfig
88
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
99
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
10+
from .test_llm import _test_llm_capture_request_error
1011
# isort: on
1112

1213
global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
1314

1415

16+
@pytest.mark.gpu2
17+
def test_llm_capture_request_error():
18+
_test_llm_capture_request_error(pytorch_backend=True, tp_size=2)
19+
20+
1521
@pytest.mark.gpu4
1622
def test_tinyllama_logits_processor_tp2pp2():
1723
tinyllama_logits_processor_test_harness(backend="pytorch",

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
check_llama_7b_multi_lora_from_request_test_harness,
1111
check_llama_7b_multi_unique_lora_adapters_from_request,
1212
create_mock_nemo_lora_checkpoint)
13-
from .test_llm import (get_model_path, global_kvcache_config, llama_model_path,
13+
from .test_llm import (_test_llm_capture_request_error, get_model_path,
14+
global_kvcache_config, llama_model_path,
1415
llm_get_stats_async_test_harness,
1516
llm_get_stats_test_harness, prompts,
1617
run_llm_abort_request,
@@ -76,6 +77,10 @@ def test_llm_get_stats_async(return_context_logits, use_overlap,
7677
enable_iter_req_stats=enable_iter_req_stats)
7778

7879

80+
def test_llm_capture_request_error():
81+
_test_llm_capture_request_error(pytorch_backend=True, tp_size=1)
82+
83+
7984
@force_ampere
8085
@pytest.mark.parametrize(
8186
"sampling_params",

0 commit comments

Comments
 (0)