From 095fea3330a26e39589f516827a104d6e99f87d1 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Mon, 21 Apr 2025 15:36:46 -0400 Subject: [PATCH 01/13] remove unnecessary assert Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index e9efe6428252..79a1663b85bb 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -40,7 +40,6 @@ def _mamba_chunk_scan_combined_fwd(x, _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) - assert x.shape == (batch, seqlen, nheads, headdim) assert dt.shape == (batch, seqlen, nheads) assert A.shape == (nheads, ) assert C.shape == B.shape From d8879c5a61e6d755b4d2d79b109e9c8379ff9a16 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Mon, 21 Apr 2025 15:43:25 -0400 Subject: [PATCH 02/13] draft refactoring Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../layers/mamba/mamba2_metadata.py | 40 ++++++--- .../layers/mamba/mamba_mixer2.py | 84 ++++++++++++++----- vllm/model_executor/models/bamba.py | 1 - vllm/model_executor/models/zamba2.py | 1 - 4 files changed, 89 insertions(+), 37 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index b1c46190403d..abbc5fa07398 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -23,6 +23,9 @@ class Mamba2Metadata: chunk_indices: torch.Tensor chunk_offsets: torch.Tensor + num_prefills: int + num_decodes: int + def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): @@ -60,33 +63,41 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): def prepare_mamba2_metadata( chunk_size: int, - input_ids: torch.Tensor, + # input_ids: torch.Tensor, attn_metadata: AttentionMetadata, ) -> Mamba2Metadata: + # compute number of prefill and decode requests + # NOTE: in V0 we assume prefills are before decodes + num_prefills = attn_metadata.num_prefills + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decodes = attn_metadata.num_decode_tokens + # print(f"{num_prefills=}, {num_prefill_tokens=}, {num_decodes=}=") + # print(f"{attn_metadata.query_start_loc=}=") + + has_prefill = num_prefills > 0 + # Need flags to indicate if there are initial states # currently we really only support the FlashAttention backend + # initial states are only relevant for prefills has_initial_states = None prep_initial_states = False if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata, PlaceholderAttentionMetadata)) and attn_metadata.context_lens_tensor is not None): - has_initial_states = attn_metadata.context_lens_tensor > 0 + has_initial_states = attn_metadata.context_lens_tensor > 0 # [batch,] # precompute flag to avoid device syncs later in mamba2 forwards prep_initial_states = torch.any(has_initial_states).item() - has_prefill = attn_metadata.num_prefills > 0 - + # Compute seq_idx, chunk_indices and chunk_offsets for prefill only seq_idx = None chunk_indices, chunk_offsets = None, None if has_prefill: - seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate( - zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): - seq_idx[srt:end] = i + query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] + seq_idx = torch.repeat_interleave(torch.arange( + num_prefills, dtype=torch.int32, device=query_start_loc.device), + query_start_loc.diff(), + output_size=num_prefill_tokens) seq_idx.unsqueeze_(0) # compute metadata for chunked prefill. @@ -99,6 +110,9 @@ def prepare_mamba2_metadata( # inside mamba kernels. chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( seq_idx, chunk_size) + # print(f"{seq_idx=}") + # print(f"{chunk_indices=}") + # print(f"{chunk_offsets=}") return Mamba2Metadata(has_prefill=has_prefill, has_initial_states=has_initial_states, @@ -106,4 +120,6 @@ def prepare_mamba2_metadata( chunk_size=chunk_size, seq_idx=seq_idx, chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets) + chunk_offsets=chunk_offsets, + num_prefills=num_prefills, + num_decodes=num_decodes) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index d459c93a26b2..adb54c59cb61 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -453,24 +453,56 @@ def forward_cuda( dim=-1, ) - # 3. State Space Model sequence transformation - if mamba2_metadata.has_prefill: + # Separate prefill and decode by slicing hidden_states + num_prefills = mamba2_metadata.num_prefills # requests + num_decodes = mamba2_metadata.num_decodes # requests (also tokens) + num_prefill_tokens = attn_metadata.num_prefill_tokens # tokens + hidden_states_p, hidden_states_d = torch.split( + hidden_states, + [num_prefill_tokens, num_decodes], + dim=0, + ) + B_p, B_d = torch.split( + B, + [num_prefill_tokens, num_decodes], + dim=0, + ) + C_p, C_d = torch.split( + B, + [num_prefill_tokens, num_decodes], + dim=0, + ) + dt_p, dt_d = torch.split( + dt, + [num_prefill_tokens, num_decodes], + dim=0, + ) + + hidden_states_list = [] + + # Process Prefills + if num_prefills > 0: initial_states = None if (mamba2_metadata.has_initial_states is not None and mamba2_metadata.prep_initial_states): # making a copy of the states initial_states = torch.where( - mamba2_metadata.has_initial_states[:, None, None, None], + mamba2_metadata.has_initial_states[:num_prefills, None, + None, None], mamba_cache_params.ssm_state[ - mamba_cache_params.state_indices_tensor], 0) + mamba_cache_params. + state_indices_tensor[:num_prefills]], 0) scan_output, varlen_state = mamba_chunk_scan_combined( - hidden_states.view(1, seq_len, self.num_heads // self.tp_size, - self.head_dim), - dt.unsqueeze(0), + hidden_states_p.view(1, num_prefill_tokens, + self.num_heads // self.tp_size, + self.head_dim), + dt_p.unsqueeze(0), self.A, - B.view(1, seq_len, self.n_groups // self.tp_size, -1), - C.view(1, seq_len, self.n_groups // self.tp_size, -1), + B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, + -1), + C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, + -1), chunk_size=mamba2_metadata.chunk_size, D=self.D, z=None, @@ -478,7 +510,7 @@ def forward_cuda( seq_idx=mamba2_metadata.seq_idx, chunk_indices=mamba2_metadata.chunk_indices, chunk_offsets=mamba2_metadata.chunk_offsets, - cu_seqlens=attn_metadata.query_start_loc, + cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], initial_states=initial_states, return_varlen_states=True, return_final_states=False, @@ -487,23 +519,24 @@ def forward_cuda( ) # update ssm states - # - varlen state is a (batch, nheads, headdim, dstate) tensor + # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor mamba_cache_params.ssm_state[ - mamba_cache_params.state_indices_tensor] = varlen_state + mamba_cache_params. + state_indices_tensor[:num_prefills]] = varlen_state # - reshape - hidden_states = scan_output.view(seq_len, -1) - else: + hidden_states_list.append(scan_output.view(num_prefill_tokens, -1)) + if num_decodes > 0: n_groups = self.n_groups // self.tp_size A = self.A[:, None, ...][:, :, None].expand( -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) - dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) - B = B.view(-1, n_groups, B.shape[1] // n_groups) - C = C.view(-1, n_groups, C.shape[1] // n_groups) - hidden_states_reshaped = hidden_states.view( + B = B_d.view(-1, n_groups, B.shape[1] // n_groups) + C = C_d.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states_reshaped = hidden_states_d.view( -1, self.num_heads // self.tp_size, self.head_dim) # - the hidden is reshaped into number of current batches @@ -514,10 +547,10 @@ def forward_cuda( # using "mamba_cache_params.state_indices_tensor", just as # above in the prefill case - hidden_states = selective_state_update( + hidden_states_d = selective_state_update( mamba_cache_params.ssm_state, hidden_states_reshaped, - dt, + dt_d, A, B, C, @@ -525,10 +558,15 @@ def forward_cuda( z=None, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor, + state_batch_indices=mamba_cache_params. + state_indices_tensor[num_prefills:], ) - hidden_states = hidden_states.view( - -1, (self.num_heads // self.tp_size) * self.head_dim) + hidden_states_list.append( + hidden_states_d.view(-1, (self.num_heads // self.tp_size) * + self.head_dim)) + + # Merge states output + hidden_states = torch.vstack(hidden_states_list) # # 4. gated MLP hidden_states = self.norm(hidden_states, gate) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 16dac6123d66..87e1e102efd8 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -313,7 +313,6 @@ def forward( mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.mamba_chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, ) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index d34033e3ac90..eddccbba5a2d 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -751,7 +751,6 @@ def forward( mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, ) From 8a3f4bf9ae89cbb5cf4a030ec37bb159d46bb1e3 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Mon, 21 Apr 2025 16:02:45 -0400 Subject: [PATCH 03/13] fix bug in C splitting Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 +- vllm/model_executor/models/mamba2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index adb54c59cb61..5876489ccde9 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -468,7 +468,7 @@ def forward_cuda( dim=0, ) C_p, C_d = torch.split( - B, + C, [num_prefill_tokens, num_decodes], dim=0, ) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 78303733f6bb..72daf34c4412 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -142,7 +142,6 @@ def forward( mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, ) From d9c2755e7af8d23950f80d1947a6e11ffd52f2d6 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Mon, 21 Apr 2025 16:17:08 -0400 Subject: [PATCH 04/13] clean up and renaming for clarity Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../layers/mamba/mamba2_metadata.py | 21 +++++----------- .../layers/mamba/mamba_mixer2.py | 24 +++++++++---------- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index abbc5fa07398..0239f4a01376 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -72,8 +72,6 @@ def prepare_mamba2_metadata( num_prefills = attn_metadata.num_prefills num_prefill_tokens = attn_metadata.num_prefill_tokens num_decodes = attn_metadata.num_decode_tokens - # print(f"{num_prefills=}, {num_prefill_tokens=}, {num_decodes=}=") - # print(f"{attn_metadata.query_start_loc=}=") has_prefill = num_prefills > 0 @@ -100,19 +98,12 @@ def prepare_mamba2_metadata( output_size=num_prefill_tokens) seq_idx.unsqueeze_(0) - # compute metadata for chunked prefill. - # actually this is only needed if there are initial states, - # but this is determinable only from attention metadata yet - # unavailable from the top-level model forward. Rather than - # complicating things to extract said metadata, we simply just - # compute them once at the top level model forward and reuse - # them in mamba layers. If not needed, they will be ignored - # inside mamba kernels. - chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( - seq_idx, chunk_size) - # print(f"{seq_idx=}") - # print(f"{chunk_indices=}") - # print(f"{chunk_offsets=}") + # We compute metadata for chunked prefill once at the top level model + # forward and reuse them in mamba layers. If not needed, they will be + # ignored inside mamba kernels. + if prep_initial_states: + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) return Mamba2Metadata(has_prefill=has_prefill, has_initial_states=has_initial_states, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5876489ccde9..ed425a3de844 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -529,14 +529,14 @@ def forward_cuda( if num_decodes > 0: n_groups = self.n_groups // self.tp_size - A = self.A[:, None, ...][:, :, None].expand( + A_d = self.A[:, None, ...][:, :, None].expand( -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) - D = self.D[:, None, ...].expand(-1, self.head_dim) - B = B_d.view(-1, n_groups, B.shape[1] // n_groups) - C = C_d.view(-1, n_groups, C.shape[1] // n_groups) - hidden_states_reshaped = hidden_states_d.view( + D_d = self.D[:, None, ...].expand(-1, self.head_dim) + B_d = B_d.view(-1, n_groups, B.shape[1] // n_groups) + C_d = C_d.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states_d = hidden_states_d.view( -1, self.num_heads // self.tp_size, self.head_dim) # - the hidden is reshaped into number of current batches @@ -549,23 +549,23 @@ def forward_cuda( hidden_states_d = selective_state_update( mamba_cache_params.ssm_state, - hidden_states_reshaped, + hidden_states_d, dt_d, - A, - B, - C, - D, + A_d, + B_d, + C_d, + D_d, z=None, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=mamba_cache_params. - state_indices_tensor[num_prefills:], + state_indices_tensor[num_prefills:], # take decodes only ) hidden_states_list.append( hidden_states_d.view(-1, (self.num_heads // self.tp_size) * self.head_dim)) - # Merge states output + # Merge prefill and decode outputs hidden_states = torch.vstack(hidden_states_list) # # 4. gated MLP From 6a363c8e31a8a968250c4b0518d4786efa38b020 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Mon, 21 Apr 2025 16:38:13 -0400 Subject: [PATCH 05/13] minor improvement Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../layers/mamba/mamba_mixer2.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index ed425a3de844..464cc975359f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -453,7 +453,8 @@ def forward_cuda( dim=-1, ) - # Separate prefill and decode by slicing hidden_states + # 3. State Space Model sequence transformation + # Separate prefill and decode by slicing varlen input num_prefills = mamba2_metadata.num_prefills # requests num_decodes = mamba2_metadata.num_decodes # requests (also tokens) num_prefill_tokens = attn_metadata.num_prefill_tokens # tokens @@ -477,10 +478,15 @@ def forward_cuda( [num_prefill_tokens, num_decodes], dim=0, ) + state_indices_tensor_p, state_indices_tensor_d = torch.split( + mamba_cache_params.state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) hidden_states_list = [] - # Process Prefills + # Process prefill requests if num_prefills > 0: initial_states = None if (mamba2_metadata.has_initial_states is not None @@ -489,9 +495,7 @@ def forward_cuda( initial_states = torch.where( mamba2_metadata.has_initial_states[:num_prefills, None, None, None], - mamba_cache_params.ssm_state[ - mamba_cache_params. - state_indices_tensor[:num_prefills]], 0) + mamba_cache_params.ssm_state[state_indices_tensor_p], 0) scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, @@ -520,13 +524,12 @@ def forward_cuda( # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor - mamba_cache_params.ssm_state[ - mamba_cache_params. - state_indices_tensor[:num_prefills]] = varlen_state + mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state # - reshape hidden_states_list.append(scan_output.view(num_prefill_tokens, -1)) + # Process decode requests if num_decodes > 0: n_groups = self.n_groups // self.tp_size A_d = self.A[:, None, ...][:, :, None].expand( @@ -558,8 +561,7 @@ def forward_cuda( z=None, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params. - state_indices_tensor[num_prefills:], # take decodes only + state_batch_indices=state_indices_tensor_d, ) hidden_states_list.append( hidden_states_d.view(-1, (self.num_heads // self.tp_size) * From 70875ae58ef62167f3c90e847b7dbde50266ff76 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:04:37 -0400 Subject: [PATCH 06/13] clean up and remove some redundant info in mamba2metadata Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../layers/mamba/mamba2_metadata.py | 17 ++------- .../layers/mamba/mamba_mixer2.py | 38 ++++++++++--------- 2 files changed, 24 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 0239f4a01376..9be6b1058616 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -13,7 +13,6 @@ @dataclass class Mamba2Metadata: - has_prefill: bool has_initial_states: torch.Tensor prep_initial_states: bool @@ -23,9 +22,6 @@ class Mamba2Metadata: chunk_indices: torch.Tensor chunk_offsets: torch.Tensor - num_prefills: int - num_decodes: int - def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): @@ -63,7 +59,6 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): def prepare_mamba2_metadata( chunk_size: int, - # input_ids: torch.Tensor, attn_metadata: AttentionMetadata, ) -> Mamba2Metadata: @@ -71,9 +66,6 @@ def prepare_mamba2_metadata( # NOTE: in V0 we assume prefills are before decodes num_prefills = attn_metadata.num_prefills num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decodes = attn_metadata.num_decode_tokens - - has_prefill = num_prefills > 0 # Need flags to indicate if there are initial states # currently we really only support the FlashAttention backend @@ -90,7 +82,7 @@ def prepare_mamba2_metadata( # Compute seq_idx, chunk_indices and chunk_offsets for prefill only seq_idx = None chunk_indices, chunk_offsets = None, None - if has_prefill: + if num_prefills > 0: query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] seq_idx = torch.repeat_interleave(torch.arange( num_prefills, dtype=torch.int32, device=query_start_loc.device), @@ -105,12 +97,9 @@ def prepare_mamba2_metadata( chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( seq_idx, chunk_size) - return Mamba2Metadata(has_prefill=has_prefill, - has_initial_states=has_initial_states, + return Mamba2Metadata(has_initial_states=has_initial_states, prep_initial_states=prep_initial_states, chunk_size=chunk_size, seq_idx=seq_idx, chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, - num_prefills=num_prefills, - num_decodes=num_decodes) + chunk_offsets=chunk_offsets) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 464cc975359f..ff356f69db5a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -388,9 +388,15 @@ def forward_cuda( # mamba2_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they - # are the same and reused for all mamba layers in the same iteration + # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + num_prefills = attn_metadata.num_prefills # #requests + num_decodes = attn_metadata.num_decode_tokens # #tokens==#requests + num_prefill_tokens = attn_metadata.num_prefill_tokens # #tokens + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size @@ -410,7 +416,9 @@ def forward_cuda( conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if mamba2_metadata.has_prefill: + # causal_conv1d_fn deals with both prefill and decode if input + # has prefill requests. + if has_prefill: # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| @@ -454,10 +462,9 @@ def forward_cuda( ) # 3. State Space Model sequence transformation - # Separate prefill and decode by slicing varlen input - num_prefills = mamba2_metadata.num_prefills # requests - num_decodes = mamba2_metadata.num_decodes # requests (also tokens) - num_prefill_tokens = attn_metadata.num_prefill_tokens # tokens + + # Separate prefill and decode by splitting varlen input + # Split along token dimension hidden_states_p, hidden_states_d = torch.split( hidden_states, [num_prefill_tokens, num_decodes], @@ -478,6 +485,7 @@ def forward_cuda( [num_prefill_tokens, num_decodes], dim=0, ) + # Split along batch dimension state_indices_tensor_p, state_indices_tensor_d = torch.split( mamba_cache_params.state_indices_tensor, [num_prefills, num_decodes], @@ -487,7 +495,7 @@ def forward_cuda( hidden_states_list = [] # Process prefill requests - if num_prefills > 0: + if has_prefill: initial_states = None if (mamba2_metadata.has_initial_states is not None and mamba2_metadata.prep_initial_states): @@ -530,7 +538,7 @@ def forward_cuda( hidden_states_list.append(scan_output.view(num_prefill_tokens, -1)) # Process decode requests - if num_decodes > 0: + if has_decode: n_groups = self.n_groups // self.tp_size A_d = self.A[:, None, ...][:, :, None].expand( -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) @@ -542,13 +550,9 @@ def forward_cuda( hidden_states_d = hidden_states_d.view( -1, self.num_heads // self.tp_size, self.head_dim) - # - the hidden is reshaped into number of current batches - # - in this case there is no more prefill, so the batches gen - # 1 token at a time - # - thus hidden will be (bs, num_heads, head_dim) + # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected - # using "mamba_cache_params.state_indices_tensor", just as - # above in the prefill case + # using state_indices_tensor_d hidden_states_d = selective_state_update( mamba_cache_params.ssm_state, @@ -567,12 +571,12 @@ def forward_cuda( hidden_states_d.view(-1, (self.num_heads // self.tp_size) * self.head_dim)) - # Merge prefill and decode outputs + # Merge prefill and decode outputs before passing to gated MLP hidden_states = torch.vstack(hidden_states_list) - # # 4. gated MLP + # 4. gated MLP hidden_states = self.norm(hidden_states, gate) - # # 5. Final linear projection + # 5. Final linear projection out, _ = self.out_proj(hidden_states) return out From 03275127f06f55ae9dedb4ff20152b70ea29b37b Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:10:18 -0400 Subject: [PATCH 07/13] narrow down prep initial state Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- vllm/model_executor/layers/mamba/mamba2_metadata.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 9be6b1058616..87479f9d6616 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -77,7 +77,9 @@ def prepare_mamba2_metadata( and attn_metadata.context_lens_tensor is not None): has_initial_states = attn_metadata.context_lens_tensor > 0 # [batch,] # precompute flag to avoid device syncs later in mamba2 forwards - prep_initial_states = torch.any(has_initial_states).item() + # prep is only needed for mamba2 ssd prefill processing + prep_initial_states = torch.any( + has_initial_states[:num_prefills]).item() # Compute seq_idx, chunk_indices and chunk_offsets for prefill only seq_idx = None From 4d37e5942ff30a21be23bbaa6958f07fe7621975 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:30:02 -0400 Subject: [PATCH 08/13] use query_start_loc to compute chunk_indices and chunk_offsets Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../layers/mamba/mamba2_metadata.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 87479f9d6616..df9800130e21 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -23,21 +23,23 @@ class Mamba2Metadata: chunk_offsets: torch.Tensor -def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): +def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, + chunk_size: int, + total_seqlens: int): - # convert seq_idx to chunk indices and offsets - # - derive the cu_seqlens - _, cu_seqlens = torch.where(seq_idx.diff()) - cu_seqlens += 1 + cu_seqlens = query_start_loc[1:] # remove prepended 0 # outputs will have length expansion of chunks that do not divide # chunk_size - N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) - chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, + dtype=torch.int, + device=query_start_loc.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=query_start_loc.device) - cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] p = 0 # num of insertions for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): @@ -96,8 +98,9 @@ def prepare_mamba2_metadata( # forward and reuse them in mamba layers. If not needed, they will be # ignored inside mamba kernels. if prep_initial_states: - chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( - seq_idx, chunk_size) + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + query_start_loc, chunk_size, num_prefill_tokens) return Mamba2Metadata(has_initial_states=has_initial_states, prep_initial_states=prep_initial_states, From 2ea3051e207db7153419bdbc891b8f55e6745702 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Wed, 23 Apr 2025 10:40:57 -0400 Subject: [PATCH 09/13] refer to the correct B and C tensor in decode path Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index ff356f69db5a..0dc8def6ce37 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -545,8 +545,8 @@ def forward_cuda( dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D_d = self.D[:, None, ...].expand(-1, self.head_dim) - B_d = B_d.view(-1, n_groups, B.shape[1] // n_groups) - C_d = C_d.view(-1, n_groups, C.shape[1] // n_groups) + B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups) + C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups) hidden_states_d = hidden_states_d.view( -1, self.num_heads // self.tp_size, self.head_dim) From 087c88824355162ca18e65dd65119109a8237380 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Wed, 23 Apr 2025 12:59:52 -0400 Subject: [PATCH 10/13] improve comments Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- vllm/model_executor/layers/mamba/mamba2_metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index df9800130e21..872d78bf570a 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -71,14 +71,14 @@ def prepare_mamba2_metadata( # Need flags to indicate if there are initial states # currently we really only support the FlashAttention backend - # initial states are only relevant for prefills has_initial_states = None prep_initial_states = False if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata, PlaceholderAttentionMetadata)) and attn_metadata.context_lens_tensor is not None): + # keeping flags for both prefill and decode causal_conv1d varlen has_initial_states = attn_metadata.context_lens_tensor > 0 # [batch,] - # precompute flag to avoid device syncs later in mamba2 forwards + # precompute flag to avoid device syncs later in mamba2 layer forwards # prep is only needed for mamba2 ssd prefill processing prep_initial_states = torch.any( has_initial_states[:num_prefills]).item() From aa3b8fa1c142a62d518498294e9f6c81487cb6e2 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Thu, 24 Apr 2025 14:08:00 -0400 Subject: [PATCH 11/13] helper function interface change Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index ee908105f557..f5e751bea414 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -6,7 +6,7 @@ from einops import rearrange, repeat from vllm.model_executor.layers.mamba.mamba2_metadata import ( - _seq_idx_to_chunk_indices_offsets) + _query_start_loc_to_chunk_indices_offsets) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform @@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, last_taken, exhausted, n_heads, d_head, itype): - chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( - seq_idx, chunk_size) + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1]) Y, new_states = mamba_chunk_scan_combined( X, From 4bdff845cae6714f64a06855020e0345815611b7 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Mon, 5 May 2025 15:13:48 -0400 Subject: [PATCH 12/13] improve comment Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 0dc8def6ce37..e8519521f763 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -391,9 +391,9 @@ def forward_cuda( # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - num_prefills = attn_metadata.num_prefills # #requests - num_decodes = attn_metadata.num_decode_tokens # #tokens==#requests - num_prefill_tokens = attn_metadata.num_prefill_tokens # #tokens + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count has_prefill = num_prefills > 0 has_decode = num_decodes > 0 From 20452d3899fcb95fd9515f39b50e4852c0d91bd3 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Tue, 6 May 2025 08:15:05 -0400 Subject: [PATCH 13/13] mamba2 metadata func arg change Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- vllm/model_executor/models/granitemoehybrid.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index dea9a0da3127..706e648f1b4f 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -338,7 +338,6 @@ def forward( attn_metadata = get_forward_context().attn_metadata mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.mamba_chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, )