From c21782bc9a9011ea2a0911128d50ce04267a5dca Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 28 Feb 2025 08:32:08 -0500 Subject: [PATCH 01/11] Working changes Signed-off-by: Thomas Parnell --- examples/test.py | 37 ++++ vllm/attention/ops/prefix_prefill.py | 4 + vllm/platforms/cuda.py | 9 +- vllm/v1/attention/backends/rocm_attn.py | 278 +++++++++++++++++++++++- vllm/v1/core/scheduler.py | 8 + 5 files changed, 332 insertions(+), 4 deletions(-) create mode 100644 examples/test.py diff --git a/examples/test.py b/examples/test.py new file mode 100644 index 000000000000..02006d23bcbd --- /dev/null +++ b/examples/test.py @@ -0,0 +1,37 @@ +import os +from vllm import LLM, SamplingParams +import time +import numpy as np + +os.environ["VLLM_USE_V1"] = "1" + + +llm = LLM( + model="/net/storage149/autofs/css22/nmg/models/granite3.1-8b/base/", + dtype='float16', +) + +sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + +doc = "Switzerland,[d] officially the Swiss Confederation,[e] is a landlocked country located in west-central Europe.[f][13] It is bordered by Italy to the south, France to the west, Germany to the north, and Austria and Liechtenstein to the east. Switzerland is geographically divided among the Swiss Plateau, the Alps and the Jura; the Alps occupy the greater part of the territory, whereas most of the country's nearly 9 million people are concentrated on the plateau, which hosts its largest cities and economic centres, including Zurich, Geneva, and Lausanne.[14]" + +batch_size = 64 + +docs = [] + +for i in range(batch_size): + docs.append(doc) + +res = [] +for i in range(10): + t0 = time.time() + responses = llm.generate(docs, sampling_params) + t_elap = time.time()-t0 + res.append(t_elap) + +print(res) + +print("t_elap: %.2f seconds" % (np.median(res))) + +#print(responses) + diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 103c408ebbf4..01278a504026 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -76,6 +76,10 @@ def _fwd_kernel( cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) cur_batch_query_len = (cur_batch_in_all_stop_index - cur_batch_in_all_start_index) + + if cur_batch_query_len == 1: + return + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len # start position inside of the query diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2a4cac46c066..a2a8db2e8da2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -194,9 +194,12 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" if use_v1: - logger.info("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends.flash_attn." - "FlashAttentionBackend") + #logger.info("Using Flash Attention backend on V1 engine.") + #return ("vllm.v1.attention.backends.flash_attn." + # "FlashAttentionBackend") + logger.info("Using ROCm Attention backend on V1 engine.") + return ("vllm.v1.attention.backends.rocm_attn." + "ROCmAttentionBackend") if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 0f3fabf05fc2..98609304ef80 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -9,10 +9,242 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.logger import init_logger -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata, FlashAttentionMetadataBuilder logger = init_logger(__name__) +import os +import torch +import triton +import triton.language as tl + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +debug_flag = False + +@triton.jit +def kernel_paged_attention_2d( + output_ptr, # [num_seqs, num_query_heads, head_size] + query_ptr, # [num_seqs, num_query_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] + value_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + context_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + cu_q_len_ptr, # [num_seqs+1] + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + cache_block_stride: tl.constexpr, # int + block_table_stride: tl.constexpr, # int, should be equal to max_num_blocks_per_seq + query_stride_0: tl.constexpr, # int + query_stride_1: tl.constexpr, # int, should be equal to head_size + output_stride_0: tl.constexpr, # int + output_stride_1: tl.constexpr, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + x: tl.constexpr, + stride_k_cache_0: tl.constexpr, + stride_k_cache_1: tl.constexpr, + stride_k_cache_2: tl.constexpr, + stride_k_cache_3: tl.constexpr, + stride_k_cache_4: tl.constexpr, + stride_v_cache_0: tl.constexpr, + stride_v_cache_1: tl.constexpr, + stride_v_cache_2: tl.constexpr, + stride_v_cache_3: tl.constexpr, +): + seq_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + kv_head_idx = query_head_idx // num_queries_per_kv + + cur_batch_in_all_start_index = tl.load(cu_q_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(cu_q_len_ptr + seq_idx + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + + if cur_batch_query_len > 1: + return + + query_offset = seq_idx * query_stride_0 + query_head_idx * query_stride_1 + + # Q : (HEAD_SIZE,) + Q = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE)) + + block_table_offset = seq_idx * block_table_stride + + m = tl.full([1], float("-inf"), dtype=tl.float32) + l = tl.full([1], 1.0, dtype=tl.float32) + acc = tl.zeros([HEAD_SIZE], dtype=tl.float32) + + # context len for this particualr sequence + context_len = tl.load(context_lens_ptr + seq_idx) + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx) + + num_blocks = cdiv_fn(context_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + offs_d = tl.arange(0, HEAD_SIZE) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[:, None] * stride_v_cache_2 + + offs_n[None, :] * stride_v_cache_3) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K = tl.load(key_cache_ptr + k_offset) + + # V : (HEAD_SIZE, BLOCK_SIZE) + V = tl.load(value_cache_ptr + v_offset) + + tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + boundary = tl.full([BLOCK_SIZE], context_len, dtype=tl.int32) + mask_new = tmp < boundary + # S : (BLOCK_SIZE,) + S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32) + S += scale * tl.sum(K * Q[:, None], axis=0) + + if USE_ALIBI_SLOPES: + S += alibi_slope * (tmp - context_len + 1) + + # compute running maximum + # m_j : (1,) + m_j = tl.maximum(m, tl.max(S, axis=0)) + + # P : (BLOCK_SIZE,) + P = tl.exp(S - m_j) + + # l_j : (1,) + l_j = tl.sum(P, axis=0) + + # alpha : (1, ) + alpha = tl.exp(m - m_j) + + # acc : (BLOCK_SIZE,) + acc = acc * alpha + + # update constants + l = l * alpha + l_j + m = m_j + + # acc : (BLOCK_SIZE,) + acc += tl.sum(V * P[None, :], axis=1) + + # epilogue + acc = acc / l + + output_offset = seq_idx * output_stride_0 + query_head_idx * output_stride_1 + + tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), acc) + + +def paged_attention_triton_2d( + output, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + alibi_slopes, + block_size, + num_seqs, + num_query_heads, + num_queries_per_kv, + head_size, + cu_q_len, +): + use_alibi_slopes = alibi_slopes is not None + + #if len(key_cache.shape) == 5 and key_cache.shape[4] != 1: + # raise RuntimeError("5d kv cache not supported") + + if debug_flag and not torch.cuda.is_current_stream_capturing(): + torch.set_printoptions(threshold=10_000) + print("\nnum_seqs: ", num_seqs) + print("query shape: ", query.shape) + print("num query heads: ", num_query_heads) + print("context_lens: ", context_lens) + print("block_tables.shape: ", block_tables.shape) + print("key_cache.shape: ", key_cache.shape) + print("value_cache.shape: ", value_cache.shape) + print(block_tables) + print("query strides: ", query.stride(0), query.stride(1), query.stride(2)) + print("block_tables strides: ", block_tables.stride(0), block_tables.stride(1)) + print( + "key_cache strides: ", + key_cache.stride(0), + key_cache.stride(1), + key_cache.stride(2), + key_cache.stride(3), + ) + print("output strides: ", output.stride(0), output.stride(1), output.stride(2)) + print( + "value_cache strides: ", + value_cache.stride(0), + value_cache.stride(1), + value_cache.stride(2), + value_cache.stride(3), + ) + print("context_lens stride: ", context_lens.stride(0)) + if alibi_slopes is not None: + print("alibi_slobes stride: ", alibi_slopes.stride(0)) + + kernel_paged_attention_2d[ + ( + num_seqs, + num_query_heads, + ) + ]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_tables, + context_lens_ptr=context_lens, + alibi_slopes_ptr=alibi_slopes, + scale=scale, + cu_q_len_ptr=cu_q_len, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + cache_block_stride=key_cache.stride(0), + block_table_stride=block_tables.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + USE_ALIBI_SLOPES=use_alibi_slopes, + x=key_cache.shape[4], + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4), + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + ) + class ROCmAttentionBackend(AttentionBackend): @@ -34,6 +266,10 @@ def get_impl_cls() -> Type["ROCmAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -150,6 +386,28 @@ def forward( layer._v_scale, ) + num_queries_per_kv = (query.shape[1] // key.shape[1]) + + ''' + print("num_actual_tokens: ", num_actual_tokens) + print("query.shape: ", query.shape) + print("key.shape: ", key.shape) + print("value.shape: ", value.shape) + print("output.shape: ", output.shape) + print("key_cache.shape: ", key_cache.shape) + print("value_cache.shape: ", value_cache.shape) + print("query_start_loc: ", attn_metadata.query_start_loc) + print("seq_lens: ", attn_metadata.seq_lens) + print("num_seqs: ", len(attn_metadata.seq_lens)) + print("num_queries_per_kv: ", num_queries_per_kv) + print("block_table.shape: ", attn_metadata.block_table.shape) + print("block_table.stride: ", attn_metadata.block_table.stride()) + print("output.stride: ", output.stride()) + print("seq_lens.stride: ", attn_metadata.seq_lens.stride()) + print("alibi_slopes: ", self.alibi_slopes) + print("sliding_window: ", self.sliding_window[0]) + ''' + # Compute attention and update output up to `num_actual_tokens`. context_attention_fwd(q=query[:num_actual_tokens], k=key[:num_actual_tokens], @@ -167,4 +425,22 @@ def forward( alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0], sm_scale=self.scale) + + paged_attention_triton_2d( + output=output[:num_actual_tokens], + query=query[:num_actual_tokens], + key_cache=key_cache, + value_cache=value_cache, + scale=self.scale, + cu_q_len=attn_metadata.query_start_loc, + block_tables=attn_metadata.block_table, + context_lens=attn_metadata.seq_lens, + alibi_slopes=self.alibi_slopes, + block_size=16, + num_seqs=len(attn_metadata.seq_lens), + num_query_heads=query.shape[1], + num_queries_per_kv=num_queries_per_kv, + head_size=query.shape[2] + ) + return output diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 87c9c0cd12b7..960e667c822e 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -245,6 +245,12 @@ def schedule(self) -> "SchedulerOutput": # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks(request) + + #print("[schedule] request_id: ", request.request_id) + #print("[schedule] num_tokens: ", request.num_tokens) + #print("[schedule] num_computed_tokens: ", num_computed_tokens) + + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed requests, @@ -312,6 +318,8 @@ def schedule(self) -> "SchedulerOutput": # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + #print("[schedule] total_num_scheduled_tokens: ", total_num_scheduled_tokens) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs From 00bafc00ee412234855fd9719a506c6ff720b55c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 28 Feb 2025 11:06:08 -0500 Subject: [PATCH 02/11] Use ibm_triton_lib Signed-off-by: Thomas Parnell --- examples/test.py | 2 +- vllm/v1/attention/backends/rocm_attn.py | 264 +----------------------- vllm/v1/core/scheduler.py | 8 - 3 files changed, 8 insertions(+), 266 deletions(-) diff --git a/examples/test.py b/examples/test.py index 02006d23bcbd..5e90e69de398 100644 --- a/examples/test.py +++ b/examples/test.py @@ -33,5 +33,5 @@ print("t_elap: %.2f seconds" % (np.median(res))) -#print(responses) +print(responses[0]) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 98609304ef80..41cd6d68752a 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -11,239 +11,9 @@ from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata, FlashAttentionMetadataBuilder -logger = init_logger(__name__) - -import os -import torch -import triton -import triton.language as tl - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - - -debug_flag = False - -@triton.jit -def kernel_paged_attention_2d( - output_ptr, # [num_seqs, num_query_heads, head_size] - query_ptr, # [num_seqs, num_query_heads, head_size] - key_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] - value_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - context_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - scale, # float32 - cu_q_len_ptr, # [num_seqs+1] - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - cache_block_stride: tl.constexpr, # int - block_table_stride: tl.constexpr, # int, should be equal to max_num_blocks_per_seq - query_stride_0: tl.constexpr, # int - query_stride_1: tl.constexpr, # int, should be equal to head_size - output_stride_0: tl.constexpr, # int - output_stride_1: tl.constexpr, # int, should be equal to head_size - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - x: tl.constexpr, - stride_k_cache_0: tl.constexpr, - stride_k_cache_1: tl.constexpr, - stride_k_cache_2: tl.constexpr, - stride_k_cache_3: tl.constexpr, - stride_k_cache_4: tl.constexpr, - stride_v_cache_0: tl.constexpr, - stride_v_cache_1: tl.constexpr, - stride_v_cache_2: tl.constexpr, - stride_v_cache_3: tl.constexpr, -): - seq_idx = tl.program_id(0) - query_head_idx = tl.program_id(1) - kv_head_idx = query_head_idx // num_queries_per_kv - - cur_batch_in_all_start_index = tl.load(cu_q_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(cu_q_len_ptr + seq_idx + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - - if cur_batch_query_len > 1: - return - - query_offset = seq_idx * query_stride_0 + query_head_idx * query_stride_1 - - # Q : (HEAD_SIZE,) - Q = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE)) - - block_table_offset = seq_idx * block_table_stride - - m = tl.full([1], float("-inf"), dtype=tl.float32) - l = tl.full([1], 1.0, dtype=tl.float32) - acc = tl.zeros([HEAD_SIZE], dtype=tl.float32) - - # context len for this particualr sequence - context_len = tl.load(context_lens_ptr + seq_idx) - - # alibi slope for this head - if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx) - - num_blocks = cdiv_fn(context_len, BLOCK_SIZE) - - # iterate through tiles - for j in range(0, num_blocks): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) - - offs_n = tl.arange(0, BLOCK_SIZE) - offs_d = tl.arange(0, HEAD_SIZE) - - v_offset = (physical_block_idx * stride_v_cache_0 + - kv_head_idx * stride_v_cache_1 + - offs_d[:, None] * stride_v_cache_2 + - offs_n[None, :] * stride_v_cache_3) - - k_offset = (physical_block_idx * stride_k_cache_0 + - kv_head_idx * stride_k_cache_1 + - (offs_d[:, None] // x) * stride_k_cache_2 + - offs_n[None, :] * stride_k_cache_3 + - (offs_d[:, None] % x) * stride_k_cache_4) - - # K : (HEAD_SIZE, BLOCK_SIZE) - K = tl.load(key_cache_ptr + k_offset) - - # V : (HEAD_SIZE, BLOCK_SIZE) - V = tl.load(value_cache_ptr + v_offset) - - tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - boundary = tl.full([BLOCK_SIZE], context_len, dtype=tl.int32) - mask_new = tmp < boundary - # S : (BLOCK_SIZE,) - S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32) - S += scale * tl.sum(K * Q[:, None], axis=0) +from ibm_triton_lib.kernels import paged_attention_2d - if USE_ALIBI_SLOPES: - S += alibi_slope * (tmp - context_len + 1) - - # compute running maximum - # m_j : (1,) - m_j = tl.maximum(m, tl.max(S, axis=0)) - - # P : (BLOCK_SIZE,) - P = tl.exp(S - m_j) - - # l_j : (1,) - l_j = tl.sum(P, axis=0) - - # alpha : (1, ) - alpha = tl.exp(m - m_j) - - # acc : (BLOCK_SIZE,) - acc = acc * alpha - - # update constants - l = l * alpha + l_j - m = m_j - - # acc : (BLOCK_SIZE,) - acc += tl.sum(V * P[None, :], axis=1) - - # epilogue - acc = acc / l - - output_offset = seq_idx * output_stride_0 + query_head_idx * output_stride_1 - - tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), acc) - - -def paged_attention_triton_2d( - output, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - alibi_slopes, - block_size, - num_seqs, - num_query_heads, - num_queries_per_kv, - head_size, - cu_q_len, -): - use_alibi_slopes = alibi_slopes is not None - - #if len(key_cache.shape) == 5 and key_cache.shape[4] != 1: - # raise RuntimeError("5d kv cache not supported") - - if debug_flag and not torch.cuda.is_current_stream_capturing(): - torch.set_printoptions(threshold=10_000) - print("\nnum_seqs: ", num_seqs) - print("query shape: ", query.shape) - print("num query heads: ", num_query_heads) - print("context_lens: ", context_lens) - print("block_tables.shape: ", block_tables.shape) - print("key_cache.shape: ", key_cache.shape) - print("value_cache.shape: ", value_cache.shape) - print(block_tables) - print("query strides: ", query.stride(0), query.stride(1), query.stride(2)) - print("block_tables strides: ", block_tables.stride(0), block_tables.stride(1)) - print( - "key_cache strides: ", - key_cache.stride(0), - key_cache.stride(1), - key_cache.stride(2), - key_cache.stride(3), - ) - print("output strides: ", output.stride(0), output.stride(1), output.stride(2)) - print( - "value_cache strides: ", - value_cache.stride(0), - value_cache.stride(1), - value_cache.stride(2), - value_cache.stride(3), - ) - print("context_lens stride: ", context_lens.stride(0)) - if alibi_slopes is not None: - print("alibi_slobes stride: ", alibi_slopes.stride(0)) - - kernel_paged_attention_2d[ - ( - num_seqs, - num_query_heads, - ) - ]( - output_ptr=output, - query_ptr=query, - key_cache_ptr=key_cache, - value_cache_ptr=value_cache, - block_tables_ptr=block_tables, - context_lens_ptr=context_lens, - alibi_slopes_ptr=alibi_slopes, - scale=scale, - cu_q_len_ptr=cu_q_len, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - cache_block_stride=key_cache.stride(0), - block_table_stride=block_tables.stride(0), - query_stride_0=query.stride(0), - query_stride_1=query.stride(1), - output_stride_0=output.stride(0), - output_stride_1=output.stride(1), - BLOCK_SIZE=block_size, - HEAD_SIZE=head_size, - USE_ALIBI_SLOPES=use_alibi_slopes, - x=key_cache.shape[4], - stride_k_cache_0=key_cache.stride(0), - stride_k_cache_1=key_cache.stride(1), - stride_k_cache_2=key_cache.stride(2), - stride_k_cache_3=key_cache.stride(3), - stride_k_cache_4=key_cache.stride(4), - stride_v_cache_0=value_cache.stride(0), - stride_v_cache_1=value_cache.stride(1), - stride_v_cache_2=value_cache.stride(2), - stride_v_cache_3=value_cache.stride(3), - ) +logger = init_logger(__name__) class ROCmAttentionBackend(AttentionBackend): @@ -386,29 +156,8 @@ def forward( layer._v_scale, ) - num_queries_per_kv = (query.shape[1] // key.shape[1]) - - ''' - print("num_actual_tokens: ", num_actual_tokens) - print("query.shape: ", query.shape) - print("key.shape: ", key.shape) - print("value.shape: ", value.shape) - print("output.shape: ", output.shape) - print("key_cache.shape: ", key_cache.shape) - print("value_cache.shape: ", value_cache.shape) - print("query_start_loc: ", attn_metadata.query_start_loc) - print("seq_lens: ", attn_metadata.seq_lens) - print("num_seqs: ", len(attn_metadata.seq_lens)) - print("num_queries_per_kv: ", num_queries_per_kv) - print("block_table.shape: ", attn_metadata.block_table.shape) - print("block_table.stride: ", attn_metadata.block_table.stride()) - print("output.stride: ", output.stride()) - print("seq_lens.stride: ", attn_metadata.seq_lens.stride()) - print("alibi_slopes: ", self.alibi_slopes) - print("sliding_window: ", self.sliding_window[0]) - ''' - # Compute attention and update output up to `num_actual_tokens`. + # do prefill and prefix prefills context_attention_fwd(q=query[:num_actual_tokens], k=key[:num_actual_tokens], v=value[:num_actual_tokens], @@ -426,7 +175,8 @@ def forward( sliding_window=self.sliding_window[0], sm_scale=self.scale) - paged_attention_triton_2d( + # Call second kernel (concurrently) to do decodes + paged_attention_2d( output=output[:num_actual_tokens], query=query[:num_actual_tokens], key_cache=key_cache, @@ -436,10 +186,10 @@ def forward( block_tables=attn_metadata.block_table, context_lens=attn_metadata.seq_lens, alibi_slopes=self.alibi_slopes, - block_size=16, + block_size=value_cache.shape[3], num_seqs=len(attn_metadata.seq_lens), num_query_heads=query.shape[1], - num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv=(query.shape[1] // key.shape[1]), head_size=query.shape[2] ) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 960e667c822e..87c9c0cd12b7 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -245,12 +245,6 @@ def schedule(self) -> "SchedulerOutput": # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks(request) - - #print("[schedule] request_id: ", request.request_id) - #print("[schedule] num_tokens: ", request.num_tokens) - #print("[schedule] num_computed_tokens: ", num_computed_tokens) - - # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed requests, @@ -318,8 +312,6 @@ def schedule(self) -> "SchedulerOutput": # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) - #print("[schedule] total_num_scheduled_tokens: ", total_num_scheduled_tokens) - assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs From 152e2347a17ea1a1b990c6ff3639122484fd01b2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 3 Mar 2025 08:27:03 -0500 Subject: [PATCH 03/11] working changes Signed-off-by: Thomas Parnell --- examples/test.py | 8 ++-- vllm/attention/ops/prefix_prefill.py | 4 -- vllm/v1/attention/backends/rocm_attn.py | 54 ++++++++----------------- 3 files changed, 21 insertions(+), 45 deletions(-) diff --git a/examples/test.py b/examples/test.py index 5e90e69de398..70b1ec19cd93 100644 --- a/examples/test.py +++ b/examples/test.py @@ -7,7 +7,7 @@ llm = LLM( - model="/net/storage149/autofs/css22/nmg/models/granite3.1-8b/base/", + model="/net/storage149/autofs/css22/nmg/models/llama3.1-8b/instruct/", dtype='float16', ) @@ -15,7 +15,7 @@ doc = "Switzerland,[d] officially the Swiss Confederation,[e] is a landlocked country located in west-central Europe.[f][13] It is bordered by Italy to the south, France to the west, Germany to the north, and Austria and Liechtenstein to the east. Switzerland is geographically divided among the Swiss Plateau, the Alps and the Jura; the Alps occupy the greater part of the territory, whereas most of the country's nearly 9 million people are concentrated on the plateau, which hosts its largest cities and economic centres, including Zurich, Geneva, and Lausanne.[14]" -batch_size = 64 +batch_size = 2 docs = [] @@ -23,7 +23,7 @@ docs.append(doc) res = [] -for i in range(10): +for i in range(1): t0 = time.time() responses = llm.generate(docs, sampling_params) t_elap = time.time()-t0 @@ -33,5 +33,5 @@ print("t_elap: %.2f seconds" % (np.median(res))) -print(responses[0]) +print(responses) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 01278a504026..103c408ebbf4 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -76,10 +76,6 @@ def _fwd_kernel( cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) cur_batch_query_len = (cur_batch_in_all_stop_index - cur_batch_in_all_start_index) - - if cur_batch_query_len == 1: - return - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len # start position inside of the query diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 41cd6d68752a..e18684f1dce8 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -7,11 +7,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata, FlashAttentionMetadataBuilder -from ibm_triton_lib.kernels import paged_attention_2d +from ibm_triton_lib.kernels import chunked_prefill_paged_decode logger = init_logger(__name__) @@ -157,40 +156,21 @@ def forward( ) # Compute attention and update output up to `num_actual_tokens`. - # do prefill and prefix prefills - context_attention_fwd(q=query[:num_actual_tokens], - k=key[:num_actual_tokens], - v=value[:num_actual_tokens], - o=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - k_cache=key_cache, - v_cache=value_cache, - b_loc=attn_metadata.block_table, - b_start_loc=attn_metadata.query_start_loc, - b_seq_len=attn_metadata.seq_lens, - max_input_len=attn_metadata.max_query_len, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) - - # Call second kernel (concurrently) to do decodes - paged_attention_2d( - output=output[:num_actual_tokens], - query=query[:num_actual_tokens], - key_cache=key_cache, - value_cache=value_cache, - scale=self.scale, - cu_q_len=attn_metadata.query_start_loc, - block_tables=attn_metadata.block_table, - context_lens=attn_metadata.seq_lens, - alibi_slopes=self.alibi_slopes, - block_size=value_cache.shape[3], - num_seqs=len(attn_metadata.seq_lens), - num_query_heads=query.shape[1], - num_queries_per_kv=(query.shape[1] // key.shape[1]), - head_size=query.shape[2] - ) + chunked_prefill_paged_decode(query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=attn_metadata.block_table, + query_start_loc=attn_metadata.query_start_loc, + seq_lens=attn_metadata.seq_lens, + max_query_len=attn_metadata.max_query_len, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + scale=self.scale) return output From c7087bcfe68982480dd4d162b405d181c39ae55f Mon Sep 17 00:00:00 2001 From: Burkhard Ringlein Date: Mon, 3 Mar 2025 12:14:26 -0500 Subject: [PATCH 04/11] making backend enabled via env variable Signed-off-by: Burkhard Ringlein --- vllm/platforms/cuda.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a2a8db2e8da2..d68b771ff9c3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -194,12 +194,13 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" if use_v1: - #logger.info("Using Flash Attention backend on V1 engine.") - #return ("vllm.v1.attention.backends.flash_attn." - # "FlashAttentionBackend") - logger.info("Using ROCm Attention backend on V1 engine.") - return ("vllm.v1.attention.backends.rocm_attn." - "ROCmAttentionBackend") + if os.environ.get('VLLM_V1_USE_TRITON_BACKEND', '0') == '1': + logger.info("Using ROCm Attention backend on V1 engine.") + return ("vllm.v1.attention.backends.rocm_attn." + "ROCmAttentionBackend") + logger.info("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends.flash_attn." + "FlashAttentionBackend") if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" From 576e1b15d3d54489f413560f09c81cb78387c727 Mon Sep 17 00:00:00 2001 From: Burkhard Ringlein Date: Mon, 3 Mar 2025 12:35:12 -0500 Subject: [PATCH 05/11] cleanup Signed-off-by: Burkhard Ringlein --- examples/test.py | 37 ------------------------------------- vllm/platforms/cuda.py | 2 +- 2 files changed, 1 insertion(+), 38 deletions(-) delete mode 100644 examples/test.py diff --git a/examples/test.py b/examples/test.py deleted file mode 100644 index 70b1ec19cd93..000000000000 --- a/examples/test.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from vllm import LLM, SamplingParams -import time -import numpy as np - -os.environ["VLLM_USE_V1"] = "1" - - -llm = LLM( - model="/net/storage149/autofs/css22/nmg/models/llama3.1-8b/instruct/", - dtype='float16', -) - -sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - -doc = "Switzerland,[d] officially the Swiss Confederation,[e] is a landlocked country located in west-central Europe.[f][13] It is bordered by Italy to the south, France to the west, Germany to the north, and Austria and Liechtenstein to the east. Switzerland is geographically divided among the Swiss Plateau, the Alps and the Jura; the Alps occupy the greater part of the territory, whereas most of the country's nearly 9 million people are concentrated on the plateau, which hosts its largest cities and economic centres, including Zurich, Geneva, and Lausanne.[14]" - -batch_size = 2 - -docs = [] - -for i in range(batch_size): - docs.append(doc) - -res = [] -for i in range(1): - t0 = time.time() - responses = llm.generate(docs, sampling_params) - t_elap = time.time()-t0 - res.append(t_elap) - -print(res) - -print("t_elap: %.2f seconds" % (np.median(res))) - -print(responses) - diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index d68b771ff9c3..5383063b7307 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -194,7 +194,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" if use_v1: - if os.environ.get('VLLM_V1_USE_TRITON_BACKEND', '0') == '1': + if os.environ.get("VLLM_V1_USE_TRITON_BACKEND", "0") == "1": logger.info("Using ROCm Attention backend on V1 engine.") return ("vllm.v1.attention.backends.rocm_attn." "ROCmAttentionBackend") From e9485d2fdee72d34d44a3cc0f4ca27d1ba6fce8c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 3 Mar 2025 13:59:45 -0500 Subject: [PATCH 06/11] Fix merge error Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/rocm_attn.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 70896e0959a5..d75ff089fd6e 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -36,10 +36,6 @@ def get_impl_cls() -> type["ROCmAttentionImpl"]: def get_metadata_cls() -> type["AttentionMetadata"]: return FlashAttentionMetadata - @staticmethod - def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: - return FlashAttentionMetadataBuilder - @staticmethod def get_kv_cache_shape( num_blocks: int, From 8a7a883f42bccc0f14dd40e79ca58a1ea07567fb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 4 Mar 2025 14:46:20 -0500 Subject: [PATCH 07/11] Add chunked_prefill_paged_decode kernel. Co-authored-by: Burkhard Ringlein Co-authored-by: Jan van Lunteren Signed-off-by: Thomas Parnell --- .../ops/chunked_prefill_paged_decode.py | 355 ++++++++++++++++++ vllm/attention/ops/prefix_prefill.py | 14 +- vllm/v1/attention/backends/rocm_attn.py | 37 +- 3 files changed, 387 insertions(+), 19 deletions(-) create mode 100644 vllm/attention/ops/chunked_prefill_paged_decode.py diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py new file mode 100644 index 000000000000..e26f0ef48275 --- /dev/null +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + +from .prefix_prefill import context_attention_fwd + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def kernel_paged_attention_2d( + output_ptr, # [num_seqs, num_query_heads, head_size] + query_ptr, # [num_seqs, num_query_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] + value_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + context_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.constexpr, # int + query_stride_0: tl.constexpr, # int + query_stride_1: tl.constexpr, # int, should be equal to head_size + output_stride_0: tl.constexpr, # int + output_stride_1: tl.constexpr, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + x: tl.constexpr, # int + stride_k_cache_0: tl.constexpr, # int + stride_k_cache_1: tl.constexpr, # int + stride_k_cache_2: tl.constexpr, # int + stride_k_cache_3: tl.constexpr, # int + stride_k_cache_4: tl.constexpr, # int + stride_v_cache_0: tl.constexpr, # int + stride_v_cache_1: tl.constexpr, # int + stride_v_cache_2: tl.constexpr, # int + stride_v_cache_3: tl.constexpr, # int + filter_by_query_len: tl.constexpr, # bool + query_start_len_ptr, # [num_seqs+1] +): + seq_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + kv_head_idx = query_head_idx // num_queries_per_kv + + if filter_by_query_len: + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + + 1) + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if cur_batch_query_len > 1: + return + else: + cur_batch_in_all_start_index = seq_idx + + query_offset = (cur_batch_in_all_start_index * query_stride_0 + + query_head_idx * query_stride_1) + + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # Q : (HEAD_SIZE,) + Q = tl.load( + query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED), + mask=dim_mask, + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([1], float("-inf"), dtype=tl.float32) + L = tl.full([1], 1.0, dtype=tl.float32) + acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32) + + # context len for this particular sequence + context_len = tl.load(context_lens_ptr + seq_idx) + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx) + + num_blocks = cdiv_fn(context_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + offs_d = tl.arange(0, HEAD_SIZE) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[:, None] * stride_v_cache_2 + + offs_n[None, :] * stride_v_cache_3) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (HEAD_SIZE, BLOCK_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[:, None], + other=0.0) + + if V_load.dtype.is_fp8(): + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + boundary = tl.full([BLOCK_SIZE], context_len, dtype=tl.int32) + mask_new = tmp < boundary + # S : (BLOCK_SIZE,) + S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32) + S += scale * tl.sum(K * Q[:, None], axis=0) + + if USE_ALIBI_SLOPES: + S += alibi_slope * (tmp - context_len + 1) + + # compute running maximum + # m_j : (1,) + m_j = tl.maximum(M, tl.max(S, axis=0)) + + # P : (BLOCK_SIZE,) + P = tl.exp(S - m_j) + + # l_j : (1,) + l_j = tl.sum(P, axis=0) + + # alpha : (1, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_SIZE,) + acc = acc * alpha + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_SIZE,) + acc += tl.sum(V * P[None, :], axis=1) + + # epilogue + acc = acc / L + + output_offset = (cur_batch_in_all_start_index * output_stride_0 + + query_head_idx * output_stride_1) + + tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE_PADDED), + acc, + mask=dim_mask) + + +def paged_attention_triton_2d( + output, + query, + key_cache, + value_cache, + scale, + k_scale, + v_scale, + kv_cache_dtype, + block_tables, + context_lens, + alibi_slopes, + block_size, + num_seqs, + num_query_heads, + num_queries_per_kv, + head_size, +): + use_alibi_slopes = alibi_slopes is not None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert key_cache.dtype == torch.uint8 + assert value_cache.dtype == torch.uint8 + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + key_cache = key_cache.view(target_dtype) + value_cache = value_cache.view(target_dtype) + + kernel_paged_attention_2d[( + num_seqs, + num_query_heads, + )]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_tables, + context_lens_ptr=context_lens, + alibi_slopes_ptr=alibi_slopes, + scale=scale, + k_scale=k_scale, + v_scale=v_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_tables.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + x=key_cache.shape[4] if len(key_cache.shape) == 5 else 1, + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4) + if len(key_cache.shape) == 5 else 1, + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + filter_by_query_len=False, + query_start_len_ptr=None, + ) + + +def chunked_prefill_paged_decode( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_table, + query_start_loc, + seq_lens, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + scale, +): + + use_alibi_slopes = alibi_slopes is not None + + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + kv_cache_dtype=kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=block_table, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_input_len=max_query_len, + k_scale=k_scale, + v_scale=v_scale, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window, + sm_scale=scale, + skip_decode=True, + ) + + block_size = value_cache.shape[3] + num_seqs = len(seq_lens) + num_query_heads = query.shape[1] + num_queries_per_kv = query.shape[1] // key.shape[1] + head_size = query.shape[2] + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert key_cache.dtype == torch.uint8 + assert value_cache.dtype == torch.uint8 + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + key_cache = key_cache.view(target_dtype) + value_cache = value_cache.view(target_dtype) + + kernel_paged_attention_2d[( + num_seqs, + num_query_heads, + )]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_table, + context_lens_ptr=seq_lens, + alibi_slopes_ptr=alibi_slopes, + scale=scale, + k_scale=k_scale, + v_scale=v_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + x=key_cache.shape[4], + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4), + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + filter_by_query_len=True, + query_start_len_ptr=query_start_loc, + ) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 103c408ebbf4..e85ec605ad2f 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -64,7 +64,9 @@ def _fwd_kernel( BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, + SKIP_DECODE: tl.constexpr, ): + cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -78,6 +80,9 @@ def _fwd_kernel( cur_batch_in_all_start_index) cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + if SKIP_DECODE and cur_batch_query_len == 1: + return + # start position inside of the query # generally, N goes over kv, while M goes over query_len block_start_loc = BLOCK_M * start_m @@ -500,6 +505,7 @@ def _fwd_kernel_alibi( BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, ): # attn_bias[] cur_batch = tl.program_id(0) @@ -518,6 +524,9 @@ def _fwd_kernel_alibi( cur_batch_in_all_start_index) cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + if SKIP_DECODE and cur_batch_query_len == 1: + return + block_start_loc = BLOCK_M * start_m # initialize offsets @@ -721,7 +730,8 @@ def context_attention_fwd(q, v_scale: torch.Tensor, alibi_slopes=None, sliding_window=None, - sm_scale=None): + sm_scale=None, + skip_decode=False): q_dtype_is_f32 = q.dtype is torch.float32 # need to reduce num. blocks when using fp32 @@ -823,6 +833,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, + SKIP_DECODE=skip_decode, num_warps=NUM_WARPS, num_stages=1, ) @@ -875,6 +886,7 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, num_warps=NUM_WARPS, num_stages=1, ) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index d75ff089fd6e..18f9d4fcdc79 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -6,13 +6,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.ops.chunked_prefill_paged_decode import ( + chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) -from ibm_triton_lib.kernels import chunked_prefill_paged_decode - logger = init_logger(__name__) @@ -157,21 +157,22 @@ def forward( ) # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=attn_metadata.block_table, - query_start_loc=attn_metadata.query_start_loc, - seq_lens=attn_metadata.seq_lens, - max_query_len=attn_metadata.max_query_len, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - scale=self.scale) + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=attn_metadata.block_table, + query_start_loc=attn_metadata.query_start_loc, + seq_lens=attn_metadata.seq_lens, + max_query_len=attn_metadata.max_query_len, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + scale=self.scale) return output From dc1029ac42c277c9f7e20886da3a15f02f4ab199 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 4 Mar 2025 17:16:56 -0500 Subject: [PATCH 08/11] Added unit tests for chunked_prefill_paged_decode Signed-off-by: Thomas Parnell --- tests/kernels/test_prefix_prefill.py | 135 ++++++++++-------- .../ops/chunked_prefill_paged_decode.py | 101 +++---------- vllm/v1/attention/backends/rocm_attn.py | 2 +- 3 files changed, 94 insertions(+), 144 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index c3ac6a37e717..3bdb8fff2c43 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -3,6 +3,7 @@ import math import random import time +from collections.abc import Callable import pytest import torch @@ -10,6 +11,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask from vllm.attention.backends.xformers import _make_alibi_bias +from vllm.attention.ops.chunked_prefill_paged_decode import ( + chunked_prefill_paged_decode) from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -24,6 +27,8 @@ SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] +OPS = [chunked_prefill_paged_decode, context_attention_fwd] + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @@ -32,6 +37,7 @@ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) +@pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, @@ -41,6 +47,7 @@ def test_contexted_kv_attention( dtype: torch.dtype, kv_cache_dtype: str, device: str, + op: Callable, ) -> None: if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( @@ -65,6 +72,9 @@ def test_contexted_kv_attention( block_size = 32 max_block_per_request = 64 query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + # ensure one sequence in batch is a decode + query_lens[-1] = 1 + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv @@ -144,36 +154,36 @@ def test_contexted_kv_attention( # Warm up the Triton kernel by calling it once before actually measuring # generation time - context_attention_fwd(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op(query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op(query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -228,7 +238,7 @@ def test_contexted_kv_attention( end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) - atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 + atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -238,6 +248,7 @@ def test_contexted_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_alibi( num_heads: int, @@ -246,6 +257,7 @@ def test_contexted_kv_attention_alibi( dtype: torch.dtype, kv_cache_dtype: str, device: str, + op: Callable, ) -> None: if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( @@ -375,36 +387,36 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # Warm up the Triton kernel by calling it once before actually measuring # generation time - context_attention_fwd(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op(query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op(query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -501,6 +513,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) +@pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_f32( num_heads: int, @@ -510,9 +523,11 @@ def test_contexted_kv_attention_f32( dtype: torch.dtype, kv_cache_dtype: str, device: str, + op: Callable, ) -> None: test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, - sliding_window, dtype, kv_cache_dtype, device) + sliding_window, dtype, kv_cache_dtype, device, + op) @pytest.mark.optional @@ -522,6 +537,7 @@ def test_contexted_kv_attention_f32( @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_alibi_f32( num_heads: int, @@ -530,6 +546,7 @@ def test_contexted_kv_attention_alibi_f32( dtype: torch.dtype, kv_cache_dtype: str, device: str, + op: Callable, ) -> None: test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, - dtype, kv_cache_dtype, device) + dtype, kv_cache_dtype, device, op) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index e26f0ef48275..55011b229898 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -35,6 +35,7 @@ def kernel_paged_attention_2d( HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int x: tl.constexpr, # int stride_k_cache_0: tl.constexpr, # int stride_k_cache_1: tl.constexpr, # int @@ -58,7 +59,6 @@ def kernel_paged_attention_2d( 1) cur_batch_query_len = cur_batch_in_all_stop_index \ - cur_batch_in_all_start_index - if cur_batch_query_len > 1: return else: @@ -98,7 +98,7 @@ def kernel_paged_attention_2d( physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) offs_n = tl.arange(0, BLOCK_SIZE) - offs_d = tl.arange(0, HEAD_SIZE) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) v_offset = (physical_block_idx * stride_v_cache_0 + kv_head_idx * stride_v_cache_1 + @@ -138,6 +138,9 @@ def kernel_paged_attention_2d( S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32) S += scale * tl.sum(K * Q[:, None], axis=0) + if SLIDING_WINDOW > 0: + S = tl.where((context_len - 1 - tmp) < SLIDING_WINDOW, S, -10000) + if USE_ALIBI_SLOPES: S += alibi_slope * (tmp - context_len + 1) @@ -175,83 +178,6 @@ def kernel_paged_attention_2d( mask=dim_mask) -def paged_attention_triton_2d( - output, - query, - key_cache, - value_cache, - scale, - k_scale, - v_scale, - kv_cache_dtype, - block_tables, - context_lens, - alibi_slopes, - block_size, - num_seqs, - num_query_heads, - num_queries_per_kv, - head_size, -): - use_alibi_slopes = alibi_slopes is not None - - # Conversion of FP8 Tensor from uint8 storage to - # appropriate torch.dtype for interpretation by Triton - if "fp8" in kv_cache_dtype: - assert key_cache.dtype == torch.uint8 - assert value_cache.dtype == torch.uint8 - - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = torch.float8_e4m3fn - elif kv_cache_dtype == "fp8_e5m2": - target_dtype = torch.float8_e5m2 - else: - raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) - - key_cache = key_cache.view(target_dtype) - value_cache = value_cache.view(target_dtype) - - kernel_paged_attention_2d[( - num_seqs, - num_query_heads, - )]( - output_ptr=output, - query_ptr=query, - key_cache_ptr=key_cache, - value_cache_ptr=value_cache, - block_tables_ptr=block_tables, - context_lens_ptr=context_lens, - alibi_slopes_ptr=alibi_slopes, - scale=scale, - k_scale=k_scale, - v_scale=v_scale, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_tables.stride(0), - query_stride_0=query.stride(0), - query_stride_1=query.stride(1), - output_stride_0=output.stride(0), - output_stride_1=output.stride(1), - BLOCK_SIZE=block_size, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - x=key_cache.shape[4] if len(key_cache.shape) == 5 else 1, - stride_k_cache_0=key_cache.stride(0), - stride_k_cache_1=key_cache.stride(1), - stride_k_cache_2=key_cache.stride(2), - stride_k_cache_3=key_cache.stride(3), - stride_k_cache_4=key_cache.stride(4) - if len(key_cache.shape) == 5 else 1, - stride_v_cache_0=value_cache.stride(0), - stride_v_cache_1=value_cache.stride(1), - stride_v_cache_2=value_cache.stride(2), - stride_v_cache_3=value_cache.stride(3), - filter_by_query_len=False, - query_start_len_ptr=None, - ) - - def chunked_prefill_paged_decode( query, key, @@ -266,13 +192,19 @@ def chunked_prefill_paged_decode( max_query_len, k_scale, v_scale, - alibi_slopes, - sliding_window, - scale, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, ): + if sm_scale is None: + sm_scale = 1.0 / (query.shape[1]**0.5) + use_alibi_slopes = alibi_slopes is not None + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + context_attention_fwd( q=query, k=key, @@ -289,7 +221,7 @@ def chunked_prefill_paged_decode( v_scale=v_scale, alibi_slopes=alibi_slopes, sliding_window=sliding_window, - sm_scale=scale, + sm_scale=sm_scale, skip_decode=True, ) @@ -326,7 +258,7 @@ def chunked_prefill_paged_decode( block_tables_ptr=block_table, context_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, - scale=scale, + scale=sm_scale, k_scale=k_scale, v_scale=v_scale, num_query_heads=num_query_heads, @@ -340,6 +272,7 @@ def chunked_prefill_paged_decode( HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, + SLIDING_WINDOW=sliding_window, x=key_cache.shape[4], stride_k_cache_0=key_cache.stride(0), stride_k_cache_1=key_cache.stride(1), diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 18f9d4fcdc79..640c3b3d4fbb 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -173,6 +173,6 @@ def forward( v_scale=layer._v_scale, alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0], - scale=self.scale) + sm_scale=self.scale) return output From b3b873bbbc4597871a70aeea3f3a26c6ac61cdd7 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 5 Mar 2025 08:57:11 -0500 Subject: [PATCH 09/11] Address review comments Signed-off-by: Thomas Parnell --- .../ops/chunked_prefill_paged_decode.py | 63 ++++++++++--------- vllm/platforms/cuda.py | 3 +- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 55011b229898..807a270b43de 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -14,12 +14,12 @@ def cdiv_fn(x, y): @triton.jit def kernel_paged_attention_2d( - output_ptr, # [num_seqs, num_query_heads, head_size] - query_ptr, # [num_seqs, num_query_heads, head_size] - key_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] - value_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - context_lens_ptr, # [num_seqs] + seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] scale, # float32 k_scale, # float32 @@ -83,14 +83,14 @@ def kernel_paged_attention_2d( L = tl.full([1], 1.0, dtype=tl.float32) acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32) - # context len for this particular sequence - context_len = tl.load(context_lens_ptr + seq_idx) + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) # alibi slope for this head if USE_ALIBI_SLOPES: alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx) - num_blocks = cdiv_fn(context_len, BLOCK_SIZE) + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # iterate through tiles for j in range(0, num_blocks): @@ -132,17 +132,17 @@ def kernel_paged_attention_2d( V = V_load tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - boundary = tl.full([BLOCK_SIZE], context_len, dtype=tl.int32) + boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) mask_new = tmp < boundary # S : (BLOCK_SIZE,) S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32) S += scale * tl.sum(K * Q[:, None], axis=0) if SLIDING_WINDOW > 0: - S = tl.where((context_len - 1 - tmp) < SLIDING_WINDOW, S, -10000) + S = tl.where((seq_len - 1 - tmp) < SLIDING_WINDOW, S, -10000) if USE_ALIBI_SLOPES: - S += alibi_slope * (tmp - context_len + 1) + S += alibi_slope * (tmp - seq_len + 1) # compute running maximum # m_j : (1,) @@ -205,25 +205,26 @@ def chunked_prefill_paged_decode( if sliding_window is None or sliding_window <= 0: sliding_window = 0 - context_attention_fwd( - q=query, - k=key, - v=value, - o=output, - kv_cache_dtype=kv_cache_dtype, - k_cache=key_cache, - v_cache=value_cache, - b_loc=block_table, - b_start_loc=query_start_loc, - b_seq_len=seq_lens, - max_input_len=max_query_len, - k_scale=k_scale, - v_scale=v_scale, - alibi_slopes=alibi_slopes, - sliding_window=sliding_window, - sm_scale=sm_scale, - skip_decode=True, - ) + if max_query_len > 1: + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + kv_cache_dtype=kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=block_table, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_input_len=max_query_len, + k_scale=k_scale, + v_scale=v_scale, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window, + sm_scale=sm_scale, + skip_decode=True, + ) block_size = value_cache.shape[3] num_seqs = len(seq_lens) @@ -256,7 +257,7 @@ def chunked_prefill_paged_decode( key_cache_ptr=key_cache, value_cache_ptr=value_cache, block_tables_ptr=block_table, - context_lens_ptr=seq_lens, + seq_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, scale=sm_scale, k_scale=k_scale, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a6d4de840b72..d11c173cc9ca 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -196,7 +196,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, return "vllm.attention.backends.triton_mla.TritonMLABackend" if use_v1: if os.environ.get("VLLM_V1_USE_TRITON_BACKEND", "0") == "1": - logger.info_once("Using ROCm Attention backend on V1 engine.") + logger.info_once( + "Using Triton/ROCm Attention backend on V1 engine.") return ("vllm.v1.attention.backends.rocm_attn." "ROCmAttentionBackend") logger.info_once("Using Flash Attention backend on V1 engine.") From 5baa8790ed824b2509554acc9e5d0ad82295070a Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 5 Mar 2025 14:22:59 -0500 Subject: [PATCH 10/11] Revert changes in cuda platform. Signed-off-by: Thomas Parnell --- vllm/platforms/cuda.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index d11c173cc9ca..bffa113cab89 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -195,11 +195,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" if use_v1: - if os.environ.get("VLLM_V1_USE_TRITON_BACKEND", "0") == "1": - logger.info_once( - "Using Triton/ROCm Attention backend on V1 engine.") - return ("vllm.v1.attention.backends.rocm_attn." - "ROCmAttentionBackend") logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends.flash_attn." "FlashAttentionBackend") From 16500bf817f2fb1af641e363a2ddc06449286ff2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 6 Mar 2025 02:39:05 -0500 Subject: [PATCH 11/11] Ensuring co-authors survive squash. Co-authored-by: Burkhard Ringlein Co-authored-by: Jan van Lunteren Signed-off-by: Thomas Parnell