Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from torch import nn

from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping

from ..distributed import AllReduceParams
Expand Down Expand Up @@ -95,12 +96,21 @@ def __init__(self,
[LoraModuleType.MLP_GATE_UP],
[2 * self.intermediate_size // mapping.tp_size])

def _apply_activation(self, x):
def _apply_activation(self, x, *, has_lora: bool = False):
if self.activation == F.silu:
if self.down_proj.has_fp8_qdq:
return swiglu(x,
quant_scale=self.down_proj.input_scale,
quant_type=torch.float8_e4m3fn)
if has_lora:
# NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet.
# TODO: Remove this path when LoRA grouped_gemm supports FP8
# see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm
logger.warning(
f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}"
)
return swiglu(x)
else:
return swiglu(x,
quant_scale=self.down_proj.input_scale,
quant_type=torch.float8_e4m3fn)
else:
return swiglu(x)
elif callable(self.activation):
Expand Down Expand Up @@ -152,7 +162,7 @@ def forward_lora(
if h1_lora is not None:
h1 = h1 + h1_lora

h2 = self._apply_activation(h1)
h2 = self._apply_activation(h1, has_lora=True)
output = self.down_proj(h2,
all_reduce_params=final_all_reduce_params,
lora_params=lora_params,
Expand Down
11 changes: 10 additions & 1 deletion tests/integration/defs/perf/pytorch_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str,

# lora-specific change for pytorch
if 'pytorch' in model_label and 'loras' in model_label:
# Derive the requested number of adapters from model_label (segment like "loras:X")
lora_count = 1
for part in model_label.split('-'):
if part.startswith('loras:'):
lora_count = max(1, int(part.split(':', 1)[1]))
break

lora_config = {
'lora_config': {
'lora_dir': lora_dirs if lora_dirs is not None else [],
'max_lora_rank': 64
'max_lora_rank': 64,
'max_loras': lora_count,
'max_cpu_loras': lora_count,
}
}
if 'phi_4_multimodal_instruct' in model_label:
Expand Down