-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
register custom op for flash attn and use from torch.ops #7536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
8649a5b
register
youkaichao c2c8ca6
use
youkaichao 679b18a
use
youkaichao 94a39cc
manually mutate all
youkaichao bdbbe76
manually mutate all tensors
youkaichao 5b64f2d
add tests
youkaichao 9d97f7b
add tests
youkaichao 506eed5
change import
youkaichao d9105aa
update tests
youkaichao f827ad3
change args
youkaichao f0fe288
change import
youkaichao 8c322b0
rename
youkaichao 755dbaf
fix register fake
youkaichao fc2a4c2
add opcheck
youkaichao 495d2f0
fix alibi_slopes
youkaichao 76c5cec
update mutates_args
youkaichao 45bb131
add schema tests
youkaichao ee8d426
reduce number of heads to avoid OOM
youkaichao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
|
||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) | ||
def test_full_graph(model): | ||
# make sure these models can be captured in full graph mode | ||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" | ||
|
||
from vllm import LLM, SamplingParams | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
sampling_params = SamplingParams(temperature=0) | ||
llm = LLM(model="meta-llama/Meta-Llama-3-8B") | ||
llm.generate(prompts, sampling_params) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type | ||
|
||
import torch | ||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
|
@@ -18,6 +17,108 @@ | |
if TYPE_CHECKING: | ||
from vllm.worker.model_runner import ModelInputForGPUBuilder | ||
|
||
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func | ||
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache | ||
|
||
|
||
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirmed with @WoosukKwon , these two functions do not mutate the input. |
||
def flash_attn_varlen_func( | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
cu_seqlens_q: torch.Tensor, | ||
cu_seqlens_k: torch.Tensor, | ||
max_seqlen_q: int, | ||
max_seqlen_k: int, | ||
softmax_scale: Optional[float] = None, | ||
causal: bool = False, | ||
window_size: Optional[List[int]] = None, | ||
softcap: float = 0.0, | ||
alibi_slopes: Optional[torch.Tensor] = None, | ||
block_table: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
# custom op does not support tuple input | ||
real_window_size: Tuple[int, int] | ||
if window_size is None: | ||
real_window_size = (-1, -1) | ||
else: | ||
assert len(window_size) == 2 | ||
real_window_size = (window_size[0], window_size[1]) | ||
return _flash_attn_varlen_func( | ||
q=q, | ||
k=k, | ||
v=v, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_q, | ||
max_seqlen_k=max_seqlen_k, | ||
softmax_scale=softmax_scale, | ||
causal=causal, | ||
window_size=real_window_size, | ||
softcap=softcap, | ||
alibi_slopes=alibi_slopes, | ||
block_table=block_table, | ||
) | ||
|
||
|
||
@flash_attn_varlen_func.register_fake # type: ignore | ||
def _( | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
cu_seqlens_q: torch.Tensor, | ||
cu_seqlens_k: torch.Tensor, | ||
max_seqlen_q: int, | ||
max_seqlen_k: int, | ||
softmax_scale: Optional[float] = None, | ||
causal: bool = False, | ||
window_size: Optional[List[int]] = None, | ||
softcap: float = 0.0, | ||
alibi_slopes: Optional[torch.Tensor] = None, | ||
block_table: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
return torch.empty_like(q) | ||
|
||
|
||
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[]) | ||
def flash_attn_with_kvcache( | ||
decode_query: torch.Tensor, | ||
key_cache: torch.Tensor, | ||
value_cache: torch.Tensor, | ||
cache_seqlens: Optional[torch.Tensor] = None, | ||
block_table: Optional[torch.Tensor] = None, | ||
softmax_scale: Optional[float] = None, | ||
causal: bool = False, | ||
alibi_slopes: Optional[torch.Tensor] = None, | ||
softcap: float = 0.0, | ||
) -> torch.Tensor: | ||
return _flash_attn_with_kvcache( | ||
decode_query, | ||
key_cache, | ||
value_cache, | ||
cache_seqlens=cache_seqlens, | ||
block_table=block_table, | ||
softmax_scale=softmax_scale, | ||
causal=causal, | ||
alibi_slopes=alibi_slopes, | ||
softcap=softcap, | ||
) | ||
|
||
|
||
@flash_attn_with_kvcache.register_fake # type: ignore | ||
def _( | ||
decode_query: torch.Tensor, | ||
key_cache: torch.Tensor, | ||
value_cache: torch.Tensor, | ||
cache_seqlens: Optional[torch.Tensor] = None, | ||
block_table: Optional[torch.Tensor] = None, | ||
softmax_scale: Optional[float] = None, | ||
causal: bool = False, | ||
alibi_slopes: Optional[torch.Tensor] = None, | ||
softcap: float = 0.0, | ||
) -> torch.Tensor: | ||
return torch.empty_like(decode_query) | ||
|
||
|
||
class FlashAttentionBackend(AttentionBackend): | ||
|
||
|
@@ -517,7 +618,7 @@ def forward( | |
# normal attention | ||
# When block_tables are not filled, it means q and k are the | ||
# prompt, and they have the same length. | ||
out = flash_attn_varlen_func( | ||
out = torch.ops.vllm.flash_attn_varlen_func( | ||
q=query, | ||
k=key, | ||
v=value, | ||
|
@@ -537,34 +638,36 @@ def forward( | |
# prefix-enabled attention | ||
assert prefill_meta.seq_lens is not None | ||
max_seq_len = max(prefill_meta.seq_lens) | ||
output[:num_prefill_tokens] = flash_attn_varlen_func( | ||
q=query, | ||
k=key_cache, | ||
v=value_cache, | ||
cu_seqlens_q=prefill_meta.query_start_loc, | ||
max_seqlen_q=prefill_meta.max_query_len, | ||
cu_seqlens_k=prefill_meta.seq_start_loc, | ||
max_seqlen_k=max_seq_len, | ||
output[: | ||
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
q=query, | ||
k=key_cache, | ||
v=value_cache, | ||
cu_seqlens_q=prefill_meta.query_start_loc, | ||
max_seqlen_q=prefill_meta.max_query_len, | ||
cu_seqlens_k=prefill_meta.seq_start_loc, | ||
max_seqlen_k=max_seq_len, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
alibi_slopes=self.alibi_slopes, | ||
block_table=prefill_meta.block_tables, | ||
softcap=self.logits_soft_cap, | ||
) | ||
|
||
if decode_meta := attn_metadata.decode_metadata: | ||
# Decoding run. | ||
output[ | ||
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache( | ||
decode_query.unsqueeze(1), | ||
key_cache, | ||
value_cache, | ||
block_table=decode_meta.block_tables, | ||
cache_seqlens=decode_meta.seq_lens_tensor, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
alibi_slopes=self.alibi_slopes, | ||
block_table=prefill_meta.block_tables, | ||
softcap=self.logits_soft_cap, | ||
) | ||
|
||
if decode_meta := attn_metadata.decode_metadata: | ||
# Decoding run. | ||
output[num_prefill_tokens:] = flash_attn_with_kvcache( | ||
decode_query.unsqueeze(1), | ||
key_cache, | ||
value_cache, | ||
block_table=decode_meta.block_tables, | ||
cache_seqlens=decode_meta.seq_lens_tensor, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
alibi_slopes=self.alibi_slopes, | ||
softcap=self.logits_soft_cap, | ||
).squeeze(1) | ||
).squeeze(1) | ||
|
||
# Reshape the output tensor. | ||
return output.view(num_tokens, hidden_size) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.