Skip to content
Closed
7 changes: 4 additions & 3 deletions tests/kernels/mamba/test_mamba_ssm_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from einops import rearrange, repeat

from vllm.model_executor.layers.mamba.mamba2_metadata import (
_seq_idx_to_chunk_indices_offsets)
_query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
last_taken, exhausted, n_heads,
d_head, itype):

chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])

Y, new_states = mamba_chunk_scan_combined(
X,
Expand Down
79 changes: 40 additions & 39 deletions vllm/model_executor/layers/mamba/mamba2_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

@dataclass
class Mamba2Metadata:
has_prefill: bool

has_initial_states: torch.Tensor
prep_initial_states: bool
Expand All @@ -24,21 +23,23 @@ class Mamba2Metadata:
chunk_offsets: torch.Tensor


def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):

# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
cu_seqlens = query_start_loc[1:] # remove prepended 0

# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)

cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):

Expand All @@ -60,48 +61,48 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):

def prepare_mamba2_metadata(
chunk_size: int,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Mamba2Metadata:

# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens

# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states = None
prep_initial_states = False
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# precompute flag to avoid device syncs later in mamba2 forwards
prep_initial_states = torch.any(has_initial_states).item()

has_prefill = attn_metadata.num_prefills > 0

# keeping flags for both prefill and decode causal_conv1d varlen
has_initial_states = attn_metadata.context_lens_tensor > 0 # [batch,]
# precompute flag to avoid device syncs later in mamba2 layer forwards
# prep is only needed for mamba2 ssd prefill processing
prep_initial_states = torch.any(
has_initial_states[:num_prefills]).item()

# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
seq_idx = None
chunk_indices, chunk_offsets = None, None
if has_prefill:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
if num_prefills > 0:
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device),
query_start_loc.diff(),
output_size=num_prefill_tokens)
seq_idx.unsqueeze_(0)

# compute metadata for chunked prefill.
# actually this is only needed if there are initial states,
# but this is determinable only from attention metadata yet
# unavailable from the top-level model forward. Rather than
# complicating things to extract said metadata, we simply just
# compute them once at the top level model forward and reuse
# them in mamba layers. If not needed, they will be ignored
# inside mamba kernels.
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)

return Mamba2Metadata(has_prefill=has_prefill,
has_initial_states=has_initial_states,
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
if prep_initial_states:
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens)

return Mamba2Metadata(has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
Expand Down
126 changes: 85 additions & 41 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,15 @@ def forward_cuda(
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# are the same and reused for all mamba layers in the same iteration
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata

num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0

seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size

Expand All @@ -410,7 +416,9 @@ def forward_cuda(
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))

if mamba2_metadata.has_prefill:
# causal_conv1d_fn deals with both prefill and decode if input
# has prefill requests.
if has_prefill:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
Expand Down Expand Up @@ -454,31 +462,67 @@ def forward_cuda(
)

# 3. State Space Model sequence transformation
if mamba2_metadata.has_prefill:

# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_p, hidden_states_d = torch.split(
hidden_states,
[num_prefill_tokens, num_decodes],
dim=0,
)
B_p, B_d = torch.split(
B,
[num_prefill_tokens, num_decodes],
dim=0,
)
C_p, C_d = torch.split(
C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)

hidden_states_list = []

# Process prefill requests
if has_prefill:
initial_states = None
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor], 0)
mamba2_metadata.has_initial_states[:num_prefills, None,
None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)

scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
self.head_dim),
dt.unsqueeze(0),
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
dt_p.unsqueeze(0),
self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=mamba2_metadata.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
Expand All @@ -487,52 +531,52 @@ def forward_cuda(
)

# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor] = varlen_state
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state

# - reshape
hidden_states = scan_output.view(seq_len, -1)
else:
hidden_states_list.append(scan_output.view(num_prefill_tokens, -1))

# Process decode requests
if has_decode:
n_groups = self.n_groups // self.tp_size
A = self.A[:, None, ...][:, :, None].expand(
A_d = self.A[:, None, ...][:, :, None].expand(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the suffix _d mean in this code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh is it decode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, _p means prefill and _d means decode

-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups)
hidden_states_reshaped = hidden_states.view(
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim)

# - the hidden is reshaped into number of current batches
# - in this case there is no more prefill, so the batches gen
# 1 token at a time
# - thus hidden will be (bs, num_heads, head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using "mamba_cache_params.state_indices_tensor", just as
# above in the prefill case
# using state_indices_tensor_d

hidden_states = selective_state_update(
hidden_states_d = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states_reshaped,
dt,
A,
B,
C,
D,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
state_batch_indices=state_indices_tensor_d,
)
hidden_states = hidden_states.view(
-1, (self.num_heads // self.tp_size) * self.head_dim)
hidden_states_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))

# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(hidden_states_list)

# # 4. gated MLP
# 4. gated MLP
hidden_states = self.norm(hidden_states, gate)

# # 5. Final linear projection
# 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out
1 change: 0 additions & 1 deletion vllm/model_executor/layers/mamba/ops/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def _mamba_chunk_scan_combined_fwd(x,
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, seqlen, nheads)
assert A.shape == (nheads, )
assert C.shape == B.shape
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def forward(

mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)

Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def forward(
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)

Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def forward(

mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)

Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,6 @@ def forward(

mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)

Expand Down