Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/kernels/test_mha_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def clear_cache():
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_mha_attn_platform(device: str):
"""
Test that the attention selector between different platform and device.
Test the attention selector between different platform and device.
"""
torch.set_default_dtype(torch.float16)

Expand All @@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str):
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN
assert attn.attn_backend == _Backend.XFORMERS

with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 72, scale=1)
Expand Down
8 changes: 8 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def __init__(
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
Expand Down Expand Up @@ -240,6 +243,11 @@ def forward(
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)

if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

Expand Down
Loading