diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 56aa00a30d3a..5c3b22001636 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -339,6 +339,7 @@ class CompilationConfig: "vllm.mamba_mixer2", "vllm.mamba_mixer", "vllm.short_conv", + "vllm.linear_attention", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 0e854bd7d913..176a40179bca 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" -import copy import math from collections.abc import Iterable from typing import TYPE_CHECKING, Optional, Union @@ -19,13 +18,14 @@ from vllm import envs from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -43,12 +43,15 @@ MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from .interfaces import HasInnerState, IsHybrid @@ -143,61 +146,6 @@ def forward( return self._forward(x) -class MiniMaxText01RotaryEmbedding(CustomOp): - name = "MiniMaxText01RotaryEmbedding" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position: int, - base: float, - is_neox_style: bool, - cache_dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position - self.base = base - self.is_neox_style = is_neox_style - self.cache_dtype = cache_dtype - cache = self._compute_cos_sin_cache().to(cache_dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - """Compute the inverse frequency.""" - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) - query_cast = query.to(self.cache_dtype) - key_cast = key.to(self.cache_dtype) - ops.rotary_embedding(positions, query_cast, key_cast, self.head_size, - self.cos_sin_cache, self.is_neox_style) - query = query_cast.to(query.dtype) - key = key_cast.to(key.dtype) - return query, key - - class MiniMaxText01MLP(nn.Module): def __init__( @@ -526,20 +474,40 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, slot_id, 32) return hidden - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> None: + if not envs.VLLM_USE_V1: + self._forward(hidden_states, output, positions, kv_caches) + else: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[MinimaxCacheParams]) -> None: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1 and attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + num_actual_tokens = hidden_states.shape[0] + + qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) qkv32 = qkv.to(torch.float32) qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata if envs.VLLM_USE_V1: if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, LinearAttentionMetadata) kv_cache = self.kv_cache[forward_context.virtual_engine][0] state_indices_tensor = attn_metadata.state_indices_tensor @@ -578,13 +546,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, hidden = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata) - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) + gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) hidden = F.sigmoid(gate) * hidden hidden = hidden.to(hidden_states.dtype) - hidden, _ = self.out_proj(hidden) - return hidden + output[:num_actual_tokens], _ = self.out_proj(hidden) class MiniMaxText01Attention(nn.Module): @@ -652,23 +618,23 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", ) + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=int(rope_theta), + is_neox_style=True, + dtype=torch.float32, + ) return - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - **kwargs) -> torch.Tensor: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, **kwargs) -> None: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - if envs.VLLM_USE_V1: - if attn_metadata is not None: - q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb( - positions, q, k) - else: - q, k = attn_metadata.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output + output[:], _ = self.o_proj(attn_output) class MiniMaxText01DecoderLayer(nn.Module): @@ -816,16 +782,15 @@ def forward(self, is_warmup: bool = False, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input - self_attention_output = self.self_attn( + self_attention_output = torch.empty_like(layernorm_output) + self.self_attn( hidden_states=layernorm_output, + output=self_attention_output, positions=positions, kv_caches=kv_caches, - attn_metadata=attn_metadata, ) residual = residual * self.layernorm_attention_alpha @@ -839,8 +804,8 @@ def forward(self, if self.expert_num == 1: hidden_states = self.mlp(layernorm_output) else: - moe_hidden_states = self.block_sparse_moe( - copy.deepcopy(layernorm_output)) + moe_layernorm_output = layernorm_output.clone() + moe_hidden_states = self.block_sparse_moe(moe_layernorm_output) if self.shared_moe: before_moe_dtype = layernorm_output.dtype moe_hidden_fp32 = moe_hidden_states.to(torch.float32) @@ -878,18 +843,16 @@ def shared_moe_coefficient_loader(param: torch.Tensor, return +@support_torch_compile class MiniMaxText01Model(nn.Module): - def __init__( - self, - config: MiniMaxConfig, - model_config: Optional[ModelConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - scheduler_config=None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config: MiniMaxConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -976,24 +939,6 @@ def layer_fn(prefix): self.minimax_cache = MinimaxCacheManager( dtype=torch.float32, cache_shape=self.cache_shape) - rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", None) - if head_dim is None: - head_dim = config.hidden_size // config.num_attention_heads - if hasattr(config, "max_model_len") and isinstance( - config.max_model_len, int): - max_position_embeddings = min(config.max_position_embeddings, - config.max_model_len) - self.rotary_emb = MiniMaxText01RotaryEmbedding( - head_dim, - rotary_dim=config.rotary_dim - if hasattr(config, "rotary_dim") else head_dim, - max_position=max_position_embeddings, - base=int(rope_theta), - is_neox_style=True, - cache_dtype=torch.float32, - ) - norm_kwargs = {} if hasattr(config, "rms_norm_eps"): norm_kwargs["eps"] = config.rms_norm_eps @@ -1043,12 +988,11 @@ def forward(self, attn_metadata = forward_context.attn_metadata if not envs.VLLM_USE_V1 and attn_metadata is None: return None - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - if not envs.VLLM_USE_V1: + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] ( minimax_cache_tensors, state_indices_tensor, @@ -1077,16 +1021,6 @@ def forward(self, for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - if attn_metadata is not None: - # TODO (tdoublep): this whole thing with the rotary_emb is - # weird. we shouldn't be passing it via attn_metadata imo. - if envs.VLLM_USE_V1: - if isinstance(layer.self_attn, MiniMaxText01Attention): - attn_metadata[layer.prefix + - ".attn"].rotary_emb = self.rotary_emb - else: - attn_metadata.rotary_emb = self.rotary_emb - _caches = None if not envs.VLLM_USE_V1 and isinstance( layer.self_attn, MiniMaxText01LinearAttention): @@ -1120,7 +1054,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -1133,13 +1066,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model( - self.config, - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - scheduler_config=vllm_config.scheduler_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiniMaxText01Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -1469,3 +1397,35 @@ def get_mamba_state_shape_from_config( tp_size=parallel_config.tensor_parallel_size, head_dim=hf_config.head_dim, ) + + +def linear_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, + output=output, + positions=positions, + kv_caches=None) + + +def linear_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="linear_attention", + op_func=linear_attention, + mutates_args=["output"], + fake_impl=linear_attention_fake, + dispatch_key=current_platform.dispatch_key, +)