diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index a9de63245d97..45f5dea08521 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -43,7 +43,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVCrossParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -814,11 +813,20 @@ def __init__( self.q_local_size = self.num_local_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim - self.qkv_proj = QKVCrossParallelLinear( + # TODO(Isotr0py): Use QKVCrossParallelLinear when it supports + # quantization + self.q_proj = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, - self.num_heads, - self.num_key_value_heads, + total_num_heads=0, + total_num_kv_heads=self.num_key_value_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", @@ -854,11 +862,15 @@ def forward( kv_range_for_decode: Optional[List[Tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], ) -> torch.Tensor: - q, k, v = self.qkv_proj(hidden_states, cross_attention_states) + q, _ = self.q_proj(hidden_states) if cross_attention_states is not None: + kv, _ = self.kv_proj(cross_attention_states) + k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1) k = k.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim) k = self.k_norm(k) + else: + k = v = None q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) @@ -1149,8 +1161,13 @@ def forward( class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsV0Only): packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "self_attn.qkv_proj": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + "cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -1420,9 +1437,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + (".cross_attn.kv_proj", ".cross_attn.k_proj", "k"), + (".cross_attn.kv_proj", ".cross_attn.v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ]