Skip to content

Commit f5719de

Browse files
committed
Revert "[nvbug 5304752][fix] enhance _check_arguments to filter illegal requests for pytorch backend (#5541)"
This reverts commit 388b491. Signed-off-by: Pengyun Lin <[email protected]>
1 parent 4d0bcbc commit f5719de

File tree

5 files changed

+18
-47
lines changed

5 files changed

+18
-47
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -529,13 +529,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
529529
raise ValueError(
530530
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
531531
)
532-
# Check prompt length and query length against max_num_tokens to filter illegal requests.
533-
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill:
534-
max_num_tokens = self.args.max_num_tokens
535-
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
536-
raise ValueError(
537-
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) and max_tokens ({sampling_params.max_tokens}) should not exceed "
538-
f"max_num_tokens ({max_num_tokens})")
539532
return
540533

541534
build_config = self.args.build_config
@@ -552,7 +545,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
552545
(sampling_params.max_tokens or 0) > max_seq_len):
553546
raise ValueError(
554547
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed "
555-
f"max_seq_len ({max_seq_len})")
548+
f"max_seq_len ({build_config.max_seq_len})")
556549

557550
if sampling_params.use_beam_search and sampling_params.best_of > build_config.max_beam_width:
558551
if sampling_params.n == sampling_params.best_of:

tests/unittest/llmapi/test_llm.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,37 +2060,24 @@ def success_path():
20602060
success_path()
20612061

20622062

2063-
def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
2064-
llm_args_extra = {}
2065-
if pytorch_backend:
2066-
from tensorrt_llm._torch import LLM as LLM_torch
2067-
LLM_CLASS = LLM_torch
2068-
llm_args_extra["max_num_tokens"] = 64
2069-
else:
2070-
LLM_CLASS = LLM
2071-
build_config = BuildConfig()
2072-
build_config.max_num_tokens = 64
2073-
llm_args_extra["fast_build"] = True
2074-
llm_args_extra["build_config"] = build_config
2063+
def _test_llm_capture_request_error(tp_size: int = 1):
2064+
build_config = BuildConfig()
2065+
build_config.max_num_tokens = 64
20752066

2076-
llm = LLM_CLASS(
2067+
llm = LLM(
20772068
model=llama_model_path,
2078-
tensor_parallel_size=tp_size,
2079-
**llm_args_extra,
2069+
build_config=build_config,
2070+
fast_build=True,
20802071
)
20812072

20822073
prompt = 'A ' * 65 # the minimum max_num_tokens is 64
2083-
if pytorch_backend:
2084-
# pytorch backend will raise ValueError for max_num_tokens
2085-
with pytest.raises(ValueError):
2086-
llm.generate(prompt)
2087-
else:
2088-
with pytest.raises(RequestError):
2089-
llm.generate(prompt)
2074+
2075+
with pytest.raises(RequestError):
2076+
llm.generate(prompt)
20902077

20912078

20922079
def test_llm_capture_request_error():
2093-
_test_llm_capture_request_error(pytorch_backend=False, tp_size=1)
2080+
_test_llm_capture_request_error(tp_size=1)
20942081

20952082

20962083
def test_llm_api_jupyter_scenario():

tests/unittest/llmapi/test_llm_multi_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def test_llm_get_stats_async_tp2(pytorch_backend):
451451

452452

453453
def test_llm_capture_request_error():
454-
_test_llm_capture_request_error(pytorch_backend=False, tp_size=2)
454+
_test_llm_capture_request_error(tp_size=2)
455455

456456

457457
def test_llm_with_postprocess_parallel_tp2():

tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,11 @@
55
from tensorrt_llm.llmapi import KvCacheConfig
66
from .test_llm_pytorch import (llama_v2_13b_lora_test_harness,
77
llama_7b_multi_lora_test_harness)
8-
from .test_llm import _test_llm_capture_request_error
98
# isort: on
109

1110
global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
1211

1312

14-
@pytest.mark.gpu2
15-
def test_llm_capture_request_error():
16-
_test_llm_capture_request_error(pytorch_backend=True, tp_size=2)
17-
18-
1913
@pytest.mark.gpu4
2014
def test_tinyllama_logits_processor_tp2pp2():
2115
tinyllama_logits_processor_test_harness(backend="pytorch",

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from tensorrt_llm.sampling_params import SamplingParams
55

66
# isort: off
7-
from .test_llm import (
8-
get_model_path, global_kvcache_config, llama_model_path,
9-
llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts,
10-
run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler,
11-
tinyllama_logits_processor_test_harness, _test_llm_capture_request_error)
7+
from .test_llm import (get_model_path, global_kvcache_config, llama_model_path,
8+
llm_get_stats_async_test_harness,
9+
llm_get_stats_test_harness, prompts,
10+
run_llm_abort_request,
11+
run_llm_with_postprocess_parallel_and_result_handler,
12+
tinyllama_logits_processor_test_harness)
1213
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_gpu_memory_less_than_80gb, skip_gpu_memory_less_than_138gb
1314
from utils.llm_data import llm_models_root
1415
from tensorrt_llm.lora_manager import LoraConfig
@@ -63,10 +64,6 @@ def test_llm_get_stats_async(return_context_logits, use_overlap,
6364
enable_iter_req_stats=enable_iter_req_stats)
6465

6566

66-
def test_llm_capture_request_error():
67-
_test_llm_capture_request_error(pytorch_backend=True, tp_size=1)
68-
69-
7067
@force_ampere
7168
@pytest.mark.parametrize(
7269
"sampling_params",

0 commit comments

Comments
 (0)