Skip to content
2 changes: 2 additions & 0 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_seq_len,
alibi_slopes,
None, # TODO add custom bias
kv_cache_dtype,
k_scale,
v_scale,
Expand All @@ -138,6 +139,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_seq_len,
alibi_slopes,
None,
kv_cache_dtype,
k_scale,
v_scale,
Expand Down
32 changes: 24 additions & 8 deletions csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ __device__ void paged_attention_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len]
const int padded_max_seq_len, // Avoid recomputing from seq_lens.
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
Expand Down Expand Up @@ -154,6 +156,14 @@ __device__ void paged_attention_kernel(
const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];

// NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence
// and current head.
const float* attn_bias_vec =
attn_bias == nullptr
? nullptr
: attn_bias + seq_idx * num_heads * padded_max_seq_len +
head_idx * padded_max_seq_len;

// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
Expand Down Expand Up @@ -293,8 +303,10 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
// Add the ALiBi bias if slopes are given, then add custom bias if given.
// TODO mutually exclusive?
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0;

if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
Expand Down Expand Up @@ -512,6 +524,8 @@ __global__ void paged_attention_v1_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ attn_bias,
const int padded_max_seq_len, // Avoid recomputing from seq_lens.
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
Expand All @@ -520,9 +534,9 @@ __global__ void paged_attention_v1_kernel(
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
max_num_blocks_per_seq, alibi_slopes, attn_bias, padded_max_seq_len,
q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}

Expand All @@ -548,17 +562,19 @@ __global__ void paged_attention_v2_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ attn_bias,
const int padded_max_seq_len, // Avoid recomputing from seq_lens.
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias,
padded_max_seq_len, q_stride, kv_block_stride, kv_head_stride, k_scale,
v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step);
}

// Grid: (num_heads, num_seqs).
Expand Down
62 changes: 38 additions & 24 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);

// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
Expand All @@ -53,7 +53,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
const std::optional<torch::Tensor>& alibi_slopes,
const std::optional<torch::Tensor>& attn_bias, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
Expand All @@ -73,7 +74,21 @@ void paged_attention_v1_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;

const float* attn_bias_ptr =
attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr())
: nullptr;
const int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
if (attn_bias_ptr) {
const torch::Tensor& abias = attn_bias.value();
TORCH_CHECK(abias.dtype() == torch::kFloat32,
"Unsupported bias dtype: ", abias.dtype());
TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len,
"The last dimension of the attention bias must "
"match the block-aligned maximum sequence length (",
padded_max_seq_len,
"). However, the given dimensions are: ", abias.sizes());
}
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
Expand All @@ -84,13 +99,11 @@ void paged_attention_v1_launcher(
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
const int logits_size = padded_max_seq_len * sizeof(float);
const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
const int shared_mem_size = std::max(logits_size, outputs_size);

dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
Expand Down Expand Up @@ -137,8 +150,8 @@ void paged_attention_v1_launcher(
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
seq_lens, max_seq_len, alibi_slopes, attn_bias, k_scale, v_scale, \
tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);

#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
Expand Down Expand Up @@ -179,6 +192,7 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::optional<torch::Tensor>& attn_bias,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
Expand Down
37 changes: 27 additions & 10 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \
attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride, \
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
Expand All @@ -54,7 +55,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
const std::optional<torch::Tensor>& alibi_slopes,
const std::optional<torch::Tensor>& attn_bias, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
Expand All @@ -74,7 +76,21 @@ void paged_attention_v2_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;

const float* attn_bias_ptr =
attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr())
: nullptr;
const int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
if (attn_bias_ptr) {
const torch::Tensor& abias = attn_bias.value();
TORCH_CHECK(abias.dtype() == torch::kFloat32,
"Unsupported bias dtype: ", abias.dtype());
TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len,
"The last dimension of the attention bias must "
"match the block-aligned maximum sequence length (",
padded_max_seq_len,
"). However, the given dimensions are: ", abias.sizes());
}
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
Expand All @@ -88,16 +104,16 @@ void paged_attention_v2_launcher(
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
const int logits_size = PARTITION_SIZE * sizeof(float);
const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);

// For paged attention v2 kernel.
dim3 grid(num_heads, num_seqs, max_num_partitions);
int shared_mem_size = std::max(logits_size, outputs_size);
const int shared_mem_size = std::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
Expand Down Expand Up @@ -144,7 +160,7 @@ void paged_attention_v2_launcher(
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
attn_bias, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);

Expand Down Expand Up @@ -190,6 +206,7 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::optional<torch::Tensor>& attn_bias,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
Expand Down
8 changes: 7 additions & 1 deletion csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,17 @@ void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& attn_bias,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
TORCH_CHECK(!attn_bias.has_value(),
"CPU backend does not support custom attention bias.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
Expand Down Expand Up @@ -782,13 +785,16 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::optional<torch::Tensor>& attn_bias,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
TORCH_CHECK(!attn_bias.has_value(),
"CPU backend does not support custom attention bias.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
Expand Down
5 changes: 3 additions & 2 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
// TODO attn_bias on cpu
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
Expand All @@ -43,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
Expand Down
Loading