Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")

#
# Supported/expected torch versions for CUDA/ROCm.
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
NUM_BLOCKS = 128 * 1024
PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_NAVI = "gfx1" in GPU_ARCH


@torch.inference_mode()
Expand Down Expand Up @@ -83,7 +85,7 @@ def main(
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
if not args.custom_paged_attn:
if not args.custom_paged_attn and not ON_NAVI:
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
Expand Down Expand Up @@ -169,6 +171,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_cache_dtype,
k_scale,
v_scale,
ON_NAVI,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down
1,976 changes: 1,803 additions & 173 deletions csrc/rocm/attention.cu

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);
torch::Tensor& v_scale, bool is_navi);
3 changes: 2 additions & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
" Tensor k_scale, Tensor v_scale,"
" bool is_navi) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}

Expand Down
13 changes: 12 additions & 1 deletion tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip()

is_rocm_navi = False
if current_platform.is_rocm():
is_rocm_navi = "gfx1" in torch.cuda.get_device_properties(
"cuda").gcnArchName

if (version == "rocm" and is_rocm_navi
and (kv_cache_dtype == "fp8" or head_size != 128
or block_size != 16 or use_alibi)):
pytest.skip()

global PARTITION_SIZE

current_platform.seed_everything(seed)
Expand Down Expand Up @@ -282,13 +292,14 @@ def test_paged_attention(
kv_cache_dtype,
k_scale,
v_scale,
is_rocm_navi,
)

opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
kv_cache_dtype, k_scale, v_scale, is_rocm_navi),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

Expand Down
4 changes: 3 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ def paged_attention_rocm(
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
is_navi: bool = False,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale)
kv_cache_dtype, k_scale, v_scale,
is_navi)


# pos encoding ops
Expand Down
35 changes: 25 additions & 10 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_PARTITION_SIZE_ROCM = 256
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_NAVI3_NAVI4 = any(arch in _GPU_ARCH for arch in ["gfx11", "gfx12"])
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])


Expand Down Expand Up @@ -792,7 +793,8 @@ def forward(
gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len)
decode_meta.max_decode_seq_len, self.kv_cache_dtype,
self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
Expand Down Expand Up @@ -839,6 +841,7 @@ def forward(
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
_ON_NAVI,
)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
Expand Down Expand Up @@ -901,12 +904,24 @@ def _sdpa_attention(
return output


def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (_ON_MI250_MI300 and not _ON_NAVI
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
def _use_rocm_custom_paged_attention(
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
if _ON_NAVI3_NAVI4:
return ((qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 32768 and alibi_slopes is None
and kv_cache_dtype == "auto")
else:
return (_ON_MI250_MI300 and not _ON_NAVI
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 32768)