Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b6b00d7
init
SageMoore Feb 5, 2025
fa52268
temporarily remove torch from requirements-build
SageMoore Feb 5, 2025
f563276
move rocm logic to its own attention backend
SageMoore Feb 6, 2025
2a03b92
actually add backend
SageMoore Feb 6, 2025
4bdf7de
more rocm refactoring
SageMoore Feb 7, 2025
875fcfc
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 7, 2025
e507e30
more rocm refactoring
SageMoore Feb 7, 2025
b9ce259
hack to fix the multiprocessing isssue
SageMoore Feb 7, 2025
f2cc5e3
minor print fix
SageMoore Feb 7, 2025
d6f6c5c
remove cruft
SageMoore Feb 7, 2025
2bf214a
format
SageMoore Feb 7, 2025
11411cb
modify requirements files
SageMoore Feb 7, 2025
c2499bf
remove basic.py changes
SageMoore Feb 7, 2025
cf6f691
cleanup
SageMoore Feb 7, 2025
4505f53
add support for passing in softmax scales to the context_attn_fwd
SageMoore Feb 7, 2025
9a0416a
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 7, 2025
ef9ae86
added requirements-rocm-build
SageMoore Feb 10, 2025
0ccef65
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 10, 2025
a00a2d9
minor setup.py fix
SageMoore Feb 10, 2025
afb15f5
add batch size back in
SageMoore Feb 10, 2025
08a25b7
revert setup.py change
SageMoore Feb 10, 2025
55eb036
update setup.py
SageMoore Feb 10, 2025
4b62de2
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 11, 2025
442bc7b
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 12, 2025
9472636
minor fix
SageMoore Feb 12, 2025
21d8d6a
update error messages
SageMoore Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions requirements-rocm-build.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Common dependencies
-r requirements-common.txt

--extra-index-url https://download.pytorch.org/whl/rocm6.2
torch==2.5.1
torchvision==0.20.1
torchaudio==2.5.1

cmake>=3.26
ninja
packaging
setuptools>=61
setuptools-scm>=8
wheel
jinja2
amdsmi==6.2.4
6 changes: 4 additions & 2 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,8 @@ def context_attention_fwd(q,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None):
sliding_window=None,
sm_scale=None):

q_dtype_is_f32 = q.dtype is torch.float32
# need to reduce num. blocks when using fp32
Expand Down Expand Up @@ -759,7 +760,8 @@ def context_attention_fwd(q,
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)

sm_scale = 1.0 / (Lq**0.5)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]

Expand Down
45 changes: 30 additions & 15 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import os
from functools import lru_cache
from typing import TYPE_CHECKING, Dict, List, Optional

Expand Down Expand Up @@ -29,12 +28,6 @@
except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e)

if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
logger.warning("`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
" `spawn` instead.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
Comment on lines -32 to -36
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ROCm support fork now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me exactly what was broken before but fork does seem to work.


# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

Expand Down Expand Up @@ -84,6 +77,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
return "vllm.attention.backends.triton_mla.TritonMLABackend"
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1:
logger.info("Using ROCm Attention backend on V1 engine.")
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
Expand All @@ -102,7 +98,11 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
# NOTE: When using V1 this function is called when overriding the
# engine args. Calling torch.cuda.get_device_name(device_id) here
# will result in the ROCm context being initialized before other
# processes can be created.
return "AMD"
Comment on lines +101 to +105
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an issue / do you know why this isn't an issue on CUDA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. So CUDA uses NVML to get the device name which doesn't initialize the CUDA context. The problem is that this is called when we override the engine args in V1. mreso gave a good explanation in a comment on this PR.

I think just returning "AMD" is fine for now. We can investigate if AMD has an NVML-like system in the future if we find a need to add dispatching logic that's specific to individual AMD GPU models. The CPU platforms do something similar. I.E return "CPU" or "openvino".


@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
Expand All @@ -129,15 +129,30 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Multi-step scheduling is not supported (and not "
"needed) on VLLM V1. Please launch without "
"--num-scheduler-steps.")
else:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Speculative decoding is not yet supported on VLLM V1."
)
else:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import flash_attn_varlen_func

if current_platform.is_cuda():
from vllm.vllm_flash_attn import flash_attn_varlen_func

logger = init_logger(__name__)

Expand Down
182 changes: 182 additions & 0 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention on rocm"""
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata

logger = init_logger(__name__)


class ROCmAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@staticmethod
def get_name() -> str:
return "ROCM_ATTN_VLLM_V1"

@staticmethod
def get_impl_cls() -> Type["ROCmAttentionImpl"]:
return ROCmAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False


class ROCmAttentionImpl(AttentionImpl):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"ROCmAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by ROCmAttention. "
f"Supported head sizes are: {support_head_sizes}.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmAttentionImpl")

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."

if attn_metadata is None:
# Profiling run.
return output

assert attn_metadata.use_cascade is False

# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.

num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

# TODO(sage): Refactor the context_attention_fwd kernel so that this
# overhead can be removed
context_lens = torch.empty_like(attn_metadata.seq_lens)
batch_size = len(attn_metadata.query_start_loc) - 1
assert len(context_lens) == batch_size
for i in range(batch_size):
query_start = attn_metadata.query_start_loc[i]
query_end = attn_metadata.query_start_loc[i + 1]
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
query_start)

# Compute attention and update output up to `num_actual_tokens`.
context_attention_fwd(q=query[:num_actual_tokens],
k=key[:num_actual_tokens],
v=value[:num_actual_tokens],
o=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
k_cache=key_cache,
v_cache=value_cache,
b_loc=attn_metadata.block_table,
b_start_loc=attn_metadata.query_start_loc,
b_seq_len=attn_metadata.seq_lens,
b_ctx_len=context_lens,
max_input_len=attn_metadata.max_query_len,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
return output