Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/disaggregated/clients/long_prompts.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[
"Place of birth\nThe place of birth (POB) or birthplace is the place where a person was born. This place is often used in legal documents, together with name and date of birth, to uniquely identify a person. Practice regarding whether this place should be a country, a territory or a city/town/locality differs in different countries, but often city or territory is used for native-born citizen passports and countries for foreign-born ones.\nAs a general rule with respect to passports, if the place of birth is to be a country, it's determined to be the country that currently has sovereignty over the actual place of birth, regardless of when the birth actually occurred. The place of birth is not necessarily the place where the parents of the new baby live. If the baby is born in a hospital in another place, that place is the place of birth. In many countries, this also means that the government requires that the birth of the new baby is registered in the place of birth.\nSome countries place less or no importance on the place of birth, instead using alternative geographical characteristics for the purpose of identity documents. For example, Sweden has used the concept of födelsehemort (\"domicile of birth\") since 1947. This means that the domicile of the baby's mother is the registered place of birth.\nSimilarly, Switzerland uses the concept of place of origin. A child born to Swiss parents is automatically assigned the place of origin of the parent with the same last name, so the child either gets their mother's or father's place of origin. A child born to one Swiss parent and one foreign parent acquires the place of origin of their Swiss parent. In a Swiss passport and identity card, the holder's place of origin is stated, not their place of birth. In Japan, the registered domicile is a similar concept.\nIn some countries (primarily in the Americas), the place of birth automatically determines the nationality of the baby, a practice often referred to by the Latin phrase jus soli."
]
10 changes: 9 additions & 1 deletion tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,14 @@ def _check_arguments(self, prompt_len: int, query_len: int,
raise ValueError(
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
)
# Check prompt length and query length against max_num_tokens to filter illegal requests.
# Skip check for gen-only requests
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
max_num_tokens = self.args.max_num_tokens
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
raise ValueError(
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
f"max_num_tokens ({max_num_tokens})")
return

build_config = self.args.build_config
Expand All @@ -582,7 +590,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
(sampling_params.max_tokens or 0) > max_seq_len):
raise ValueError(
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed "
f"max_seq_len ({build_config.max_seq_len})")
f"max_seq_len ({max_seq_len})")

if sampling_params.use_beam_search and sampling_params.best_of > build_config.max_beam_width:
if sampling_params.n == sampling_params.best_of:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
hostname: localhost
port: 8000
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
free_gpu_memory_fraction: 0.5
backend: "pytorch"
cuda_graph_config: null
disable_overlap_scheduler: True
context_servers:
num_instances: 1
max_num_tokens: 512
max_batch_size: 256
cache_transceiver_config:
backend: default
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
max_num_tokens: 256
max_batch_size: 128
cache_transceiver_config:
backend: default
urls:
- "localhost:8002"
35 changes: 33 additions & 2 deletions tests/integration/defs/disaggregated/test_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def get_test_config(test_desc, example_dir, test_root):
"""Get test configuration based on test description."""
test_configs_root = f"{test_root}/test_configs"
config_map = {
"2_ranks_diff_max_tokens":
(2, f"{test_configs_root}/disagg_config_diff_max_tokens.yaml"),
"2_ranks": (2, f"{example_dir}/disagg_config.yaml"),
"2_ranks_trt_backend":
(2, f"{test_configs_root}/disagg_config_trt_backend.yaml"),
Expand Down Expand Up @@ -144,7 +146,8 @@ def run_disaggregated_test(example_dir,
test_desc,
num_iters=5,
env=None,
cwd=None):
cwd=None,
prompt_file="prompts.json"):
"""Run disaggregated test with given configuration."""
cleanup_output_files()
run_env = env.copy()
Expand Down Expand Up @@ -185,10 +188,13 @@ def run_disaggregated_test(example_dir,
client_cmd = [
'python3', f'{client_dir}/disagg_client.py', '-c',
f'{example_dir}/disagg_config.yaml', '-p',
f'{client_dir}/prompts.json', '--ignore-eos',
f'{client_dir}/{prompt_file}', '--ignore-eos',
'--server-start-timeout',
str(server_start_timeout)
]
if prompt_file == "long_prompts.json":
# Use max_tokens 4 for long prompts to reduce test time
client_cmd.extend(['--max-tokens', '4'])
check_call(client_cmd,
env=env,
poll_procs=[workers_proc, server_proc])
Expand Down Expand Up @@ -217,6 +223,10 @@ def run_disaggregated_test(example_dir,
env=env,
poll_procs=[workers_proc, server_proc])

# Skip output verification for long prompts test
if prompt_file == "long_prompts.json":
continue

# Verify outputs
not_expected_strings = ["Berlin Berlin"]

Expand Down Expand Up @@ -269,6 +279,27 @@ def run_disaggregated_test(example_dir,
workers_proc.wait()


@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_diff_max_tokens(disaggregated_test_root,
disaggregated_example_root, llm_venv,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)

run_disaggregated_test(disaggregated_example_root,
"2_ranks_diff_max_tokens",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory(),
prompt_file="long_prompts.json")


@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_single_gpu_with_mpirun(disaggregated_test_root,
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ l0_a10:
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_diff_max_tokens[TinyLlama-1.1B-Chat-v1.0]
- test_e2e.py::test_openai_chat_structural_tag_example
- test_e2e.py::test_openai_chat_json_example
- test_e2e.py::test_openai_chat_multimodal_example
Expand Down
32 changes: 22 additions & 10 deletions tests/unittest/llmapi/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,24 +2239,36 @@ def success_path():
success_path()


def _test_llm_capture_request_error(tp_size: int = 1):
build_config = BuildConfig()
build_config.max_num_tokens = 64
def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
llm_args_extra = {}
if pytorch_backend:
LLM_CLASS = LLM_torch
llm_args_extra["max_num_tokens"] = 64
else:
LLM_CLASS = LLM
build_config = BuildConfig()
build_config.max_num_tokens = 64
llm_args_extra["fast_build"] = True
llm_args_extra["build_config"] = build_config

llm = LLM(
llm = LLM_CLASS(
model=llama_model_path,
build_config=build_config,
fast_build=True,
tensor_parallel_size=tp_size,
**llm_args_extra,
)

prompt = 'A ' * 65 # the minimum max_num_tokens is 64

with pytest.raises(RequestError):
llm.generate(prompt)
if pytorch_backend:
# pytorch backend will raise ValueError for max_num_tokens
with pytest.raises(ValueError):
llm.generate(prompt)
else:
with pytest.raises(RequestError):
llm.generate(prompt)


def test_llm_capture_request_error():
_test_llm_capture_request_error(tp_size=1)
_test_llm_capture_request_error(pytorch_backend=False, tp_size=1)


def test_llm_shutdown_executor():
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest/llmapi/test_llm_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_llm_get_stats_async_tp2(pytorch_backend):


def test_llm_capture_request_error():
_test_llm_capture_request_error(tp_size=2)
_test_llm_capture_request_error(pytorch_backend=False, tp_size=2)


def test_llm_with_postprocess_parallel_tp2():
Expand Down
6 changes: 6 additions & 0 deletions tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
from tensorrt_llm.lora_manager import LoraConfig
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
from .test_llm import _test_llm_capture_request_error
# isort: on

global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)


@pytest.mark.gpu2
def test_llm_capture_request_error():
_test_llm_capture_request_error(pytorch_backend=True, tp_size=2)


@pytest.mark.gpu4
def test_tinyllama_logits_processor_tp2pp2():
tinyllama_logits_processor_test_harness(backend="pytorch",
Expand Down
7 changes: 6 additions & 1 deletion tests/unittest/llmapi/test_llm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
check_llama_7b_multi_lora_from_request_test_harness,
check_llama_7b_multi_unique_lora_adapters_from_request,
create_mock_nemo_lora_checkpoint)
from .test_llm import (get_model_path, global_kvcache_config, llama_model_path,
from .test_llm import (_test_llm_capture_request_error, get_model_path,
global_kvcache_config, llama_model_path,
llm_get_stats_async_test_harness,
llm_get_stats_test_harness, prompts,
run_llm_abort_request,
Expand Down Expand Up @@ -76,6 +77,10 @@ def test_llm_get_stats_async(return_context_logits, use_overlap,
enable_iter_req_stats=enable_iter_req_stats)


def test_llm_capture_request_error():
_test_llm_capture_request_error(pytorch_backend=True, tp_size=1)


@force_ampere
@pytest.mark.parametrize(
"sampling_params",
Expand Down