From 9ad6271231a0ee26181bd131f807aafa01a57670 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 10 Aug 2025 05:26:47 -0400 Subject: [PATCH 1/7] Enable compile for minimax Signed-off-by: Thomas Parnell --- vllm/config/compilation.py | 1 + vllm/model_executor/models/minimax_text_01.py | 132 ++++++++++++------ 2 files changed, 89 insertions(+), 44 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 8a78d811b9a2..f6ffe4505331 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -425,4 +425,5 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", + "vllm.linear_attention", ] diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 3d14a6ad5c3a..2cf52b0e5a44 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" -import copy import math from collections.abc import Iterable from typing import Optional, Union @@ -16,12 +15,13 @@ from vllm import envs from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -44,7 +44,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from .interfaces import HasInnerState, IsHybrid @@ -507,20 +509,41 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, slot_id, 32) return hidden - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> torch.Tensor: + if not envs.VLLM_USE_V1: + self._forward(hidden_states, output, positions, kv_caches) + else: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> torch.Tensor: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1 and attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + num_actual_tokens = hidden_states.shape[0] + + qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) qkv32 = qkv.to(torch.float32) qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata + if envs.VLLM_USE_V1: if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, LinearAttentionMetadata) kv_cache = self.kv_cache[forward_context.virtual_engine][0] state_indices_tensor = attn_metadata.state_indices_tensor @@ -559,13 +582,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, hidden = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata) - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) + gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) hidden = F.sigmoid(gate) * hidden hidden = hidden.to(hidden_states.dtype) - hidden, _ = self.out_proj(hidden) - return hidden + output[:num_actual_tokens], _ = self.out_proj(hidden) class MiniMaxText01Attention(nn.Module): @@ -635,8 +656,8 @@ def __init__( ) return - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - **kwargs) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, **kwargs) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata qkv, _ = self.qkv_proj(hidden_states) @@ -648,8 +669,7 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, else: q, k = attn_metadata.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output + output[:], _ = self.o_proj(attn_output) class MiniMaxText01DecoderLayer(nn.Module): @@ -794,16 +814,15 @@ def forward(self, is_warmup: bool = False, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input - self_attention_output = self.self_attn( + self_attention_output = torch.empty_like(layernorm_output) + self.self_attn( hidden_states=layernorm_output, + output=self_attention_output, positions=positions, kv_caches=kv_caches, - attn_metadata=attn_metadata, ) residual = residual * self.layernorm_attention_alpha @@ -817,8 +836,8 @@ def forward(self, if self.expert_num == 1: hidden_states = self.mlp(layernorm_output) else: - moe_hidden_states = self.block_sparse_moe( - copy.deepcopy(layernorm_output)) + moe_layernorm_output = layernorm_output.clone() + moe_hidden_states = self.block_sparse_moe(moe_layernorm_output) if self.shared_moe: before_moe_dtype = layernorm_output.dtype moe_hidden_fp32 = moe_hidden_states.to(torch.float32) @@ -856,17 +875,15 @@ def shared_moe_coefficient_loader(param: torch.Tensor, return +@support_torch_compile class MiniMaxText01Model(nn.Module): - def __init__( - self, - config: MiniMaxConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - scheduler_config=None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config: MiniMaxConfig = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + scheduler_config = vllm_config.scheduler_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1019,12 +1036,11 @@ def forward(self, attn_metadata = forward_context.attn_metadata if not envs.VLLM_USE_V1 and attn_metadata is None: return None - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - if not envs.VLLM_USE_V1: + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] ( minimax_cache_tensors, state_indices_tensor, @@ -1096,7 +1112,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -1109,12 +1124,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model( - self.config, - quant_config, - cache_config=vllm_config.cache_config, - scheduler_config=vllm_config.scheduler_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiniMaxText01Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -1433,3 +1444,36 @@ def get_mamba_state_shape_from_config( tp_size=parallel_config.tensor_parallel_size, head_dim=hf_config.head_dim, ) + + +def linear_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + print("layer_name: ", layer_name) + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, + output=output, + positions=positions, + kv_caches=None) + + +def linear_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="linear_attention", + op_func=linear_attention, + mutates_args=["output"], + fake_impl=linear_attention_fake, + dispatch_key=current_platform.dispatch_key, +) From c698db3280c6f756ff36086a14931a77acd43d66 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 10 Aug 2025 05:31:52 -0400 Subject: [PATCH 2/7] minor diff reduction Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 2cf52b0e5a44..6def7947c060 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -541,7 +541,6 @@ def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - if envs.VLLM_USE_V1: if attn_metadata is not None: kv_cache = self.kv_cache[forward_context.virtual_engine][0] From 07b3dd7b539ee199f1724e29afd37570b9c450ca Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 26 Aug 2025 04:27:16 -0400 Subject: [PATCH 3/7] Fix some things Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 1a7073650c30..c613a8b20a98 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -899,16 +899,13 @@ def shared_moe_coefficient_loader(param: torch.Tensor, @support_torch_compile class MiniMaxText01Model(nn.Module): - def __init__( - self, - config: MiniMaxConfig, - model_config: Optional[ModelConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - scheduler_config=None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config: MiniMaxConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1138,7 +1135,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -1151,13 +1147,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model( - self.config, - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - scheduler_config=vllm_config.scheduler_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiniMaxText01Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, From 9de8c8fcdd0f0680a8a965dc73b84b10035c7ef2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 26 Aug 2025 10:57:02 -0400 Subject: [PATCH 4/7] Fix issue with rope Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 43 +++++-------------- 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index c613a8b20a98..fc55828f9fa6 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -43,6 +43,7 @@ MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -672,20 +673,21 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", ) + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=int(rope_theta), + is_neox_style=True, + dtype=torch.float32, + ) return def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor, **kwargs) -> None: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - if envs.VLLM_USE_V1: - if attn_metadata is not None: - q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb( - positions, q, k) - else: - q, k = attn_metadata.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output[:], _ = self.o_proj(attn_output) @@ -992,23 +994,9 @@ def layer_fn(prefix): self.minimax_cache = MinimaxCacheManager( dtype=torch.float32, cache_shape=self.cache_shape) - rope_theta = getattr(config, "rope_theta", 10000) head_dim = getattr(config, "head_dim", None) if head_dim is None: head_dim = config.hidden_size // config.num_attention_heads - if hasattr(config, "max_model_len") and isinstance( - config.max_model_len, int): - max_position_embeddings = min(config.max_position_embeddings, - config.max_model_len) - self.rotary_emb = MiniMaxText01RotaryEmbedding( - head_dim, - rotary_dim=config.rotary_dim - if hasattr(config, "rotary_dim") else head_dim, - max_position=max_position_embeddings, - base=int(rope_theta), - is_neox_style=True, - cache_dtype=torch.float32, - ) norm_kwargs = {} if hasattr(config, "rms_norm_eps"): @@ -1092,16 +1080,6 @@ def forward(self, for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - if attn_metadata is not None: - # TODO (tdoublep): this whole thing with the rotary_emb is - # weird. we shouldn't be passing it via attn_metadata imo. - if envs.VLLM_USE_V1: - if isinstance(layer.self_attn, MiniMaxText01Attention): - attn_metadata[layer.prefix + - ".attn"].rotary_emb = self.rotary_emb - else: - attn_metadata.rotary_emb = self.rotary_emb - _caches = None if not envs.VLLM_USE_V1 and isinstance( layer.self_attn, MiniMaxText01LinearAttention): @@ -1487,7 +1465,6 @@ def linear_attention( layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - print("layer_name: ", layer_name) self = forward_context.no_compile_layers[layer_name] self._forward(hidden_states=hidden_states, output=output, From fed4c21261c7f9f65885a7701f71ae172c9cd077 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 26 Aug 2025 11:02:42 -0400 Subject: [PATCH 5/7] Further cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index fc55828f9fa6..385debc73058 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -994,10 +994,6 @@ def layer_fn(prefix): self.minimax_cache = MinimaxCacheManager( dtype=torch.float32, cache_shape=self.cache_shape) - head_dim = getattr(config, "head_dim", None) - if head_dim is None: - head_dim = config.hidden_size // config.num_attention_heads - norm_kwargs = {} if hasattr(config, "rms_norm_eps"): norm_kwargs["eps"] = config.rms_norm_eps From 477f9bb4c47f09f00bbb30bbe8ac1f04595de3ad Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 26 Aug 2025 11:30:07 -0400 Subject: [PATCH 6/7] Remove custom rope Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 55 ------------------- 1 file changed, 55 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 385debc73058..4f3ece97701b 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -146,61 +146,6 @@ def forward( return self._forward(x) -class MiniMaxText01RotaryEmbedding(CustomOp): - name = "MiniMaxText01RotaryEmbedding" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position: int, - base: float, - is_neox_style: bool, - cache_dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position - self.base = base - self.is_neox_style = is_neox_style - self.cache_dtype = cache_dtype - cache = self._compute_cos_sin_cache().to(cache_dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - """Compute the inverse frequency.""" - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) - query_cast = query.to(self.cache_dtype) - key_cast = key.to(self.cache_dtype) - ops.rotary_embedding(positions, query_cast, key_cast, self.head_size, - self.cos_sin_cache, self.is_neox_style) - query = query_cast.to(query.dtype) - key = key_cast.to(key.dtype) - return query, key - - class MiniMaxText01MLP(nn.Module): def __init__( From 02c37d59e4d471b35d5f2f9c17afa499bd79f479 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 26 Aug 2025 13:58:31 -0400 Subject: [PATCH 7/7] Fix type hints Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 4f3ece97701b..176a40179bca 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -476,7 +476,7 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams) -> torch.Tensor: + kv_caches: MinimaxCacheParams) -> None: if not envs.VLLM_USE_V1: self._forward(hidden_states, output, positions, kv_caches) else: @@ -489,7 +489,7 @@ def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams) -> torch.Tensor: + kv_caches: Optional[MinimaxCacheParams]) -> None: forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if envs.VLLM_USE_V1 and attn_metadata is not None: