-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
[ROCm][V1] Add intial ROCm support to V1 #12790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b6b00d7
fa52268
f563276
2a03b92
4bdf7de
875fcfc
e507e30
b9ce259
f2cc5e3
d6f6c5c
2bf214a
11411cb
c2499bf
cf6f691
4505f53
9a0416a
ef9ae86
0ccef65
a00a2d9
afb15f5
08a25b7
55eb036
4b62de2
442bc7b
9472636
21d8d6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Common dependencies | ||
-r requirements-common.txt | ||
|
||
--extra-index-url https://download.pytorch.org/whl/rocm6.2 | ||
torch==2.5.1 | ||
torchvision==0.20.1 | ||
torchaudio==2.5.1 | ||
|
||
cmake>=3.26 | ||
ninja | ||
packaging | ||
setuptools>=61 | ||
setuptools-scm>=8 | ||
wheel | ||
jinja2 | ||
amdsmi==6.2.4 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
from functools import lru_cache | ||
from typing import TYPE_CHECKING, Dict, List, Optional | ||
|
||
|
@@ -29,12 +28,6 @@ | |
except ImportError as e: | ||
logger.warning("Failed to import from vllm._rocm_C with %r", e) | ||
|
||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: | ||
logger.warning("`fork` method is not supported by ROCm. " | ||
"VLLM_WORKER_MULTIPROC_METHOD is overridden to" | ||
" `spawn` instead.") | ||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" | ||
|
||
# Models not supported by ROCm. | ||
_ROCM_UNSUPPORTED_MODELS: List[str] = [] | ||
|
||
|
@@ -84,6 +77,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, | |
return "vllm.attention.backends.triton_mla.TritonMLABackend" | ||
selected_backend = (_Backend.ROCM_FLASH if selected_backend | ||
== _Backend.FLASH_ATTN else selected_backend) | ||
if envs.VLLM_USE_V1: | ||
logger.info("Using ROCm Attention backend on V1 engine.") | ||
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend" | ||
if selected_backend == _Backend.ROCM_FLASH: | ||
if not cls.has_device_capability(90): | ||
# not Instinct series GPUs. | ||
|
@@ -102,7 +98,11 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: | |
@classmethod | ||
@lru_cache(maxsize=8) | ||
def get_device_name(cls, device_id: int = 0) -> str: | ||
return torch.cuda.get_device_name(device_id) | ||
# NOTE: When using V1 this function is called when overriding the | ||
# engine args. Calling torch.cuda.get_device_name(device_id) here | ||
# will result in the ROCm context being initialized before other | ||
# processes can be created. | ||
return "AMD" | ||
Comment on lines
+101
to
+105
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this an issue / do you know why this isn't an issue on CUDA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. So CUDA uses NVML to get the device name which doesn't initialize the CUDA context. The problem is that this is called when we override the engine args in V1. mreso gave a good explanation in a comment on this PR. I think just returning "AMD" is fine for now. We can investigate if AMD has an NVML-like system in the future if we find a need to add dispatching logic that's specific to individual AMD GPU models. The CPU platforms do something similar. I.E return "CPU" or "openvino". |
||
|
||
@classmethod | ||
def get_device_total_memory(cls, device_id: int = 0) -> int: | ||
|
@@ -129,15 +129,30 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
scheduler_config = vllm_config.scheduler_config | ||
if parallel_config.worker_cls == "auto": | ||
if scheduler_config.is_multi_step: | ||
parallel_config.worker_cls = \ | ||
"vllm.worker.multi_step_worker.MultiStepWorker" | ||
if envs.VLLM_USE_V1: | ||
raise NotImplementedError( | ||
"Multi-step scheduling is not supported (and not " | ||
"needed) on VLLM V1. Please launch without " | ||
"--num-scheduler-steps.") | ||
else: | ||
parallel_config.worker_cls = \ | ||
"vllm.worker.multi_step_worker.MultiStepWorker" | ||
elif vllm_config.speculative_config: | ||
parallel_config.worker_cls = \ | ||
"vllm.spec_decode.spec_decode_worker.create_spec_worker" | ||
parallel_config.sd_worker_cls = \ | ||
"vllm.worker.worker.Worker" | ||
if envs.VLLM_USE_V1: | ||
raise NotImplementedError( | ||
"Speculative decoding is not yet supported on VLLM V1." | ||
) | ||
else: | ||
parallel_config.worker_cls = \ | ||
"vllm.spec_decode.spec_decode_worker.create_spec_worker" | ||
parallel_config.sd_worker_cls = \ | ||
"vllm.worker.worker.Worker" | ||
else: | ||
parallel_config.worker_cls = "vllm.worker.worker.Worker" | ||
if envs.VLLM_USE_V1: | ||
parallel_config.worker_cls = \ | ||
"vllm.v1.worker.gpu_worker.Worker" | ||
else: | ||
parallel_config.worker_cls = "vllm.worker.worker.Worker" | ||
|
||
@classmethod | ||
def verify_model_arch(cls, model_arch: str) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Attention layer with PagedAttention on rocm""" | ||
from typing import Any, Dict, List, Optional, Tuple, Type | ||
|
||
import torch | ||
|
||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
AttentionMetadata, AttentionType) | ||
from vllm.attention.ops.paged_attn import PagedAttention | ||
from vllm.attention.ops.prefix_prefill import context_attention_fwd | ||
from vllm.logger import init_logger | ||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class ROCmAttentionBackend(AttentionBackend): | ||
|
||
accept_output_buffer: bool = True | ||
|
||
@staticmethod | ||
def get_supported_head_sizes() -> List[int]: | ||
return [32, 64, 96, 128, 160, 192, 224, 256] | ||
|
||
@staticmethod | ||
def get_name() -> str: | ||
return "ROCM_ATTN_VLLM_V1" | ||
|
||
@staticmethod | ||
def get_impl_cls() -> Type["ROCmAttentionImpl"]: | ||
return ROCmAttentionImpl | ||
|
||
@staticmethod | ||
def get_metadata_cls() -> Type["AttentionMetadata"]: | ||
return FlashAttentionMetadata | ||
|
||
@staticmethod | ||
def get_kv_cache_shape( | ||
num_blocks: int, | ||
block_size: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
) -> Tuple[int, ...]: | ||
if block_size % 16 != 0: | ||
raise ValueError("Block size must be a multiple of 16.") | ||
return (2, num_blocks, block_size, num_kv_heads, head_size) | ||
|
||
@staticmethod | ||
def use_cascade_attention(*args, **kwargs) -> bool: | ||
return False | ||
|
||
|
||
class ROCmAttentionImpl(AttentionImpl): | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: int, | ||
alibi_slopes: Optional[List[float]], | ||
sliding_window: Optional[int], | ||
kv_cache_dtype: str, | ||
blocksparse_params: Optional[Dict[str, Any]] = None, | ||
logits_soft_cap: Optional[float] = None, | ||
attn_type: AttentionType = AttentionType.DECODER, | ||
) -> None: | ||
if blocksparse_params is not None: | ||
raise ValueError( | ||
"ROCmAttention does not support block-sparse attention.") | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.scale = float(scale) | ||
self.num_kv_heads = num_kv_heads | ||
if alibi_slopes is not None: | ||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) | ||
self.alibi_slopes = alibi_slopes | ||
if sliding_window is None: | ||
self.sliding_window = (-1, -1) | ||
else: | ||
self.sliding_window = (sliding_window - 1, 0) | ||
self.kv_cache_dtype = kv_cache_dtype | ||
|
||
assert self.num_heads % self.num_kv_heads == 0 | ||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
||
support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes() | ||
if head_size not in support_head_sizes: | ||
raise ValueError( | ||
f"Head size {head_size} is not supported by ROCmAttention. " | ||
f"Supported head sizes are: {support_head_sizes}.") | ||
|
||
if attn_type != AttentionType.DECODER: | ||
raise NotImplementedError("Encoder self-attention and " | ||
"encoder/decoder cross-attention " | ||
"are not implemented for " | ||
"ROCmAttentionImpl") | ||
|
||
def forward( | ||
self, | ||
layer: torch.nn.Module, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: FlashAttentionMetadata, | ||
output: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
"""Forward pass with FlashAttention. | ||
|
||
Args: | ||
query: shape = [num_tokens, num_heads, head_size] | ||
key: shape = [num_tokens, num_kv_heads, head_size] | ||
value: shape = [num_tokens, num_kv_heads, head_size] | ||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] | ||
attn_metadata: Metadata for attention. | ||
Returns: | ||
shape = [num_tokens, num_heads * head_size] | ||
""" | ||
assert output is not None, "Output tensor must be provided." | ||
|
||
if attn_metadata is None: | ||
# Profiling run. | ||
return output | ||
|
||
assert attn_metadata.use_cascade is False | ||
|
||
# IMPORTANT! | ||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in | ||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead | ||
# in this method. For example, `view` and `slice` (or `[:n]`) operations | ||
# are surprisingly slow even in the case they do not invoke any GPU ops. | ||
# Minimize the PyTorch ops in this method as much as possible. | ||
# Whenever making a change in this method, please benchmark the | ||
# performance to make sure it does not introduce any overhead. | ||
|
||
num_actual_tokens = attn_metadata.num_actual_tokens | ||
key_cache, value_cache = PagedAttention.split_kv_cache( | ||
kv_cache, self.num_kv_heads, self.head_size) | ||
|
||
# Reshape the input keys and values and store them in the cache. | ||
PagedAttention.write_to_paged_cache( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
self.kv_cache_dtype, | ||
layer._k_scale, | ||
layer._v_scale, | ||
) | ||
|
||
# TODO(sage): Refactor the context_attention_fwd kernel so that this | ||
# overhead can be removed | ||
context_lens = torch.empty_like(attn_metadata.seq_lens) | ||
batch_size = len(attn_metadata.query_start_loc) - 1 | ||
assert len(context_lens) == batch_size | ||
for i in range(batch_size): | ||
query_start = attn_metadata.query_start_loc[i] | ||
query_end = attn_metadata.query_start_loc[i + 1] | ||
context_lens[i] = attn_metadata.seq_lens[i] - (query_end - | ||
query_start) | ||
|
||
# Compute attention and update output up to `num_actual_tokens`. | ||
context_attention_fwd(q=query[:num_actual_tokens], | ||
k=key[:num_actual_tokens], | ||
v=value[:num_actual_tokens], | ||
o=output[:num_actual_tokens], | ||
kv_cache_dtype=self.kv_cache_dtype, | ||
k_cache=key_cache, | ||
v_cache=value_cache, | ||
b_loc=attn_metadata.block_table, | ||
b_start_loc=attn_metadata.query_start_loc, | ||
b_seq_len=attn_metadata.seq_lens, | ||
b_ctx_len=context_lens, | ||
max_input_len=attn_metadata.max_query_len, | ||
k_scale=layer._k_scale, | ||
v_scale=layer._v_scale, | ||
alibi_slopes=self.alibi_slopes, | ||
sliding_window=self.sliding_window[0], | ||
sm_scale=self.scale) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does ROCm support fork now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear to me exactly what was broken before but fork does seem to work.