Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b029c0e
Fixed BindCapacityScheduler to pass peft_cache_manager to the CPP bin…
amitz-nv Jul 15, 2025
c79b271
Removed unnecessary changes in test_llm.py
amitz-nv Jul 15, 2025
d193d30
Refactored LoRA eviction tests
amitz-nv Jul 15, 2025
28d86b1
Added type hint to peft_cache_manager
amitz-nv Jul 15, 2025
b87d58a
Add forgotten llm_args in test_llm.py, fix formatting in test_llm_pyt…
amitz-nv Jul 15, 2025
a7d6ea5
Pass peft_cache_manager=None to BindCapacityScheduler in create_autod…
amitz-nv Jul 15, 2025
0ebe6aa
Fix target name in test
amitz-nv Jul 15, 2025
728b32f
Changed GuaranteedNoEvictScheduler to try call peftCacheManager->dete…
amitz-nv Jul 15, 2025
26923b6
Format comments in tests
amitz-nv Jul 15, 2025
f4875fa
Remove debug prints from test
amitz-nv Jul 15, 2025
f5ca3b4
Update missingPeftTask CPP test to expect the error message starts wi…
amitz-nv Jul 15, 2025
82e18f9
Refactored shared lora test logic into lora_test_utils.py
amitz-nv Jul 15, 2025
9cc4cc4
PeftCacheManager::determineNumPages throws exception with 'not suppor…
amitz-nv Jul 15, 2025
e16aae7
Add docstring to check_multi_unique_lora_adapters_from_request
amitz-nv Jul 15, 2025
8758975
Fix imports of test_llm.py
amitz-nv Jul 15, 2025
bdfa780
Improved check_multi_unique_lora_adapters_from_request docstring
amitz-nv Jul 15, 2025
2c3b771
Fix imports in test_llm_multi_gpu.py and in test_llm_multi_gpu_pytorc…
amitz-nv Jul 15, 2025
6625682
Revert changes in _TrtLLM._build_model, move LLM creation to test so …
amitz-nv Jul 15, 2025
ed68f49
Change the 'should include adapter weights with request' to be based …
amitz-nv Jul 15, 2025
b1d0bf6
test_llm_pytorch.py - Minor docstring fix, readability improvement
amitz-nv Jul 15, 2025
b0f91f2
Update test_llm_multi_gpu_pytorch.py to also disable cuda_graph until…
amitz-nv Jul 15, 2025
10149b7
Fix formatting of lora_test_utils.py
amitz-nv Jul 15, 2025
9e9e02e
Improve test case documentation
amitz-nv Jul 15, 2025
f3c330c
Fix docstring of is_adapter_in_cpu_cache
amitz-nv Jul 15, 2025
abad5c4
Add 'is_task_cached' method binding to CPP PeftCacheManager class
amitz-nv Jul 15, 2025
b6e99e9
Improve comment over not supporting LoRA optimization in TRT-python flow
amitz-nv Jul 15, 2025
b9d6c9e
Change cpp_peft_cache_manager argument in LoraManager constructor to …
amitz-nv Jul 15, 2025
6189c47
Fix typo in lora test
amitz-nv Jul 16, 2025
e4ff01a
Revert added note in exception message in TRT flow, as the LoRA optim…
amitz-nv Jul 16, 2025
6e0b872
Fix LLM args in multi GPU LoRA tests
amitz-nv Jul 17, 2025
1085670
Improve resource release in test util function run_function_in_sub_pr…
amitz-nv Jul 17, 2025
352f429
Improve formatting - split long import line
amitz-nv Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,20 +591,28 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> llmRe
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
if (llmRequest->getLoraTaskId().has_value())
{
auto taskId = llmRequest->getLoraTaskId().value();
try
{
return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value());
return mHostLoraCache->determineNumPages(taskId);
}
catch (std::runtime_error& e)
{
if (llmRequest->getLoraConfig().has_value())
{
return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value());
}
else
if (!llmRequest->getLoraWeights().has_value())
{
throw;
auto const reqId = llmRequest->mRequestId;
std::string errMsg
= "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task "
+ std::to_string(taskId) + " that's not found in LoRA CPU cache."
" Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization,"
" so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported.";
throw PeftTaskNotCachedException(errMsg);
}
throw;
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)

py::classh<tb::PeftCacheManager, tb::BasePeftCacheManager>(m, "PeftCacheManager")
.def(py::init<tb::PeftCacheManagerConfig, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"));
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"))
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, py::arg("taskId"));

py::classh<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager").def(py::init());
}
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)

# scheduling
capacitor_scheduler = BindCapacityScheduler(ad_config.max_batch_size, kv_cache_manager.impl)
capacitor_scheduler = BindCapacityScheduler(
ad_config.max_batch_size, kv_cache_manager.impl, peft_cache_manager=None
)
mb_scheduler = BindMicroBatchScheduler(
ad_config.max_batch_size, engine.cache_seq_interface.info.max_num_tokens
)
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def create_py_executor_instance(
f"Cannot overwrite existing resource manager {key}.")
resources[key] = value

peft_cache_manager = None
if lora_config is not None:
from tensorrt_llm.bindings import LoraModule

Expand Down Expand Up @@ -513,6 +514,7 @@ def create_py_executor_instance(
capacity_scheduler = BindCapacityScheduler(
max_num_sequences,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
executor_config.scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ def update_resources(self, scheduled_batch: ScheduledRequests):
pass

def free_resources(self, request: LlmRequest):
pass
self.impl.mark_request_done(request)

def shutdown(self):
pass
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,14 @@ def __init__(
self,
max_num_requests: int,
kv_cache_manager,
peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None,
scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor.
CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
two_step_lookahead: bool = False,
):
super(BindCapacityScheduler, self).__init__()
self.kv_cache_manager = kv_cache_manager
self.peft_cache_manager = peft_cache_manager

self.impl = tb_internal.algorithms.CapacityScheduler(
max_num_requests=max_num_requests,
Expand All @@ -91,7 +93,8 @@ def __init__(
def schedule_request(
self, active_requests: RequestList
) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]:
return self.impl(active_requests, self.kv_cache_manager)
return self.impl(active_requests, self.kv_cache_manager,
self.peft_cache_manager)


class GuaranteedNoEvictScheduler(CapacityScheduler):
Expand Down
23 changes: 17 additions & 6 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,23 @@ def _create_engine():
self._runtime_model_config = _engine_config_to_model_config(
engine_config)
if engine_config.build_config.plugin_config.lora_plugin:
self._lora_manager = LoraManager()
# TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization
# (see LoraManager constructor docstring). Getting the peft cache manager from this
# point in the TRT flow is currently not supported (it's at the CPP
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
# optimization is not available in TRT-python flow.
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()

if getattr(executor_config, "backend",
"") == "pytorch" and lora_config is not None:
self._lora_manager = LoraManager()
from tensorrt_llm._torch.pyexecutor.resource_manager import \
ResourceManagerType
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
ResourceManagerType.PEFT_CACHE_MANAGER)
self._lora_manager = LoraManager(
cpp_peft_cache_manager=peft_cache_manager.impl)
lora_model_config = self.engine.model_engine.lora_model_config
assert lora_model_config is not None
self._lora_model_config = lora_model_config
Expand Down Expand Up @@ -362,15 +372,16 @@ def _load_prompt_adapter(self,
def _enqueue_request(self, request: GenerationRequest) -> int:
assert request.id is not None
if self._lora_manager is not None and request.lora_request is not None:
loaded_new_lora_adapter = self._load_lora_adapter(
request.lora_request)
adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache(
request.lora_request.adapter_id)
self._load_lora_adapter(request.lora_request)
uid = str(request.lora_request.adapter_id)
lora_config = tllm.LoraConfig(
task_id=request.lora_request.adapter_id,
weights=self._lora_manager.cpp_lora_weights[uid]
if loaded_new_lora_adapter else None,
if not adapter_in_cache else None,
config=self._lora_manager.cpp_lora_config[uid]
if loaded_new_lora_adapter else None)
if not adapter_in_cache else None)
else:
lora_config = None

Expand Down
27 changes: 25 additions & 2 deletions tensorrt_llm/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
import yaml

from tensorrt_llm.bindings import internal as tb_internal

from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from .layers.linear import ColumnLinear
from .mapping import Mapping
Expand Down Expand Up @@ -436,8 +438,16 @@ class LoraManager(object):
"mlp_gate_up": 18,
}

def __init__(self):
"""Constructor."""
def __init__(
self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None
):
"""Constructor.

Args:
cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
the adapter is already loaded in the LoRA CPU cache.
"""
# _lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]]
# {
# uid: {
Expand Down Expand Up @@ -473,6 +483,19 @@ def __init__(self):
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
self.lora_target_modules: List[str] = []
self._cpp_peft_cache_manager = cpp_peft_cache_manager

def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
"""Best effort to check if a LoRA adapter is in the LoRA CPU cache.

If no cpp_peft_cache_manager instance was given at the construction of this LoraManager instance, then False is
returned.
"""
return (
self._cpp_peft_cache_manager.is_task_cached(adapter_uid)
if self._cpp_peft_cache_manager
else False
)

@staticmethod
def get_missing_qkv_modules(lora_target_modules):
Expand Down
116 changes: 116 additions & 0 deletions tests/unittest/llmapi/lora_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import OrderedDict, Type

from utils.llm_data import llm_models_root
from utils.util import duplicate_list_to_length, flatten_list, similar

from tensorrt_llm import SamplingParams
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi.llm import BaseLLM


def check_llama_7b_multi_unique_lora_adapters_from_request(
lora_adapter_count_per_call: list[int], repeat_calls: int,
repeats_per_call: int, llm_class: Type[BaseLLM], **llm_kwargs):
"""Calls llm.generate s.t. for each C in lora_adapter_count_per_call, llm.generate is called with C requests
repeated 'repeats_per_call' times, where each request is configured with a unique LoRA adapter ID.
This entire process is done in a loop 'repeats_per_call' times with the same requests.
Asserts the output of each llm.generate call is similar to the expected.
""" # noqa: D205
total_lora_adapters = sum(lora_adapter_count_per_call)
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
hf_lora_dirs = [
f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1",
f"{llm_models_root()}/llama-models/Japanese-Alpaca-LoRA-7b-v0"
]
# Each prompt should have a reference for every LoRA adapter dir (in the same order as in hf_lora_dirs)
prompt_to_references = OrderedDict({
"美国的首都在哪里? \n答案:": [
"美国的首都是华盛顿。\n\n美国的",
"纽约\n\n### カンファレンスの",
],
"アメリカ合衆国の首都はどこですか? \n答え:": [
"华盛顿。\n\n英国の首都是什",
"ワシントン\nQ1. アメリカ合衆国",
],
})

prompts_to_generate = duplicate_list_to_length(
flatten_list([[prompt] * len(hf_lora_dirs)
for prompt in prompt_to_references.keys()]),
total_lora_adapters)
references = duplicate_list_to_length(
flatten_list(list(prompt_to_references.values())), total_lora_adapters)
lora_requests = [
LoRARequest(str(i), i, hf_lora_dirs[i % len(hf_lora_dirs)])
for i in range(total_lora_adapters)
]
llm = llm_class(hf_model_dir, **llm_kwargs)

# Perform repeats of the same requests to test reuse and reload of adapters previously unloaded from cache
try:
for _ in range(repeat_calls):
last_idx = 0
for adapter_count in lora_adapter_count_per_call:
sampling_params = SamplingParams(max_tokens=20)
outputs = llm.generate(
prompts_to_generate[last_idx:last_idx + adapter_count] *
repeats_per_call,
sampling_params,
lora_request=lora_requests[last_idx:last_idx +
adapter_count] *
repeats_per_call)
for output, ref in zip(
outputs, references[last_idx:last_idx + adapter_count] *
repeats_per_call):
assert similar(output.outputs[0].text, ref)
last_idx += adapter_count
finally:
llm.shutdown()


def check_llama_7b_multi_lora_from_request_test_harness(
llm_class: Type[BaseLLM], **llm_kwargs) -> None:
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
hf_lora_dir1 = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
hf_lora_dir2 = f"{llm_models_root()}/llama-models/Japanese-Alpaca-LoRA-7b-v0"
prompts = [
"美国的首都在哪里? \n答案:",
"美国的首都在哪里? \n答案:",
"美国的首都在哪里? \n答案:",
"アメリカ合衆国の首都はどこですか? \n答え:",
"アメリカ合衆国の首都はどこですか? \n答え:",
"アメリカ合衆国の首都はどこですか? \n答え:",
]
references = [
"沃尔玛\n\n## 新闻\n\n* ",
"美国的首都是华盛顿。\n\n美国的",
"纽约\n\n### カンファレンスの",
"Washington, D.C.\nWashington, D.C. is the capital of the United",
"华盛顿。\n\n英国の首都是什",
"ワシントン\nQ1. アメリカ合衆国",
]
key_words = [
"沃尔玛",
"华盛顿",
"纽约",
"Washington",
"华盛顿",
"ワシントン",
]
lora_req1 = LoRARequest("luotuo", 1, hf_lora_dir1)
lora_req2 = LoRARequest("Japanese", 2, hf_lora_dir2)
sampling_params = SamplingParams(max_tokens=20)

llm = llm_class(hf_model_dir, **llm_kwargs)
try:
outputs = llm.generate(prompts,
sampling_params,
lora_request=[
None, lora_req1, lora_req2, None, lora_req1,
lora_req2
])
finally:
llm.shutdown()
for output, ref, key_word in zip(outputs, references, key_words):
assert similar(output.outputs[0].text,
ref) or key_word in output.outputs[0].text
Loading