38
38
39
39
from vllm .attention import AttentionMetadata
40
40
from vllm .config import VllmConfig
41
- from vllm .distributed import parallel_state
41
+ from vllm .distributed import parallel_state , tensor_model_parallel_all_gather
42
42
from vllm .distributed import utils as dist_utils
43
43
from vllm .logger import init_logger
44
44
from vllm .model_executor import SamplingMetadata
@@ -239,6 +239,8 @@ def __init__(
239
239
super ().__init__ ()
240
240
# Per attention head and per partition values.
241
241
world_size = parallel_state .get_tensor_model_parallel_world_size ()
242
+ self .tp_size = world_size
243
+ self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
242
244
self .hidden_size_per_attention_head = dist_utils .divide (
243
245
projection_size , num_heads )
244
246
self .num_attention_heads_per_partition = dist_utils .divide (
@@ -261,24 +263,41 @@ def __init__(
261
263
raise RuntimeError (
262
264
f"Qwen2-VL does not support { self .attn_backend } backend now." )
263
265
266
+ def split_qkv (self , qkv : torch .Tensor ) -> tuple [torch .Tensor , ...]:
267
+ # [s, b, 3 * head * head_dim]
268
+ seq_len , bs , _ = qkv .shape
269
+ if self .tp_size > 1 :
270
+ qkv = tensor_model_parallel_all_gather (qkv )
271
+
272
+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
273
+ q , k , v = qkv .chunk (3 , dim = 2 )
274
+
275
+ # 3 * [s, b, head * head_dim]
276
+ if self .tp_size > 1 :
277
+ splitter = partial (dist_utils .split_tensor_along_last_dim ,
278
+ num_partitions = self .tp_size )
279
+ q = splitter (q )[self .tp_rank ]
280
+ k = splitter (k )[self .tp_rank ]
281
+ v = splitter (v )[self .tp_rank ]
282
+
283
+ # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
284
+ new_shape = (seq_len , bs , self .num_attention_heads_per_partition ,
285
+ self .hidden_size_per_attention_head )
286
+ q , k , v = (x .view (* new_shape ) for x in (q , k , v ))
287
+ return q , k , v
288
+
264
289
def forward (
265
290
self ,
266
291
x : torch .Tensor ,
267
292
cu_seqlens : torch .Tensor ,
268
293
rotary_pos_emb : torch .Tensor ,
269
294
) -> torch .Tensor :
270
- # [s, b, c] --> [s, b, head * 3 * head_dim]
271
- x , _ = self .qkv (x )
272
295
273
- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
274
- new_x_shape = x .size ()[:- 1 ] + (
275
- self .num_attention_heads_per_partition ,
276
- 3 * self .hidden_size_per_attention_head ,
277
- )
278
- x = x .view (* new_x_shape )
296
+ # [s, b, c] --> [s, b, 3 * head * head_dim]
297
+ x , _ = self .qkv (x )
279
298
280
- # [s, b, head, 3 * head_dim] -- > 3 [s, b, head, head_dim]
281
- q , k , v = dist_utils . split_tensor_along_last_dim ( x , 3 )
299
+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
300
+ q , k , v = self . split_qkv ( x )
282
301
batch_size = q .shape [1 ]
283
302
284
303
q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
@@ -614,24 +633,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
614
633
weight_loader (param , loaded_weight , shard_id )
615
634
break
616
635
else :
617
- if name .endswith ("qkv.weight" ):
618
- visual_num_heads = self .num_heads
619
- visual_embed_dim = self .embed_dim
620
- head_size = visual_embed_dim // visual_num_heads
621
- loaded_weight = loaded_weight .view (3 , visual_num_heads ,
622
- head_size ,
623
- visual_embed_dim )
624
- loaded_weight = loaded_weight .transpose (0 , 1 )
625
- loaded_weight = loaded_weight .reshape (- 1 , visual_embed_dim )
626
- elif name .endswith ("qkv.bias" ):
627
- visual_num_heads = self .num_heads
628
- visual_embed_dim = self .embed_dim
629
- head_size = visual_embed_dim // visual_num_heads
630
- loaded_weight = loaded_weight .view (3 , visual_num_heads ,
631
- head_size )
632
- loaded_weight = loaded_weight .transpose (0 , 1 )
633
- loaded_weight = loaded_weight .reshape (- 1 )
634
-
635
636
param = params_dict [name ]
636
637
weight_loader = getattr (param , "weight_loader" ,
637
638
default_weight_loader )
@@ -935,6 +936,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
935
936
embedding_modules = {}
936
937
embedding_padding_modules = []
937
938
939
+ # BitandBytes specific attributes
940
+ bitsandbytes_stacked_params_mapping = {
941
+ # shard_name, weight_name, index
942
+ "q_proj" : ("qkv_proj" , 0 ),
943
+ "k_proj" : ("qkv_proj" , 1 ),
944
+ "v_proj" : ("qkv_proj" , 2 ),
945
+ "gate_proj" : ("gate_up_proj" , 0 ),
946
+ "up_proj" : ("gate_up_proj" , 1 ),
947
+ }
948
+
938
949
# To ensure correct weight loading and mapping.
939
950
hf_to_vllm_mapper = WeightsMapper (orig_to_new_prefix = {
940
951
"lm_head." : "language_model.lm_head." ,
0 commit comments