Skip to content

Commit a655eb3

Browse files
jeejeeleeIsotr0py
andauthored
[Misc]Add BNB quantization for Qwen2VL (#11719)
Signed-off-by: Jee Jee Li <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent 1543914 commit a655eb3

File tree

1 file changed

+40
-29
lines changed

1 file changed

+40
-29
lines changed

vllm/model_executor/models/qwen2_vl.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
from vllm.attention import AttentionMetadata
4040
from vllm.config import VllmConfig
41-
from vllm.distributed import parallel_state
41+
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
4242
from vllm.distributed import utils as dist_utils
4343
from vllm.logger import init_logger
4444
from vllm.model_executor import SamplingMetadata
@@ -239,6 +239,8 @@ def __init__(
239239
super().__init__()
240240
# Per attention head and per partition values.
241241
world_size = parallel_state.get_tensor_model_parallel_world_size()
242+
self.tp_size = world_size
243+
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
242244
self.hidden_size_per_attention_head = dist_utils.divide(
243245
projection_size, num_heads)
244246
self.num_attention_heads_per_partition = dist_utils.divide(
@@ -261,24 +263,41 @@ def __init__(
261263
raise RuntimeError(
262264
f"Qwen2-VL does not support {self.attn_backend} backend now.")
263265

266+
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
267+
# [s, b, 3 * head * head_dim]
268+
seq_len, bs, _ = qkv.shape
269+
if self.tp_size > 1:
270+
qkv = tensor_model_parallel_all_gather(qkv)
271+
272+
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
273+
q, k, v = qkv.chunk(3, dim=2)
274+
275+
# 3 * [s, b, head * head_dim]
276+
if self.tp_size > 1:
277+
splitter = partial(dist_utils.split_tensor_along_last_dim,
278+
num_partitions=self.tp_size)
279+
q = splitter(q)[self.tp_rank]
280+
k = splitter(k)[self.tp_rank]
281+
v = splitter(v)[self.tp_rank]
282+
283+
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
284+
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
285+
self.hidden_size_per_attention_head)
286+
q, k, v = (x.view(*new_shape) for x in (q, k, v))
287+
return q, k, v
288+
264289
def forward(
265290
self,
266291
x: torch.Tensor,
267292
cu_seqlens: torch.Tensor,
268293
rotary_pos_emb: torch.Tensor,
269294
) -> torch.Tensor:
270-
# [s, b, c] --> [s, b, head * 3 * head_dim]
271-
x, _ = self.qkv(x)
272295

273-
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
274-
new_x_shape = x.size()[:-1] + (
275-
self.num_attention_heads_per_partition,
276-
3 * self.hidden_size_per_attention_head,
277-
)
278-
x = x.view(*new_x_shape)
296+
# [s, b, c] --> [s, b, 3 * head * head_dim]
297+
x, _ = self.qkv(x)
279298

280-
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
281-
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
299+
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
300+
q, k, v = self.split_qkv(x)
282301
batch_size = q.shape[1]
283302

284303
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
@@ -614,24 +633,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
614633
weight_loader(param, loaded_weight, shard_id)
615634
break
616635
else:
617-
if name.endswith("qkv.weight"):
618-
visual_num_heads = self.num_heads
619-
visual_embed_dim = self.embed_dim
620-
head_size = visual_embed_dim // visual_num_heads
621-
loaded_weight = loaded_weight.view(3, visual_num_heads,
622-
head_size,
623-
visual_embed_dim)
624-
loaded_weight = loaded_weight.transpose(0, 1)
625-
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
626-
elif name.endswith("qkv.bias"):
627-
visual_num_heads = self.num_heads
628-
visual_embed_dim = self.embed_dim
629-
head_size = visual_embed_dim // visual_num_heads
630-
loaded_weight = loaded_weight.view(3, visual_num_heads,
631-
head_size)
632-
loaded_weight = loaded_weight.transpose(0, 1)
633-
loaded_weight = loaded_weight.reshape(-1)
634-
635636
param = params_dict[name]
636637
weight_loader = getattr(param, "weight_loader",
637638
default_weight_loader)
@@ -935,6 +936,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
935936
embedding_modules = {}
936937
embedding_padding_modules = []
937938

939+
# BitandBytes specific attributes
940+
bitsandbytes_stacked_params_mapping = {
941+
# shard_name, weight_name, index
942+
"q_proj": ("qkv_proj", 0),
943+
"k_proj": ("qkv_proj", 1),
944+
"v_proj": ("qkv_proj", 2),
945+
"gate_proj": ("gate_up_proj", 0),
946+
"up_proj": ("gate_up_proj", 1),
947+
}
948+
938949
# To ensure correct weight loading and mapping.
939950
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
940951
"lm_head.": "language_model.lm_head.",

0 commit comments

Comments
 (0)