17
17
from vllm .model_executor .custom_op import CustomOp
18
18
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
19
19
RowParallelLinear )
20
+ from vllm .model_executor .layers .mamba .abstract import MambaBase
20
21
from vllm .model_executor .layers .mamba .mamba2_metadata import (Mamba2Metadata ,
21
22
update_metadata )
22
23
from vllm .model_executor .layers .mamba .ops .causal_conv1d import (
@@ -219,7 +220,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
219
220
220
221
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
221
222
@CustomOp .register ("mamba_mixer2" )
222
- class MambaMixer2 (CustomOp ):
223
+ class MambaMixer2 (MambaBase , CustomOp ):
223
224
"""
224
225
Compute ∆, A, B, C, and D the state space parameters and compute
225
226
the `contextualized_states`. A, D are input independent
@@ -231,22 +232,21 @@ class MambaMixer2(CustomOp):
231
232
"""
232
233
233
234
def __init__ (
234
- self ,
235
- hidden_size : int ,
236
- ssm_state_size : int ,
237
- conv_kernel_size : int ,
238
- intermediate_size : int ,
239
- use_conv_bias : bool ,
240
- use_bias : bool ,
241
- n_groups : int = 1 ,
242
- num_heads : int = 128 ,
243
- head_dim : int = 64 ,
244
- rms_norm_eps : float = 1e-5 ,
245
- activation : str = "silu" ,
246
- use_rms_norm : bool = True ,
247
- quant_config : Optional [QuantizationConfig ] = None ,
248
- prefix : str = "" ,
249
- chunk_size : int = - 1 , # the chunk size used by v1
235
+ self ,
236
+ hidden_size : int ,
237
+ ssm_state_size : int ,
238
+ conv_kernel_size : int ,
239
+ intermediate_size : int ,
240
+ use_conv_bias : bool ,
241
+ use_bias : bool ,
242
+ n_groups : int = 1 ,
243
+ num_heads : int = 128 ,
244
+ head_dim : int = 64 ,
245
+ rms_norm_eps : float = 1e-5 ,
246
+ activation : str = "silu" ,
247
+ use_rms_norm : bool = True ,
248
+ quant_config : Optional [QuantizationConfig ] = None ,
249
+ prefix : str = "" ,
250
250
):
251
251
super ().__init__ ()
252
252
@@ -428,10 +428,7 @@ def __init__(
428
428
# of Attention + v0 PP.
429
429
# The inner tuple is (conv_state, ssm_state)
430
430
self .kv_cache = [(torch .tensor ([]), torch .tensor ([]))]
431
- assert chunk_size != - 1 , "chunk_size must be set for v1"
432
431
433
- # NOTE: chunk_size may be -1 for models without v1 support
434
- self .chunk_size = chunk_size
435
432
self .prefix = prefix
436
433
437
434
def forward_native (
0 commit comments