diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index 727e45018e1..10acffae9d6 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -4,7 +4,6 @@ import torch from torch import nn from transformers import Gemma3TextConfig -from transformers.activations import ACT2FN from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper @@ -20,7 +19,8 @@ from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.linear import Linear, TensorParallelMode +from ..modules.gated_mlp import GatedMLP +from ..modules.linear import TensorParallelMode from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, @@ -156,37 +156,10 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], return super().apply_rope(q, k, v, position_ids) -class Gemma3MLP(nn.Module): - - def __init__(self, model_config: ModelConfig[Gemma3TextConfig]): - super().__init__() - self.config = model_config.pretrained_config - self.hidden_size = self.config.hidden_size - self.intermediate_size = self.config.intermediate_size - self.dtype = self.config.torch_dtype - self.quant_config = model_config.get_quant_config() - self.gate_proj = Linear(self.hidden_size, - self.intermediate_size, - bias=False, - dtype=self.dtype, - quant_config=self.quant_config) - self.up_proj = Linear(self.hidden_size, - self.intermediate_size, - bias=False, - dtype=self.dtype, - quant_config=self.quant_config) - self.down_proj = Linear(self.intermediate_size, - self.hidden_size, - bias=False, - dtype=self.dtype, - quant_config=self.quant_config) - self.act_fn = ACT2FN[self.config.hidden_activation] - - @torch.inference_mode() - def forward(self, x): - down_proj = self.down_proj( - self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj +# This function is written to be compatible with TRTLLM's GatedMLP class. +def pytorch_gelu_tanh(gate_x: torch.Tensor) -> torch.Tensor: + gate, x = gate_x.chunk(2, dim=-1) + return nn.functional.gelu(gate, approximate="tanh") * x class Gemma3DecoderLayer(DecoderLayer): @@ -206,7 +179,13 @@ def __init__( is_sliding=is_sliding, ) - self.mlp = Gemma3MLP(model_config=model_config) + self.mlp = GatedMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=False, + activation=pytorch_gelu_tanh, + dtype=config.torch_dtype, + config=model_config, + layer_idx=layer_idx) self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, @@ -230,6 +209,7 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor] = None, attention_mask_data: Optional[torch.Tensor] = None, + lora_params: Optional[dict] = None, **kwargs, ) -> torch.Tensor: @@ -242,13 +222,14 @@ def forward( attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data is not None else PredefinedAttentionMask.CAUSAL, attention_mask_data=attention_mask_data, + lora_params=lora_params, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.pre_feedforward_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, lora_params=lora_params) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states @@ -289,6 +270,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, local_attention_mask_data: Optional[torch.Tensor] = None, global_attention_mask_data: Optional[torch.Tensor] = None, + lora_params: Optional[dict] = None, **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): @@ -308,7 +290,9 @@ def forward( attn_metadata=attn_metadata, attention_mask_data=local_attention_mask_data if decoder_layer.self_attn.is_sliding else - global_attention_mask_data) + global_attention_mask_data, + lora_params=lora_params, + ) hidden_states = self.norm(hidden_states) return hidden_states diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index 671f3390358..b2f30ba63e6 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -213,6 +213,7 @@ def forward( inputs_embeds=inputs_embeds, return_context_logits=return_context_logits, image_token_mask=mm_token_mask, + lora_params=kwargs.get("lora_params", None), ) return logits diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index ef419651bb3..3f45ae80651 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -108,7 +108,9 @@ def __init__(self, def _apply_activation(self, x): if self.activation == F.silu: return swiglu(x) - elif self.activation == None: + elif callable(self.activation): + return self.activation(x) + elif self.activation is None: return x else: raise NotImplementedError( diff --git a/tests/integration/defs/accuracy/references/cnn_dailymail.yaml b/tests/integration/defs/accuracy/references/cnn_dailymail.yaml index fd4c43093fc..29e099f5816 100644 --- a/tests/integration/defs/accuracy/references/cnn_dailymail.yaml +++ b/tests/integration/defs/accuracy/references/cnn_dailymail.yaml @@ -2,7 +2,7 @@ google/gemma-3-1b-it: - accuracy: 22.988 - quant_algo: FP8 kv_cache_quant_algo: FP8 - accuracy: 22.988 + accuracy: 20.699 google/gemma-3-27b-it: - accuracy: 28.90 gpt2: diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 3d387f36b80..20c9cf119a8 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -104,7 +104,7 @@ google/gemma-3-1b-it: - accuracy: 39.0 - quant_algo: FP8 kv_cache_quant_algo: FP8 - accuracy: 39.0 + accuracy: 37.5 google/gemma-3-27b-it: - accuracy: 77.80 Qwen/Qwen2-0.5B-Instruct: diff --git a/tests/integration/defs/perf/test_perf.py b/tests/integration/defs/perf/test_perf.py index 4459521c637..d46cc198801 100644 --- a/tests/integration/defs/perf/test_perf.py +++ b/tests/integration/defs/perf/test_perf.py @@ -92,6 +92,7 @@ "mistral_7b_v0.1": "mistral-7b-v0.1", "ministral_8b": "Ministral-8B-Instruct-2410", "ministral_8b_fp8": "Ministral-8B-Instruct-2410-FP8", + "gemma_3_1b_it": "gemma/gemma-3-1b-it", "deepseek_r1_fp8": "DeepSeek-R1/DeepSeek-R1", "deepseek_r1_nvfp4": "DeepSeek-R1/DeepSeek-R1-FP4", "deepseek_v3_lite_fp8": "DeepSeek-V3-Lite/fp8", @@ -153,6 +154,7 @@ "ministral_8b_hf": "mistralai/Ministral-8B-Instruct-2410", "flan_t5_base_hf": "google/flan-t5-small", "phi_4_mini_instruct_hf": "microsoft/Phi-4-mini-instruct", + "gemma_3_1b_it_hf": "google/gemma-3-1b-it", } LORA_MODEL_PATH = { "llama_v2_13b": @@ -163,6 +165,8 @@ "lora/llama-3-chinese-8b-instruct-v2-lora/", "ministral_8b": "lora/ministral/Ministral-8B-Instruct-2410-Loras-Dummy", # Dummy LoRA for Ministral + "gemma_3_1b_it": + "lora/gemma/gemma-3-1b-it-dummy-lora", # Dummy LoRA for Gemma-3-1B-Instruct "phi_4_multimodal_instruct_image": "multimodals/Phi-4-multimodal-instruct/vision-lora", "phi_4_multimodal_instruct_audio": diff --git a/tests/integration/test_lists/qa/trt_llm_integration_perf_test.yml b/tests/integration/test_lists/qa/trt_llm_integration_perf_test.yml index ec0fb2e3c74..1d2e3e01507 100644 --- a/tests/integration/test_lists/qa/trt_llm_integration_perf_test.yml +++ b/tests/integration/test_lists/qa/trt_llm_integration_perf_test.yml @@ -45,6 +45,9 @@ trt_llm_integration_perf_test: - perf/test_perf.py::test_perf[llama_v3.1_8b-cpp-ootb_except_mha-bfloat16-maxbs:64-bs:64-input_output_len:128,8+512,32] - perf/test_perf.py::test_perf[llama_v3.1_8b-cpp-ootb_except_mha-bfloat16-maxbs:64-bs:64-input_output_len:128,128+512,32] + # Dummy lora tests + - perf/test_perf.py::test_perf[gemma_3_1b_it-bench-pytorch-bfloat16-maxbs:2-maxnt:1024-input_output_len:128,128-loras:1-reqs:8-con:2] + # Test list validation - test_list_validation.py::test_list_validation