8
8
from torch import nn
9
9
from transformers import FalconH1Config
10
10
11
+ from vllm import envs
11
12
from vllm .attention .layer import Attention
12
13
from vllm .config import CacheConfig , VllmConfig
13
14
from vllm .distributed import divide , get_tensor_model_parallel_world_size
33
34
from vllm .model_executor .sampling_metadata import SamplingMetadata
34
35
from vllm .sequence import IntermediateTensors
35
36
36
- from .interfaces import (HasInnerState , IsHybrid , SupportsLoRA , SupportsPP ,
37
- SupportsV0Only )
37
+ from .interfaces import HasInnerState , IsHybrid , SupportsLoRA , SupportsPP
38
38
from .utils import (PPMissingLayer , is_pp_missing_parameter ,
39
39
make_empty_intermediate_tensors_factory , make_layers ,
40
40
maybe_prefix )
@@ -85,6 +85,7 @@ def __init__(
85
85
config : FalconH1Config ,
86
86
cache_config : Optional [CacheConfig ] = None ,
87
87
quant_config : Optional [QuantizationConfig ] = None ,
88
+ prefix : str = "" ,
88
89
) -> None :
89
90
super ().__init__ ()
90
91
self .config = config
@@ -107,6 +108,8 @@ def __init__(
107
108
activation = config .hidden_act ,
108
109
quant_config = quant_config ,
109
110
use_rms_norm = config .mamba_rms_norm ,
111
+ prefix = f"{ prefix } .mixer" ,
112
+ chunk_size = config .mamba_chunk_size ,
110
113
)
111
114
# n_groups is overridden later by `MambaMixer2`
112
115
self .groups_time_state_size = self .mamba .n_groups * config .mamba_d_state
@@ -316,18 +319,26 @@ def __init__(
316
319
prefix : str = "" ,
317
320
) -> None :
318
321
super ().__init__ ()
322
+
319
323
# Instantiate the attention branch
320
324
self .self_attn = FalconH1AttentionDecoderLayer (
321
325
config = config ,
322
326
cache_config = cache_config ,
323
327
quant_config = quant_config ,
324
328
prefix = prefix ,
325
329
)
330
+
331
+ # In V1 all attention/ssm layers must have
332
+ # different index in prefix
333
+ ssm_layer_idx = config .num_hidden_layers + layer_idx
334
+ ssm_prefix = prefix .split ("." )[0 ] + f".{ ssm_layer_idx } "
335
+
326
336
# Instantiate the SSM branch
327
337
self .mamba = FalconH1SSMDecoderLayer (
328
338
config = config ,
329
339
cache_config = cache_config ,
330
340
quant_config = quant_config ,
341
+ prefix = ssm_prefix ,
331
342
)
332
343
self .ssm_out_multiplier = config .ssm_out_multiplier
333
344
self .ssm_in_multiplier = config .ssm_in_multiplier
@@ -452,10 +463,16 @@ def forward(
452
463
# proper continuous batching computation including
453
464
# chunked prefill
454
465
attn_metadata = get_forward_context ().attn_metadata
455
- mamba2_metadata = prepare_mamba2_metadata (
456
- chunk_size = self .config .mamba_chunk_size ,
457
- attn_metadata = attn_metadata ,
458
- )
466
+
467
+ if not envs .VLLM_USE_V1 :
468
+ mamba2_metadata = prepare_mamba2_metadata (
469
+ chunk_size = self .config .mamba_chunk_size ,
470
+ attn_metadata = attn_metadata ,
471
+ )
472
+ else :
473
+ # v1 get mamba2_metadata from forward_context
474
+ mamba2_metadata = None
475
+
459
476
if get_pp_group ().is_first_rank :
460
477
if inputs_embeds is not None :
461
478
hidden_states = inputs_embeds * self .embedding_multiplier
@@ -468,7 +485,9 @@ def forward(
468
485
469
486
for i in range (self .start_layer , self .end_layer ):
470
487
layer = self .layers [i ]
471
- layer_mamba_cache_params = mamba_cache_params .at_layer_idx (i )
488
+ layer_mamba_cache_params = None
489
+ if mamba_cache_params :
490
+ layer_mamba_cache_params = mamba_cache_params .at_layer_idx (i )
472
491
hidden_states = layer (
473
492
positions = positions ,
474
493
hidden_states = hidden_states ,
@@ -484,7 +503,7 @@ def forward(
484
503
485
504
486
505
class FalconH1ForCausalLM (nn .Module , HasInnerState , SupportsLoRA , SupportsPP ,
487
- IsHybrid , SupportsV0Only ):
506
+ IsHybrid ):
488
507
packed_modules_mapping = {
489
508
"qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
490
509
"gate_up_proj" : ["gate_proj" , "up_proj" ],
@@ -558,15 +577,19 @@ def forward(
558
577
inputs_embeds : Optional [torch .Tensor ] = None ,
559
578
** kwargs ,
560
579
):
561
- if self .mamba_cache is None :
562
- self .mamba_cache = MambaCacheManager (
563
- self .vllm_config ,
564
- self .lm_head .weight .dtype
565
- if hasattr (self .lm_head , 'weight' ) else torch .bfloat16 ,
566
- self .config .num_hidden_layers ,
567
- * self ._get_mamba_cache_shape (),
568
- )
569
- mamba_cache_params = self .mamba_cache .current_run_tensors (** kwargs )
580
+
581
+ mamba_cache_params = None
582
+ if not envs .VLLM_USE_V1 :
583
+ if self .mamba_cache is None :
584
+ self .mamba_cache = MambaCacheManager (
585
+ self .vllm_config ,
586
+ self .lm_head .weight .dtype if hasattr (
587
+ self .lm_head , 'weight' ) else torch .bfloat16 ,
588
+ self .config .num_hidden_layers ,
589
+ * self ._get_mamba_cache_shape (),
590
+ )
591
+ mamba_cache_params = self .mamba_cache .current_run_tensors (** kwargs )
592
+
570
593
hidden_states = self .model (
571
594
input_ids ,
572
595
positions ,
0 commit comments