Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 20 additions & 36 deletions tensorrt_llm/_torch/models/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from torch import nn
from transformers import Gemma3TextConfig
from transformers.activations import ACT2FN

from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
Expand All @@ -20,7 +19,8 @@
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.linear import Linear, TensorParallelMode
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
Expand Down Expand Up @@ -156,37 +156,10 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
return super().apply_rope(q, k, v, position_ids)


class Gemma3MLP(nn.Module):

def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
super().__init__()
self.config = model_config.pretrained_config
self.hidden_size = self.config.hidden_size
self.intermediate_size = self.config.intermediate_size
self.dtype = self.config.torch_dtype
self.quant_config = model_config.get_quant_config()
self.gate_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=False,
dtype=self.dtype,
quant_config=self.quant_config)
self.up_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=False,
dtype=self.dtype,
quant_config=self.quant_config)
self.down_proj = Linear(self.intermediate_size,
self.hidden_size,
bias=False,
dtype=self.dtype,
quant_config=self.quant_config)
self.act_fn = ACT2FN[self.config.hidden_activation]

@torch.inference_mode()
def forward(self, x):
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
# This function is written to be compatible with TRTLLM's GatedMLP class.
def pytorch_gelu_tanh(gate_x: torch.Tensor) -> torch.Tensor:
gate, x = gate_x.chunk(2, dim=-1)
return nn.functional.gelu(gate, approximate="tanh") * x


class Gemma3DecoderLayer(DecoderLayer):
Expand All @@ -206,7 +179,13 @@ def __init__(
is_sliding=is_sliding,
)

self.mlp = Gemma3MLP(model_config=model_config)
self.mlp = GatedMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=False,
activation=pytorch_gelu_tanh,
dtype=config.torch_dtype,
config=model_config,
layer_idx=layer_idx)

self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
Expand All @@ -230,6 +209,7 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
attention_mask_data: Optional[torch.Tensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:

Expand All @@ -242,13 +222,14 @@ def forward(
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
is not None else PredefinedAttentionMask.CAUSAL,
attention_mask_data=attention_mask_data,
lora_params=lora_params,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, lora_params=lora_params)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -289,6 +270,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
local_attention_mask_data: Optional[torch.Tensor] = None,
global_attention_mask_data: Optional[torch.Tensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand All @@ -308,7 +290,9 @@ def forward(
attn_metadata=attn_metadata,
attention_mask_data=local_attention_mask_data
if decoder_layer.self_attn.is_sliding else
global_attention_mask_data)
global_attention_mask_data,
lora_params=lora_params,
)

hidden_states = self.norm(hidden_states)
return hidden_states
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/models/modeling_gemma3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def forward(
inputs_embeds=inputs_embeds,
return_context_logits=return_context_logits,
image_token_mask=mm_token_mask,
lora_params=kwargs.get("lora_params", None),
)
return logits

Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def __init__(self,
def _apply_activation(self, x):
if self.activation == F.silu:
return swiglu(x)
elif self.activation == None:
elif callable(self.activation):
return self.activation(x)
elif self.activation is None:
return x
else:
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ google/gemma-3-1b-it:
- accuracy: 22.988
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 22.988
accuracy: 20.699
google/gemma-3-27b-it:
- accuracy: 28.90
gpt2:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/defs/accuracy/references/mmlu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ google/gemma-3-1b-it:
- accuracy: 39.0
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 39.0
accuracy: 37.5
google/gemma-3-27b-it:
- accuracy: 77.80
Qwen/Qwen2-0.5B-Instruct:
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/defs/perf/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
"mistral_7b_v0.1": "mistral-7b-v0.1",
"ministral_8b": "Ministral-8B-Instruct-2410",
"ministral_8b_fp8": "Ministral-8B-Instruct-2410-FP8",
"gemma_3_1b_it": "gemma/gemma-3-1b-it",
"deepseek_r1_fp8": "DeepSeek-R1/DeepSeek-R1",
"deepseek_r1_nvfp4": "DeepSeek-R1/DeepSeek-R1-FP4",
"deepseek_v3_lite_fp8": "DeepSeek-V3-Lite/fp8",
Expand Down Expand Up @@ -153,6 +154,7 @@
"ministral_8b_hf": "mistralai/Ministral-8B-Instruct-2410",
"flan_t5_base_hf": "google/flan-t5-small",
"phi_4_mini_instruct_hf": "microsoft/Phi-4-mini-instruct",
"gemma_3_1b_it_hf": "google/gemma-3-1b-it",
}
LORA_MODEL_PATH = {
"llama_v2_13b":
Expand All @@ -163,6 +165,8 @@
"lora/llama-3-chinese-8b-instruct-v2-lora/",
"ministral_8b":
"lora/ministral/Ministral-8B-Instruct-2410-Loras-Dummy", # Dummy LoRA for Ministral
"gemma_3_1b_it":
"lora/gemma/gemma-3-1b-it-dummy-lora", # Dummy LoRA for Gemma-3-1B-Instruct
"phi_4_multimodal_instruct_image":
"multimodals/Phi-4-multimodal-instruct/vision-lora",
"phi_4_multimodal_instruct_audio":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ trt_llm_integration_perf_test:
- perf/test_perf.py::test_perf[llama_v3.1_8b-cpp-ootb_except_mha-bfloat16-maxbs:64-bs:64-input_output_len:128,8+512,32]
- perf/test_perf.py::test_perf[llama_v3.1_8b-cpp-ootb_except_mha-bfloat16-maxbs:64-bs:64-input_output_len:128,128+512,32]

# Dummy lora tests
- perf/test_perf.py::test_perf[gemma_3_1b_it-bench-pytorch-bfloat16-maxbs:2-maxnt:1024-input_output_len:128,128-loras:1-reqs:8-con:2]

# Test list validation
- test_list_validation.py::test_list_validation

Expand Down
Loading