Skip to content

Commit 388b491

Browse files
LinPolydc3671
authored andcommitted
[nvbug 5304752][fix] enhance _check_arguments to filter illegal requests for pytorch backend (NVIDIA#5541)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent c321fb8 commit 388b491

File tree

5 files changed

+46
-19
lines changed

5 files changed

+46
-19
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,13 @@ def _check_arguments(self, prompt_len: int, query_len: int,
542542
raise ValueError(
543543
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
544544
)
545+
# Check prompt length and query length against max_num_tokens to filter illegal requests.
546+
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill:
547+
max_num_tokens = self.args.max_num_tokens
548+
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
549+
raise ValueError(
550+
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) and max_tokens ({sampling_params.max_tokens}) should not exceed "
551+
f"max_num_tokens ({max_num_tokens})")
545552
return
546553

547554
build_config = self.args.build_config
@@ -558,7 +565,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
558565
(sampling_params.max_tokens or 0) > max_seq_len):
559566
raise ValueError(
560567
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed "
561-
f"max_seq_len ({build_config.max_seq_len})")
568+
f"max_seq_len ({max_seq_len})")
562569

563570
if sampling_params.use_beam_search and sampling_params.best_of > build_config.max_beam_width:
564571
if sampling_params.n == sampling_params.best_of:

tests/unittest/llmapi/test_llm.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,24 +2111,36 @@ def success_path():
21112111
success_path()
21122112

21132113

2114-
def _test_llm_capture_request_error(tp_size: int = 1):
2115-
build_config = BuildConfig()
2116-
build_config.max_num_tokens = 64
2114+
def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
2115+
llm_args_extra = {}
2116+
if pytorch_backend:
2117+
LLM_CLASS = LLM_torch
2118+
llm_args_extra["max_num_tokens"] = 64
2119+
else:
2120+
LLM_CLASS = LLM
2121+
build_config = BuildConfig()
2122+
build_config.max_num_tokens = 64
2123+
llm_args_extra["fast_build"] = True
2124+
llm_args_extra["build_config"] = build_config
21172125

2118-
llm = LLM(
2126+
llm = LLM_CLASS(
21192127
model=llama_model_path,
2120-
build_config=build_config,
2121-
fast_build=True,
2128+
tensor_parallel_size=tp_size,
2129+
**llm_args_extra,
21222130
)
21232131

21242132
prompt = 'A ' * 65 # the minimum max_num_tokens is 64
2125-
2126-
with pytest.raises(RequestError):
2127-
llm.generate(prompt)
2133+
if pytorch_backend:
2134+
# pytorch backend will raise ValueError for max_num_tokens
2135+
with pytest.raises(ValueError):
2136+
llm.generate(prompt)
2137+
else:
2138+
with pytest.raises(RequestError):
2139+
llm.generate(prompt)
21282140

21292141

21302142
def test_llm_capture_request_error():
2131-
_test_llm_capture_request_error(tp_size=1)
2143+
_test_llm_capture_request_error(pytorch_backend=False, tp_size=1)
21322144

21332145

21342146
def test_llm_shutdown_executor():

tests/unittest/llmapi/test_llm_multi_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def test_llm_get_stats_async_tp2(pytorch_backend):
455455

456456

457457
def test_llm_capture_request_error():
458-
_test_llm_capture_request_error(tp_size=2)
458+
_test_llm_capture_request_error(pytorch_backend=False, tp_size=2)
459459

460460

461461
def test_llm_with_postprocess_parallel_tp2():

tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
from tensorrt_llm.llmapi import KvCacheConfig
66
from .test_llm_pytorch import (llama_7b_lora_from_dir_test_harness,
77
llama_7b_multi_lora_from_request_test_harness)
8-
8+
from .test_llm import _test_llm_capture_request_error
99
# isort: on
1010

1111
global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
1212

1313

14+
@pytest.mark.gpu2
15+
def test_llm_capture_request_error():
16+
_test_llm_capture_request_error(pytorch_backend=True, tp_size=2)
17+
18+
1419
@pytest.mark.gpu4
1520
def test_tinyllama_logits_processor_tp2pp2():
1621
tinyllama_logits_processor_test_harness(backend="pytorch",

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
from tensorrt_llm.sampling_params import SamplingParams
66

77
# isort: off
8-
from .test_llm import (get_model_path, global_kvcache_config, llama_model_path,
9-
llm_get_stats_async_test_harness,
10-
llm_get_stats_test_harness, prompts,
11-
run_llm_abort_request,
12-
run_llm_with_postprocess_parallel_and_result_handler,
13-
tinyllama_logits_processor_test_harness)
8+
from .test_llm import (
9+
get_model_path, global_kvcache_config, llama_model_path,
10+
llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts,
11+
run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler,
12+
tinyllama_logits_processor_test_harness, _test_llm_capture_request_error)
1413
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_gpu_memory_less_than_80gb, skip_gpu_memory_less_than_138gb
1514
from utils.llm_data import llm_models_root
1615
from tensorrt_llm.lora_manager import LoraConfig
@@ -65,6 +64,10 @@ def test_llm_get_stats_async(return_context_logits, use_overlap,
6564
enable_iter_req_stats=enable_iter_req_stats)
6665

6766

67+
def test_llm_capture_request_error():
68+
_test_llm_capture_request_error(pytorch_backend=True, tp_size=1)
69+
70+
6871
@force_ampere
6972
@pytest.mark.parametrize(
7073
"sampling_params",

0 commit comments

Comments
 (0)