Skip to content
Merged
2 changes: 1 addition & 1 deletion docs/source/torch/features/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The PyTorch backend provides LoRA support, allowing you to:

```python
from tensorrt_llm import LLM
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.sampling_params import SamplingParams

Expand Down
2 changes: 1 addition & 1 deletion examples/llm-api/llm_multilora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from tensorrt_llm import LLM
from tensorrt_llm.executor import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig


def main():
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _add_trt_llm_dll_directory():
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
import xgrammar # noqa

import tensorrt_llm._torch.models as torch_models
import tensorrt_llm.functional as functional
import tensorrt_llm.math_utils as math_utils
import tensorrt_llm.models as models
Expand Down Expand Up @@ -82,6 +83,7 @@ def _add_trt_llm_dll_directory():
'default_trtnet',
'precision',
'net_guard',
'torch_models',
'Network',
'Mapping',
'MnnvlMemory',
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor)
from ...logger import logger
from ...lora_manager import LoraConfig
from ...lora_helper import LoraConfig
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
from ..model_config import ModelConfig
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F
from torch import nn

from tensorrt_llm import logger
import tensorrt_llm.logger as trtllm_logger
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.quantization.utils.fp4_utils import (
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
Expand Down Expand Up @@ -743,7 +743,7 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
trtllm_logger.logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
load_torch_lora)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from tensorrt_llm.lora_manager import load_torch_lora
from tensorrt_llm.mapping import Mapping

from ..model_config import ModelConfig
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
MultimodalRuntimeData)
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraModelConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantAlgo
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization import QuantAlgo

Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
from tensorrt_llm.sampling_params import SamplingParams

from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .functional import PositionEmbeddingType
from .graph_rewriting import optimize
from .logger import logger
from .lora_manager import LoraConfig
from .lora_helper import LoraConfig
from .models import PretrainedConfig, PretrainedModel
from .models.modeling_utils import SpeculativeDecodingMode, optimize_model
from .network import Network, net_guard
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.builder import BuildConfig, Engine, build
from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.lora_manager import LoraConfig, LoraManager
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
from tensorrt_llm.plugin import PluginConfig, add_plugin_argument
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/disaggregated_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from dataclasses import dataclass
from typing import List, Optional

# isort: off
# needed before trying to import bindings to load tensorrt_libs
import tensorrt as trt # noqa
# isort: on

from tensorrt_llm.bindings import executor as tllme


Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.logger import logger, set_level
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig

from .._utils import mpi_world_size
from ..bindings import executor as tllm
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug,
print_traceback_on_error)
from ..lora_manager import LoraConfig, LoraManager
from ..lora_helper import LoraConfig
from ..lora_manager import LoraManager
from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/llmapi/build_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import filelock

import tensorrt_llm
from tensorrt_llm import BuildConfig
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored
from tensorrt_llm.logger import logger

Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from strenum import StrEnum
from transformers import PreTrainedTokenizerBase

from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)

from .._utils import mpi_rank
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
Expand Down
101 changes: 101 additions & 0 deletions tensorrt_llm/lora_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Dict, List, Optional

from ._utils import DictConversion


def get_missing_qkv_modules_from_lora_modules(
lora_target_modules: List[str]) -> List[str]:
"""Get missing QKV modules from LoRA target modules.

In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
all disabled at the same time. However, some lora checkpoints (e.g. BART) only contain two of them,
so we use zero tensor to fill the missing ones.
"""
missing_qkv_modules = []
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
return missing_qkv_modules


def get_default_trtllm_modules_to_hf_modules():
"""Get default mapping from TensorRT-LLM module names to HuggingFace module names."""
return {
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_gate_up": "gate_up_proj",
"moe_h_to_4h": "w1",
"moe_4h_to_h": "w2",
"moe_gate": "w3",
"moe_router": "gate",
}


def use_lora(
model,
lora_config: "LoraConfig",
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
"""Use LoRA with the given model and configuration.

This function is a wrapper that delegates to the appropriate loading function
based on the LoRA checkpoint source.
"""
if lora_config.lora_ckpt_source == "nemo":
from .lora_manager import load_nemo_lora
load_nemo_lora(model, lora_config)
elif lora_config.lora_ckpt_source == "hf":
from .lora_manager import load_hf_lora
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")


@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: Optional[int] = None
max_cpu_loras: Optional[int] = None

def __post_init__(self):
assert self.lora_ckpt_source in [
"hf", "nemo"
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)

@property
def missing_qkv_modules(self) -> List[str]:
return get_missing_qkv_modules_from_lora_modules(
self.lora_target_modules)
76 changes: 9 additions & 67 deletions tensorrt_llm/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tarfile
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
Expand All @@ -16,8 +16,13 @@

from tensorrt_llm.bindings import internal as tb_internal

from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from .layers.linear import ColumnLinear
from .lora_helper import (
LoraConfig,
get_default_trtllm_modules_to_hf_modules,
get_missing_qkv_modules_from_lora_modules,
)
from .mapping import Mapping
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp

Expand Down Expand Up @@ -232,26 +237,6 @@ def norm_dora_magnitude(
return norm_m


@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: int | None = None
max_cpu_loras: int | None = None

def __post_init__(self):
assert self.lora_ckpt_source in ["hf", "nemo"], (
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)

@property
def missing_qkv_modules(self) -> List[str]:
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)


@dataclass
class LoraModelConfig:
lora_target_modules: list[str]
Expand Down Expand Up @@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig):
lora_config.lora_target_modules = lora_loader.lora_target_modules


def get_default_trtllm_modules_to_hf_modules():
return {
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_gate_up": "gate_up_proj",
"moe_h_to_4h": "w1",
"moe_4h_to_h": "w2",
"moe_gate": "w3",
"moe_router": "gate",
}


def load_torch_hf_lora(lora_config: LoraConfig):
"""This is a shortned version of load_hf_lora that is used for torch models.

Expand Down Expand Up @@ -628,19 +596,6 @@ def load_hf_lora(
).to(torch_dtype)


def use_lora(
model,
lora_config: LoraConfig,
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
if lora_config.lora_ckpt_source == "nemo":
load_nemo_lora(model, lora_config)
elif lora_config.lora_ckpt_source == "hf":
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")


def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
"""Unpack model config and weights from a NeMo .nemo archive file.

Expand Down Expand Up @@ -762,21 +717,8 @@ def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
)

@staticmethod
def get_missing_qkv_modules(lora_target_modules):
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
# all disabled at the same time.
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
# to fill the missing ones.
missing_qkv_modules = []
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
return missing_qkv_modules
def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]:
return get_missing_qkv_modules_from_lora_modules(lora_target_modules)

@property
def missing_qkv_modules(self) -> List[str]:
Expand Down
Loading