Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
83c39e5
feat: Added MambaStateShapeCalculator
asafgardin Jul 20, 2025
2df9e52
feat: Added Mamba1AttentionMetadata
asafgardin Jul 20, 2025
8ddb42e
feat: Added V1 code to mamba
asafgardin Jul 20, 2025
52699f4
fix: Removed unnecessary mamba1 metadata dataclass
asafgardin Jul 20, 2025
57f7316
fix: Updated configs in mamba
asafgardin Jul 20, 2025
6f4b1db
refactor: Added v1 condition
asafgardin Jul 20, 2025
5165e0a
fix: Removed unnecesary ignore
asafgardin Jul 20, 2025
c4a3bcd
feat: Added mamba type to identify mamba version
asafgardin Jul 20, 2025
476ba5b
fix: Added mamba_type property to mamba_base
asafgardin Jul 20, 2025
e747f27
fix: Lint
asafgardin Jul 20, 2025
9289273
fix: Lint
asafgardin Jul 20, 2025
3402d3a
fix: Ruff long lines
asafgardin Jul 20, 2025
4d54012
fix: Added context_lens_tensor
asafgardin Jul 21, 2025
898306f
feat: Updated jamba code to support v1
asafgardin Jul 22, 2025
60c1840
fix: CR changes
asafgardin Jul 22, 2025
f0566bf
fix: Conflicts
asafgardin Jul 22, 2025
4a4f9b1
fix: Lint
asafgardin Jul 22, 2025
1b1bad2
fix: Jamba forward
asafgardin Jul 22, 2025
f6d9311
refactor: Removed unnecessary fields
asafgardin Jul 22, 2025
dd098a4
refactor: Moved mamba_selectors
asafgardin Jul 23, 2025
364ea41
refactor: Removed v1 from mamab1 state shape
asafgardin Jul 23, 2025
19d54a6
refactor: Added _create_mamba1_state_tensors
asafgardin Jul 23, 2025
e7d3e7d
fix: Lint
asafgardin Jul 23, 2025
e20516e
test: Updated ssm tests to work in test_hybrid.py
asafgardin Jul 23, 2025
e0619a4
fix: Extra extend to state tensors
asafgardin Jul 24, 2025
d1c7063
fix: Moved logic to create_mamba2_state_tensors
asafgardin Jul 24, 2025
c357731
fix: Order in conv state shape
asafgardin Jul 24, 2025
8178700
fix: Conflicts
asafgardin Jul 27, 2025
7525091
fix: Added transponse to mixer
asafgardin Aug 3, 2025
0ec3208
feat: Updated selective scan fwd to work with strides
asafgardin Aug 3, 2025
611a771
fix: Conflicted changes in gpu_model_runner
asafgardin Aug 3, 2025
4a57395
fix: Lint
asafgardin Aug 3, 2025
a0958b4
fix: Lint in tests
asafgardin Aug 3, 2025
9b8fd69
refactor: Added v1 jamba support
asafgardin Aug 3, 2025
adcb205
fix: Lint
asafgardin Aug 3, 2025
5db14c5
fix: Removed unnecessary props from mamba1_attn
asafgardin Aug 3, 2025
24a1aa6
test: Removed mamba1 from unsupported models in test_oracle
asafgardin Aug 3, 2025
23d09f8
fix: Added stride to non-varlen in kernel
asafgardin Aug 4, 2025
1cc34d1
docs: Updated docs to show mamba1 is supported in v1
asafgardin Aug 4, 2025
4a7e7f1
fix: Lint
asafgardin Aug 4, 2025
f64191b
fix: Added call to forward_cuda
asafgardin Aug 4, 2025
8868aad
test: Removed enforce_eager
asafgardin Aug 4, 2025
e6f7781
fix: Moved call to forward
asafgardin Aug 4, 2025
9ab94d4
fix: CR comments
asafgardin Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/mamba/mamba_ssm/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ struct SSMParamsBase {
index_t out_d_stride;
index_t out_z_batch_stride;
index_t out_z_d_stride;
index_t ssm_states_batch_stride;
index_t ssm_states_dim_stride;
index_t ssm_states_dstate_stride;

// Common data pointers.
void *__restrict__ A_ptr;
Expand Down
18 changes: 14 additions & 4 deletions csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;

input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride;

float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) {
#pragma unroll
Expand Down Expand Up @@ -248,7 +250,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
// Initialize running total

scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);

SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
Expand All @@ -259,7 +261,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) {
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
}
}
#pragma unroll
Expand Down Expand Up @@ -481,6 +483,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.out_batch_stride = out.stride(1);
params.out_d_stride = out.stride(0);

params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
Comment on lines +486 to +488
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be pulled out of the if/else?


}
else{
if (!is_variable_B) {
Expand Down Expand Up @@ -509,6 +515,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
}
params.out_batch_stride = out.stride(0);
params.out_d_stride = out.stride(1);

params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
}
}

Expand Down
4 changes: 2 additions & 2 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ th {
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down
12 changes: 5 additions & 7 deletions docs/usage/v1_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
| **Decoder-only Models** | <nobr>🚀 Optimized</nobr> |
| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> |
| **Embedding Models** | <nobr>🟢 Functional</nobr> |
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟡 (Mamba-1)</nobr> |
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> |
| **Multimodal Models** | <nobr>🟢 Functional</nobr> |

vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol.
Expand All @@ -104,13 +104,11 @@ to enable simultaneous generation and embedding using the same engine instance i

#### Mamba Models

Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require
disabling prefix caching in V1.
Models using selective state-space mechanisms instead of standard transformer attention are supported.
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.

Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
these models currently require disabling prefix caching and using the FlashInfer attention backend in V1.

#### Encoder-Decoder Models
Expand Down
2 changes: 2 additions & 0 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
]

V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
"mistralai/Mamba-Codestral-7B-v0.1",
"ibm-ai-platform/Bamba-9B-v1",
"Zyphra/Zamba2-1.2B-instruct",
Expand Down
1 change: 0 additions & 1 deletion tests/v1/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
"state-spaces/mamba-130m-hf", # mamba1
]

MODEL = "meta-llama/Llama-3.2-1B-Instruct"
Expand Down
144 changes: 108 additions & 36 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import torch
from torch import nn
from torch.nn.parameter import Parameter

from vllm.attention.backends.abstract import AttentionMetadata
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata


# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer")
class MambaMixer(CustomOp):
class MambaMixer(MambaBase, CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
Expand All @@ -47,13 +54,16 @@ def __init__(self,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False):
is_lora_enabled: bool = False,
prefix: str = ""):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation
self.is_lora_enabled = is_lora_enabled
self.conv_kernel_size = conv_kernel_size
self.intermediate_size = intermediate_size

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
Expand Down Expand Up @@ -131,14 +141,62 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
has_weight=rms_norm_has_weight,
) if use_rms_norm else None

def forward_native(self, hidden_states: torch.Tensor,
conv_state: torch.Tensor, ssm_state: torch.Tensor):
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]

self.prefix = prefix

def forward(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
return CustomOp.forward(self, hidden_states, mamba_cache_params)
else:
return self.forward_cuda(hidden_states, mamba_cache_params)

def forward_native(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
pass

def forward_cuda(self, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams):
def forward_cuda(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata

if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba1_metadata = attn_metadata
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
query_start_loc = mamba1_metadata.query_start_loc
state_indices_tensor = mamba1_metadata.state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_state = mamba1_metadata.has_initial_states
context_lens_tensor = mamba1_metadata.context_lens_tensor
else:
assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor

attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if context_lens_tensor is not None:
has_initial_state = context_lens_tensor > 0

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
Expand All @@ -148,8 +206,12 @@ def forward_cuda(self, hidden_states: torch.Tensor,
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))

if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states = hidden_states.contiguous()
return self.out_proj(hidden_states.transpose(-2, -1))[0]

if query_start_loc is not None and context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
Expand All @@ -161,18 +223,18 @@ def forward_cuda(self, hidden_states: torch.Tensor,
conv_weights,
bias=self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=state_indices_tensor,
query_start_loc=query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
conv_state_indices=state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)

# 3. State Space Model sequence transformation
Expand Down Expand Up @@ -203,11 +265,10 @@ def forward_cuda(self, hidden_states: torch.Tensor,
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)

if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
if query_start_loc is not None and context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
Expand All @@ -216,24 +277,23 @@ def forward_cuda(self, hidden_states: torch.Tensor,
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
cache_indices=state_indices_tensor,
has_initial_state=has_initial_state,
query_start_loc=query_start_loc)
else:
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
out=scan_outputs)
selective_state_update(ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)

# 4. Final linear projection
Expand All @@ -245,3 +305,15 @@ def forward_cuda(self, hidden_states: torch.Tensor,
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states

def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.intermediate_size,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)

@property
def mamba_type(self) -> str:
return "mamba1"
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
extra_groups_for_head_shards, get_mamba_state_shape)
MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
Expand Down Expand Up @@ -278,8 +278,9 @@ def __init__(
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
self.n_groups = n_groups + extra_groups_for_head_shards(
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size)
self.n_groups = n_groups + groups

self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
self.conv1d = ColumnParallelLinear(
Expand Down Expand Up @@ -732,7 +733,7 @@ def forward_cuda(
output[:num_actual_tokens], _ = self.out_proj(hidden_states)

def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups=self.n_groups,
Expand Down
Loading