From f7c7af5304770667d7358f4b6af8d9b76cab650d Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Mon, 18 Aug 2025 16:18:27 -0700 Subject: [PATCH 1/3] [5464088][fix] Enhance LoRA support in PyTorch model configuration - Added logging for dtype casting in LoraLayer to ensure compatibility with FP16/BF16. - Updated model configuration to derive the number of LoRA adapters from the model label, improving flexibility in adapter management. Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/peft/lora/layer.py | 11 +++++++++++ tests/integration/defs/perf/pytorch_model_config.py | 11 ++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/peft/lora/layer.py b/tensorrt_llm/_torch/peft/lora/layer.py index fb984614175..8e43abb7eff 100644 --- a/tensorrt_llm/_torch/peft/lora/layer.py +++ b/tensorrt_llm/_torch/peft/lora/layer.py @@ -3,6 +3,8 @@ import torch +from tensorrt_llm._utils import logger + class LoraModuleType(IntEnum): """Enum class representing different types of modules that can have LoRA adapters. @@ -119,6 +121,15 @@ def forward( if len(active_lora_module_ids) == 0: return None else: + # Guard: LoRA custom op only supports FP16/BF16 activations. + # If upstream produced FP8 (e.g., FP8 SwiGLU), cast here to avoid runtime failure. + if x.dtype not in (torch.float16, torch.bfloat16): + target_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported( + ) else torch.float16 + logger.debug( + f"lora_grouped_gemm supports only FP16/BF16. Casting input from {x.dtype} to {target_dtype}." + ) + x = x.to(target_dtype).contiguous() lora_outputs = torch.ops.trtllm.lora_grouped_gemm( x, lora_params['host_request_types'][:num_seqs], diff --git a/tests/integration/defs/perf/pytorch_model_config.py b/tests/integration/defs/perf/pytorch_model_config.py index 7a68b96e115..822a2a30c3a 100644 --- a/tests/integration/defs/perf/pytorch_model_config.py +++ b/tests/integration/defs/perf/pytorch_model_config.py @@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str, # lora-specific change for pytorch if 'pytorch' in model_label and 'loras' in model_label: + # Derive the requested number of adapters from model_label (segment like "loras:X") + lora_count = 1 + for part in model_label.split('-'): + if part.startswith('loras:'): + lora_count = max(1, int(part.split(':', 1)[1])) + break + lora_config = { 'lora_config': { 'lora_dir': lora_dirs if lora_dirs is not None else [], - 'max_lora_rank': 64 + 'max_lora_rank': 64, + 'max_loras': lora_count, + 'max_cpu_loras': lora_count, } } if 'phi_4_multimodal_instruct' in model_label: From fe858342916027c448533117130e5a385eae3743 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Tue, 19 Aug 2025 12:02:54 -0700 Subject: [PATCH 2/3] Move activation handling to GatedMLP for LoRA compatibility - Modified _apply_activation method to accept a for_lora flag, allowing for specific handling of activation during LoRA operations. - Updated the call to _apply_activation in GatedMLP to pass the for_lora argument, ensuring correct behavior in LoRA scenarios. - Removed unnecessary dtype casting checks in LoraLayer, simplifying the code. Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/modules/gated_mlp.py | 22 +++++++++++++++++----- tensorrt_llm/_torch/peft/lora/layer.py | 11 ----------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 8b3e314a9ec..7824d9182e3 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from torch import nn +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from ..distributed import AllReduceParams @@ -95,12 +96,23 @@ def __init__(self, [LoraModuleType.MLP_GATE_UP], [2 * self.intermediate_size // mapping.tp_size]) - def _apply_activation(self, x): + def _apply_activation(self, x, *, for_lora: bool = False): if self.activation == F.silu: if self.down_proj.has_fp8_qdq: - return swiglu(x, - quant_scale=self.down_proj.input_scale, - quant_type=torch.float8_e4m3fn) + if for_lora: + + target = torch.bfloat16 if torch.cuda.is_bf16_supported( + ) else torch.float16 + logger.debug( + f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype {target} (keeping activations in bf16/fp16), layer_idx={self.layer_idx}" + ) + return swiglu(x, + quant_scale=self.down_proj.input_scale, + quant_type=target) + else: + return swiglu(x, + quant_scale=self.down_proj.input_scale, + quant_type=torch.float8_e4m3fn) else: return swiglu(x) elif callable(self.activation): @@ -152,7 +164,7 @@ def forward_lora( if h1_lora is not None: h1 = h1 + h1_lora - h2 = self._apply_activation(h1) + h2 = self._apply_activation(h1, for_lora=True) output = self.down_proj(h2, all_reduce_params=final_all_reduce_params, lora_params=lora_params, diff --git a/tensorrt_llm/_torch/peft/lora/layer.py b/tensorrt_llm/_torch/peft/lora/layer.py index 8e43abb7eff..fb984614175 100644 --- a/tensorrt_llm/_torch/peft/lora/layer.py +++ b/tensorrt_llm/_torch/peft/lora/layer.py @@ -3,8 +3,6 @@ import torch -from tensorrt_llm._utils import logger - class LoraModuleType(IntEnum): """Enum class representing different types of modules that can have LoRA adapters. @@ -121,15 +119,6 @@ def forward( if len(active_lora_module_ids) == 0: return None else: - # Guard: LoRA custom op only supports FP16/BF16 activations. - # If upstream produced FP8 (e.g., FP8 SwiGLU), cast here to avoid runtime failure. - if x.dtype not in (torch.float16, torch.bfloat16): - target_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported( - ) else torch.float16 - logger.debug( - f"lora_grouped_gemm supports only FP16/BF16. Casting input from {x.dtype} to {target_dtype}." - ) - x = x.to(target_dtype).contiguous() lora_outputs = torch.ops.trtllm.lora_grouped_gemm( x, lora_params['host_request_types'][:num_seqs], From ab300023a5ba48255560d96c4e4db5f27855d365 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Wed, 20 Aug 2025 12:17:21 -0700 Subject: [PATCH 3/3] apply review comments Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/modules/gated_mlp.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 7824d9182e3..85e91907391 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -96,19 +96,17 @@ def __init__(self, [LoraModuleType.MLP_GATE_UP], [2 * self.intermediate_size // mapping.tp_size]) - def _apply_activation(self, x, *, for_lora: bool = False): + def _apply_activation(self, x, *, has_lora: bool = False): if self.activation == F.silu: if self.down_proj.has_fp8_qdq: - if for_lora: - - target = torch.bfloat16 if torch.cuda.is_bf16_supported( - ) else torch.float16 - logger.debug( - f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype {target} (keeping activations in bf16/fp16), layer_idx={self.layer_idx}" + if has_lora: + # NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet. + # TODO: Remove this path when LoRA grouped_gemm supports FP8 + # see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm + logger.warning( + f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}" ) - return swiglu(x, - quant_scale=self.down_proj.input_scale, - quant_type=target) + return swiglu(x) else: return swiglu(x, quant_scale=self.down_proj.input_scale, @@ -164,7 +162,7 @@ def forward_lora( if h1_lora is not None: h1 = h1 + h1_lora - h2 = self._apply_activation(h1, for_lora=True) + h2 = self._apply_activation(h1, has_lora=True) output = self.down_proj(h2, all_reduce_params=final_all_reduce_params, lora_params=lora_params,