Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ class CompilationConfig:
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
]

def compute_hash(self) -> str:
Expand Down
234 changes: 97 additions & 137 deletions vllm/model_executor/models/minimax_text_01.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only MiniMaxText01 model."""
import copy
import math
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
Expand All @@ -19,13 +18,14 @@

from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -43,12 +43,15 @@
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata

from .interfaces import HasInnerState, IsHybrid
Expand Down Expand Up @@ -143,61 +146,6 @@ def forward(
return self._forward(x)


class MiniMaxText01RotaryEmbedding(CustomOp):
name = "MiniMaxText01RotaryEmbedding"

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
is_neox_style: bool,
cache_dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position
self.base = base
self.is_neox_style = is_neox_style
self.cache_dtype = cache_dtype
cache = self._compute_cos_sin_cache().to(cache_dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)

def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache

def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
query_cast = query.to(self.cache_dtype)
key_cast = key.to(self.cache_dtype)
ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
self.cos_sin_cache, self.is_neox_style)
query = query_cast.to(query.dtype)
key = key_cast.to(key.dtype)
return query, key


class MiniMaxText01MLP(nn.Module):

def __init__(
Expand Down Expand Up @@ -526,20 +474,40 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
slot_id, 32)
return hidden

def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)

def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[MinimaxCacheParams]) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
num_actual_tokens = hidden_states.shape[0]

qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor

Expand Down Expand Up @@ -578,13 +546,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)

hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
hidden, _ = self.out_proj(hidden)
return hidden
output[:num_actual_tokens], _ = self.out_proj(hidden)


class MiniMaxText01Attention(nn.Module):
Expand Down Expand Up @@ -652,23 +618,23 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
base=int(rope_theta),
is_neox_style=True,
dtype=torch.float32,
)
return

def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
**kwargs) -> torch.Tensor:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor, **kwargs) -> None:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
positions, q, k)
else:
q, k = attn_metadata.rotary_emb(positions, q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
output[:], _ = self.o_proj(attn_output)


class MiniMaxText01DecoderLayer(nn.Module):
Expand Down Expand Up @@ -816,16 +782,15 @@ def forward(self,
is_warmup: bool = False,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:

forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
layernorm_input = hidden_states
layernorm_output = self.input_layernorm(layernorm_input)
residual = layernorm_output if self.postnorm else layernorm_input
self_attention_output = self.self_attn(
self_attention_output = torch.empty_like(layernorm_output)
self.self_attn(
hidden_states=layernorm_output,
output=self_attention_output,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)

residual = residual * self.layernorm_attention_alpha
Expand All @@ -839,8 +804,8 @@ def forward(self,
if self.expert_num == 1:
hidden_states = self.mlp(layernorm_output)
else:
moe_hidden_states = self.block_sparse_moe(
copy.deepcopy(layernorm_output))
moe_layernorm_output = layernorm_output.clone()
moe_hidden_states = self.block_sparse_moe(moe_layernorm_output)
if self.shared_moe:
before_moe_dtype = layernorm_output.dtype
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
Expand Down Expand Up @@ -878,18 +843,16 @@ def shared_moe_coefficient_loader(param: torch.Tensor,
return


@support_torch_compile
class MiniMaxText01Model(nn.Module):

def __init__(
self,
config: MiniMaxConfig,
model_config: Optional[ModelConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
scheduler_config=None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: MiniMaxConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -976,24 +939,6 @@ def layer_fn(prefix):
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)

rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
if hasattr(config, "max_model_len") and isinstance(
config.max_model_len, int):
max_position_embeddings = min(config.max_position_embeddings,
config.max_model_len)
self.rotary_emb = MiniMaxText01RotaryEmbedding(
head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim") else head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
is_neox_style=True,
cache_dtype=torch.float32,
)

norm_kwargs = {}
if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps
Expand Down Expand Up @@ -1043,12 +988,11 @@ def forward(self,
attn_metadata = forward_context.attn_metadata
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []

if not envs.VLLM_USE_V1:
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
Expand Down Expand Up @@ -1077,16 +1021,6 @@ def forward(self,

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if attn_metadata is not None:
# TODO (tdoublep): this whole thing with the rotary_emb is
# weird. we shouldn't be passing it via attn_metadata imo.
if envs.VLLM_USE_V1:
if isinstance(layer.self_attn, MiniMaxText01Attention):
attn_metadata[layer.prefix +
".attn"].rotary_emb = self.rotary_emb
else:
attn_metadata.rotary_emb = self.rotary_emb

_caches = None
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
Expand Down Expand Up @@ -1120,7 +1054,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
Expand All @@ -1133,13 +1066,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.unpadded_vocab_size = self.config.vocab_size
if hasattr(vllm_config.model_config, "max_model_len"):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
self.config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
scheduler_config=vllm_config.scheduler_config,
prefix=maybe_prefix(prefix, "model"))
self.model = MiniMaxText01Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
Expand Down Expand Up @@ -1469,3 +1397,35 @@ def get_mamba_state_shape_from_config(
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)


def linear_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states,
output=output,
positions=positions,
kv_caches=None)


def linear_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
return


direct_register_custom_op(
op_name="linear_attention",
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)