From b302bfb103bbcc21ebece4ced6dff15ad9bb3751 Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Tue, 17 Dec 2024 17:43:18 -0600 Subject: [PATCH 01/16] Dynamic Scale Factor Calculations for Key/Value Scales With FP8 KV Caching (#317) * Changed _k_scale and _v_scale to tensors * fixed rocm paged attention with tensor kv scales * Added on the fly scale factor calculation * trying to fix attn metadata * fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py * Changed K and V scale constants * Removed unneeded comment * Changes to pass format.sh, also fixed lingering k_scale/v_scale : float * Fix for TP > 1 * Ran format.sh * Removed legacy kv_scale loading from the json file * Removed the outdated kv cache docs * Revert some unwanted changes --------- Co-authored-by: Gregory Shtrasberg Signed-off-by: Gregory Shtrasberg --- csrc/attention/attention_kernels.cuh | 10 +- csrc/attention/paged_attention_v1.cu | 17 +- csrc/attention/paged_attention_v2.cu | 17 +- csrc/cache.h | 6 +- csrc/cache_kernels.cu | 30 +- csrc/ops.h | 10 +- csrc/rocm/attention.cu | 17 +- csrc/rocm/ops.h | 4 +- csrc/rocm/torch_bindings.cpp | 2 +- csrc/torch_bindings.cpp | 8 +- docs/source/quantization/fp8_e4m3_kvcache.rst | 10 +- examples/fp8/README.md | 96 ----- examples/fp8/extract_scales.py | 367 ------------------ examples/fp8/quantizer/README.md | 32 -- examples/fp8/quantizer/quantize.py | 367 ------------------ .../llama2-70b-fp8-kv/kv_cache_scales.json | 90 ----- .../llama2-7b-fp8-kv/kv_cache_scales.json | 42 -- .../models/decoder_only/language/test_fp8.py | 15 +- vllm/_custom_ops.py | 20 +- vllm/attention/backends/abstract.py | 11 +- vllm/attention/backends/blocksparse_attn.py | 4 +- vllm/attention/backends/rocm_flash_attn.py | 4 +- vllm/attention/layer.py | 20 +- vllm/attention/ops/paged_attn.py | 8 +- vllm/attention/ops/prefix_prefill.py | 4 +- vllm/config.py | 15 +- vllm/engine/arg_utils.py | 24 +- vllm/envs.py | 9 + .../layers/quantization/kv_cache.py | 19 +- .../model_loader/weight_utils.py | 44 +-- vllm/model_executor/models/exaone.py | 34 +- vllm/model_executor/models/granite.py | 31 +- vllm/model_executor/models/llama.py | 33 +- vllm/model_executor/models/solar.py | 34 +- vllm/worker/model_runner.py | 38 +- vllm/worker/model_runner_base.py | 3 +- 36 files changed, 180 insertions(+), 1315 deletions(-) delete mode 100644 examples/fp8/README.md delete mode 100644 examples/fp8/extract_scales.py delete mode 100644 examples/fp8/quantizer/README.md delete mode 100644 examples/fp8/quantizer/quantize.py delete mode 100644 tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json delete mode 100644 tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 563e1438f0b0..eb216dc8baf1 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -105,7 +105,7 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, k_scale); + k_vec_quant, *k_scale); } } @@ -415,7 +415,7 @@ __device__ void paged_attention_kernel( *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert(v_quant_vec, - v_scale); + *v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const c10::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -80,6 +80,8 @@ void paged_attention_v1_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = @@ -177,8 +179,9 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index c457bdb89008..6d2e4ac74041 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -37,7 +37,7 @@ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const c10::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -84,6 +84,8 @@ void paged_attention_v2_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); @@ -188,8 +190,9 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daa..eedad9fafa3c 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -18,15 +18,15 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale); + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - const double k_scale, const double v_scale); + torch::Tensor& k_scale, torch::Tensor& v_scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a95279f9a25..21a0aec0ecec 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( // block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, const float k_scale, - const float v_scale) { + const int head_size, const int block_size, const int x, + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, - const float k_scale, const float v_scale) { + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel( value_cache[tgt_key_value_idx] = tgt_value; } else { key_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, k_scale, v_scale); + num_heads, head_size, block_size, x, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -268,8 +270,8 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -299,7 +301,9 @@ void reshape_and_cache( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, k_scale, v_scale); + value_stride, num_heads, head_size, block_size, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -308,8 +312,8 @@ void reshape_and_cache_flash( torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { // NOTE(woosuk): In vLLM V1, key.size(0) can be different from // slot_mapping.size(0) because of padding for CUDA graphs. // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because diff --git a/csrc/ops.h b/csrc/ops.h index c145e4eda084..9102bdee90d3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -34,8 +34,9 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -45,8 +46,9 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index b48348a515c8..b638f130cd02 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale) { + int max_ctx_blocks, const float* k_scale, const float* v_scale) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; const _B8x8 Vlocalb8 = v_ptrh8be[d]; Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, v_scale); + scaled_convert_b8x8(Vlocalb8, *v_scale_ptr); } } } @@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], k_scale); + scaled_convert_b8x8(Klocalb8[d], *k_scale_ptr); } } @@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale) { + int max_ctx_blocks, const float* k_scale, const float* v_scale) { UNREACHABLE_CODE } @@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ - k_scale, v_scale); + k_scale_ptr, v_scale_ptr); template @@ -929,7 +929,7 @@ void paged_attention_custom_launcher( torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, const c10::optional& alibi_slopes, - float k_scale, float v_scale) { + torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -953,6 +953,8 @@ void paged_attention_custom_launcher( KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_num_partitions = @@ -1087,7 +1089,8 @@ void paged_attention( torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 9f085115a395..c5214d90a4ef 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -10,5 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, - double v_scale); + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index a283d4263d29..a5d2e2f97a3e 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 88a4e60c75cb..4d29de73668f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -439,7 +439,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. @@ -449,7 +449,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); diff --git a/docs/source/quantization/fp8_e4m3_kvcache.rst b/docs/source/quantization/fp8_e4m3_kvcache.rst index cc52d8f40af8..a9147f8fd8ff 100644 --- a/docs/source/quantization/fp8_e4m3_kvcache.rst +++ b/docs/source/quantization/fp8_e4m3_kvcache.rst @@ -30,18 +30,14 @@ Here is an example of how to enable this feature: .. code-block:: python - # two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to - # https://github.com/vllm-project/vllm/blob/main/examples/fp8/README.md to generate kv_cache_scales.json of your own. + # To calculate kv cache scales on the fly enable the calculate_kv_scales + # parameter from vllm import LLM, SamplingParams sampling_params = SamplingParams(temperature=1.3, top_p=0.8) llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", kv_cache_dtype="fp8", - quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") + calculate_kv_scales=True) prompt = "London is the capital of" out = llm.generate(prompt, sampling_params)[0].outputs[0].text print(out) - - # output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial, - # output w/o scaling factors: England, located in the southeastern part of the country. It is known - diff --git a/examples/fp8/README.md b/examples/fp8/README.md deleted file mode 100644 index 181c36558fcf..000000000000 --- a/examples/fp8/README.md +++ /dev/null @@ -1,96 +0,0 @@ -# FP8 KV Cache - -This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms. - -## Prerequisites - -- Python 3.x -- PyTorch -- NumPy -- Hugging Face Transformers -- Hugging Face Hub -- AMMO - -Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps: -1. Install all necessary prerequisites and dependencies. -2. Convert HF model into a quantized HF model. -3. Extract KV Cache Scaling Factors from quantized HF model. -4. Load KV Cache Scaling Factors into VLLM. - -### 2. Convert HF model into a quantized HF model. -Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md). - -`quantize.py` (examples/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format). - -The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/fp8/quantizer/README.md`. - -### 3. Extract KV Cache Scaling Factors from quantized HF model. -`extract_scales.py` (examples/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following: -1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename. - -2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM. - -3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks. - -```python -# prerequisites: -# - Quantized HF LLaMa 2 model -python3 examples/fp8/extract_scales.py --help -Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE] - -KV Scale Extraction Example - -optional arguments: ---quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU). -Optional arguments: ---cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None) ---load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto) ---revision: Specify the model's revision number. (Default: None) ---output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None) ---output_name: Specify the output filename. (Default: kv_cache_scales.json) ---tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None) -``` -```python -Example: -python3 examples/fp8/extract_scales.py --quantized_model --tp_size --output_dir -``` -### 4. Load KV Cache Scaling Factors into VLLM. -This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8. -```python -# prerequisites: -# - LLaMa 2 kv_cache_scales.json file - -python3 benchmarks/benchmark_throughput.py --help -usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL] - [--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] - [--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code] - [--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}] - [--quantization-param-path KV_CACHE_quantization_param_path] - -Benchmark Throughput Example -optional arguments: - -h, --help show this help message and exit - --backend {vllm,hf,mii} - --dataset DATASET Path to the dataset. - --input-len INPUT_LEN Input prompt length for each request - --output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset. - --model MODEL - --tokenizer TOKENIZER - --quantization {awq,gptq,None}, -q {awq,gptq,None} - --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE - --n N Number of generated sequences per prompt. - --use-beam-search - --num-prompts NUM_PROMPTS Number of prompts to process. - --seed SEED - --hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend. - --trust-remote-code trust remote code from huggingface - --max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model. - --dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. - --enforce-eager enforce eager execution - --kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria. - --quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria. -``` -``` -Example: -python3 benchmarks/benchmark_throughput.py --input-len --output-len -tp --kv-cache-dtype fp8 --quantization-param-path --model -```python diff --git a/examples/fp8/extract_scales.py b/examples/fp8/extract_scales.py deleted file mode 100644 index 1dce9d7e993a..000000000000 --- a/examples/fp8/extract_scales.py +++ /dev/null @@ -1,367 +0,0 @@ -import argparse -import glob -import json -import os -from typing import Any, Callable, Dict, List, Optional, Tuple - -import numpy as np -import torch -from safetensors.torch import safe_open - -from vllm.model_executor.layers.quantization.schema import QuantParamSchema - - -# Adapted from vllm/model_executor/model_loader/weight_utils.py -# The main differences are that we add the NPZ format and simplify -# its functionality drastically for our purposes (e.g. we assume that -# the quantized model exists locally and there is no need to download it) -def _prepare_hf_weights( - quantized_model_dir: str, - load_format: str = "auto", - fall_back_to_pt: bool = True, -) -> Tuple[List[str], bool]: - if not os.path.isdir(quantized_model_dir): - raise FileNotFoundError( - f"The quantized model directory `{quantized_model_dir}` " - "does not exist.") - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == "auto": - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == "safetensors": - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == "pt": - allow_patterns = ["*.pt"] - elif load_format == "npz": - allow_patterns = ["*.npz"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob( - os.path.join(quantized_model_dir, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - - if not use_safetensors: - # Exclude files that are not needed for inference. - # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 - blacklist = [ - "training_args.bin", - "optimizer.bin", - "optimizer.pt", - "scheduler.pt", - "scaler.pt", - ] - hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) - ] - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{quantized_model_dir}`") - - return hf_weights_files, use_safetensors - - -# Adapted from vllm/model_executor/model_loader/weight_utils.py -def _hf_tensorfile_iterator(filename: str, load_format: str, - use_safetensors: bool): - if load_format == "npz": - assert not use_safetensors - with np.load(filename) as data: - for name in data.files: - param = torch.from_numpy(data[name]) - yield name, param - elif use_safetensors: - with safe_open(filename, framework="pt") as f: - for name in f.keys(): # NOQA: SIM118 - param = f.get_tensor(name) - yield name, param - else: - state = torch.load(filename, map_location="cpu") - for name, param in state.items(): - yield name, param - del state - torch.cuda.empty_cache() - - -def _kv_scales_extractor( - hf_tensor_files: List[str], - use_safetensors: bool, - rank_keyword: str = "rank", - expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]: - """ - Given a list of files containing tensor data, attempt to extract KV cache - scales from these files. Intended as a helper function taking in the output - from _prepare_hf_weights. - Args: - rank_keyword Matches the number immediately after this keyword in the - tensor filename to determine the TP rank corresponding - to said tensor file - expected_tp_size If specified, the TP size of the tensor files is checked - against this and an error is raised if they don't match. - Returns a dictionary mapping TP ranks to their relevant KV cache scales. - The per-rank scales are themselves represented as a dictionary of layer - indices to the respective per-layer scale. - """ - for char in rank_keyword: - assert not char.isdecimal( - ), f"Rank keyword {rank_keyword} contains a numeric character!" - rank_scales_map: Dict[int, Dict[int, float]] = {} - for tensor_file in hf_tensor_files: - try: - rank_idx = tensor_file.find(rank_keyword) - if rank_idx != -1: - start_idx = rank_idx + len(rank_keyword) - stop_idx = start_idx - while stop_idx < len( - tensor_file) and tensor_file[stop_idx].isdecimal(): - stop_idx += 1 - if stop_idx == start_idx: - raise RuntimeError("Did not find rank # in filename.") - rank = int(tensor_file[start_idx:stop_idx]) - elif len(hf_tensor_files) == 1: - # Since there is only one tensor file, we can assume - # that it's intended for TP rank 0 - rank = 0 - else: - raise RuntimeError( - f"Filename does not contain '{rank_keyword}'.") - except RuntimeError: - print("Unable to determine TP rank " - f"corresponding to file '{tensor_file}'") - raise - - if rank not in rank_scales_map: - layer_scales_map: Dict[int, float] = {} - rank_scales_map[rank] = layer_scales_map - else: - raise RuntimeError( - f"Tensor file '{tensor_file}' shares TP rank {rank} " - "with another tensor file.") - - module_delimiter = ":" if args.load_format == "npz" else "." - for name, param in _hf_tensorfile_iterator(tensor_file, - args.load_format, - use_safetensors): - if "kv_cache_scaling_factor" in name: - nums = [ - int(s) for s in name.split(module_delimiter) - if s.isdecimal() - ] - assert len( - nums) == 1, f"Could not determine layer idx for {name}" - layer_idx = nums[0] - assert layer_idx not in layer_scales_map, f"Duplicate scaling"\ - f" factor corresponding to layer {layer_idx}" - try: - layer_scales_map[layer_idx] = param.item() - except RuntimeError: - print( - "This utility supports only per-tensor scalar scales " - f"for now. The tensor\n {name} = {param} \nis an " - "invalid scale factor.") - raise - - if all( - len(layer_scales_map) == 0 - for layer_scales_map in rank_scales_map.values()): - # Note: this is true even if the rank_scales_map is empty - print("WARNING: No KV cache scale factors found. No output saved.") - return None - empirical_tp_world_size = max(rank_scales_map.keys()) + 1 - if expected_tp_size is not None: - assert expected_tp_size == empirical_tp_world_size, \ - f"User expected TP world size = {expected_tp_size} " \ - "from model but tool is expecting TP world size = " \ - f"{empirical_tp_world_size} from model instead." - for i in range(empirical_tp_world_size): - assert i in rank_scales_map, "Expected TP world size = "\ - f"{empirical_tp_world_size} but did not find KV " \ - f"cache scaling factors for TP rank {i}" - print(f"Found TP world size = {empirical_tp_world_size} " - "when extracting KV cache scales!") - return rank_scales_map - - -def _metadata_extractor(quantized_model_dir: str, - metadata_extract_fns: \ - Dict[str, Callable[[Dict[str, Any]], Any]]) \ - -> Dict[str, Any]: - """ - Given a directory containing quantized model files, this function - aims to extract metadata from the JSON files within this directory. - Each JSON file is expected to represent a dictionary in JSON - format (referred to as a "JSON-dictionary"). Metadata extraction is - defined by a dictionary called metadata_extract_fns, where each - metadata field name is mapped to an extraction function. - - These extraction functions are designed to take a JSON-dictionary - as their only argument and return the corresponding metadata. - While extraction functions are permitted to raise exceptions, they - should only raise a KeyError or ValueError if the metadata field - cannot be extracted from the current JSON-dictionary, yet there's - a possibility of finding it in another JSON-dictionary. - - The function returns a dictionary that maps metadata fields to - their extracted data. The keys of this dictionary correspond exactly - to those in metadata_extract_fns. If any fields fail to be extracted, - their corresponding values are set to None, and a warning is printed. - """ - if not os.path.isdir(quantized_model_dir): - raise FileNotFoundError( - f"The quantized model directory `{quantized_model_dir}` " - "does not exist.") - metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json")) - - result: Dict[str, Any] = {} - for file in metadata_files: - with open(file) as f: - try: - metadata = json.load(f) - except json.JSONDecodeError: - print(f"Could not parse `{file}` as a valid metadata file," - " skipping it.") - continue - if not isinstance(metadata, dict): - print(f"The file `{file}` does not correspond to a " - "JSON-serialized dictionary, skipping it.") - continue - for metadata_name, extract_fn in metadata_extract_fns.items(): - try: - metadata_info = extract_fn(metadata) - if metadata_name not in result: - result[metadata_name] = metadata_info - elif metadata_info != result[metadata_name]: - raise RuntimeError( - "Metadata mismatch! Originally found " - f"{metadata_name} = {result[metadata_name]} but " - f"now found {metadata_name} = {metadata_info} in " - f"`{file}`") - except KeyError: - # It is possible that a given file does not contain some - # of our selected metadata as it could be located in some - # other metadata file. - # 'EFINAE': extract_fn failure is not an error. - pass - except ValueError: - # See above. - pass - - # Warn if we cannot find any of the requested metadata - for metadata_name in metadata_extract_fns: - if metadata_name not in result: - print("WARNING: Unable to find requested metadata field " - f"`{metadata_name}`, setting it to None.") - result[metadata_name] = None - - return result - - -def main(args): - metadata_extract_fns = { - "model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"], - "tp_size": lambda json_dict: int(json_dict["tensor_parallel"]), - "model_dtype": lambda json_dict: json_dict["dtype"] - } - recovered_metadata = _metadata_extractor(args.quantized_model, - metadata_extract_fns) - if args.tp_size is not None: - metadata_tp_size = recovered_metadata["tp_size"] - if metadata_tp_size is not None: - assert args.tp_size == metadata_tp_size, \ - f"User expected TP world size = {args.tp_size} " \ - f"but found TP world size = {metadata_tp_size} from metadata!" - expected_tp_size = args.tp_size or recovered_metadata["tp_size"] - rank_keyword = "rank" - hf_tensor_files, use_safetensors = _prepare_hf_weights( - args.quantized_model, args.load_format) - rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors, - rank_keyword, expected_tp_size) - # Postprocess: formatting to the current schema. Consider pulling it - # out into a dedicated function should it ever become more complicated. - rank_scales_map = { - rank: {k: scale[k] - for k in sorted(scale.keys())} - for rank, scale in rank_scales_map.items() - } - # TODO: Expand this with activation and weights scaling factors when - # they are used in the future - schema = QuantParamSchema( - model_type=recovered_metadata["model_type"], - kv_cache={ - "dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else - recovered_metadata["model_dtype"]), - "scaling_factor": - rank_scales_map - }, - ) - - if args.output_dir is None: - output_file = os.path.join(args.quantized_model, args.output_name) - else: - if not os.path.isdir(args.output_dir): - os.makedirs(args.output_dir, exist_ok=True) - output_file = os.path.join(args.output_dir, args.output_name) - - with open(output_file, 'w') as f: - f.write(schema.model_dump_json(indent=4)) - print(f"Completed! KV cache scaling factors saved to {output_file}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="This simple utility extracts the " - "KV cache scaling factors from a quantized HF model " - "and saves them to a JSON file compatible with later " - "use by vLLM (pass this file to the appropriate " - "runtime typically using the argument " - "--quantization-param-path ). This is only used " - "if the KV cache dtype is FP8 and on ROCm (AMD GPU).") - parser.add_argument( - "--quantized-model", - help="Specify the directory containing a single quantized HF model. " - "It is expected that the quantization format is FP8_E4M3, for use " - "on ROCm (AMD GPU).", - required=True) - parser.add_argument( - "--load_format", - help="Optionally specify the format of the model's tensor files " - "containing the KV cache scaling factors.", - choices=["auto", "safetensors", "npz", "pt"], - default="auto") - parser.add_argument( - "--output-dir", - help="Optionally specify the output directory. By default the " - "KV cache scaling factors will be saved in the model directory, " - "however you can override this behavior here.", - default=None) - parser.add_argument( - "--output-name", - help="Optionally specify the output filename.", - # TODO: Change this once additional scaling factors are enabled - default="kv_cache_scales.json") - parser.add_argument( - "--tp-size", - help="Optionally specify the tensor-parallel (TP) size that the " - "quantized model should correspond to. If specified, during KV " - "cache scaling factor extraction the observed TP size will be " - "checked against this and an error will be raised if there is " - "a mismatch. If not specified, the quantized model's expected " - "TP size is instead inferred from the largest TP rank observed. " - "The expected TP size is cross-checked against the TP ranks " - "observed in the quantized model and an error is raised if any " - "discrepancies are found.", - default=None, - type=int) - args = parser.parse_args() - - main(args) diff --git a/examples/fp8/quantizer/README.md b/examples/fp8/quantizer/README.md deleted file mode 100644 index d0895e97dc34..000000000000 --- a/examples/fp8/quantizer/README.md +++ /dev/null @@ -1,32 +0,0 @@ -### Quantizer Utilities -`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported -from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py) - -### Prerequisite - -#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later -`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo` - -#### AMMO Download (code and docs) -`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz` -`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz` - -### Usage - -#### Run on H100 system for speed if FP8; number of GPUs depends on the model size - -#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache: -`python quantize.py --model-dir ./ll2-7b --dtype float16 --qformat fp8 --kv-cache-dtype fp8 --output-dir ./ll2_7b_fp8 --calib-size 512 --tp-size 1` - -Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference) -``` -# ll ./ll2_7b_fp8/ -total 19998244 -drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./ -drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../ --rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json --rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz --rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors -# -``` - diff --git a/examples/fp8/quantizer/quantize.py b/examples/fp8/quantizer/quantize.py deleted file mode 100644 index d75cc8b3d1cf..000000000000 --- a/examples/fp8/quantizer/quantize.py +++ /dev/null @@ -1,367 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501 -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Adapted from examples/quantization/hf_ptq.py -""" - -import argparse -import copy -import json -import random -import time - -import ammo.torch.quantization as atq -import numpy as np -import torch -from ammo.torch.export import export_model_config -from datasets import load_dataset -from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM, AutoTokenizer - -RAND_SEED = 1234 -MAX_SEQ_LEN = 2048 - -EMPTY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "enable": False, - }, - "*input_quantizer": { - "enable": False - }, - "*lm_head*": { - "enable": False - }, - "*output_layer*": { - "enable": False - }, - "default": { - "enable": False - }, - }, - "algorithm": "max", -} - -KV_CACHE_CFG = { - "*.query_key_value.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.Wqkv.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.W_pack.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.c_attn.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.k_proj.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.v_proj.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, -} - -QUANT_CFG_CHOICES = { - "int8_sq": atq.INT8_SMOOTHQUANT_CFG, - "fp8": atq.FP8_DEFAULT_CFG, - "int4_awq": atq.INT4_AWQ_CFG, - "w4a8_awq": atq.W4A8_AWQ_BETA_CFG, - "int8_wo": EMPTY_CFG, - "int4_wo": EMPTY_CFG, - "full_prec": EMPTY_CFG, -} - -MODEL_NAME_PATTERN_MAP = { - "GPT2": "gpt2", - "Xverse": "llama", - "Llama": "llama", - "Mistral": "llama", - "GPTJ": "gptj", - "FalconForCausalLM": "falcon", - "RWForCausalLM": "falcon", - "baichuan": "baichuan", - "MPT": "mpt", - "Bloom": "bloom", - "ChatGLM": "chatglm", - "QWen": "qwen", -} - - -def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None): - print(f"Initializing tokenizer from {ckpt_path}") - tokenizer = AutoTokenizer.from_pretrained( - ckpt_path, - model_max_length=max_seq_len, - padding_side="left", - trust_remote_code=True, - ) - if model_type and model_type == "qwen": - # qwen use token id 151643 as pad and eos tokens - tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) - tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) - - # can't set attribute 'pad_token' for "" - if tokenizer.pad_token != "": - tokenizer.pad_token = tokenizer.eos_token - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - assert (tokenizer.pad_token - is not None), f"Pad token for {model_type} cannot be set!" - - return tokenizer - - -def get_model(ckpt_path, dtype="fp16", device="cuda"): - print(f"Initializing model from {ckpt_path}") - if dtype == "bf16" or dtype == "bfloat16": - dtype = torch.bfloat16 - elif dtype == "fp16" or dtype == "float16": - dtype = torch.float16 - elif dtype == "fp32" or dtype == "float32": - dtype = torch.float32 - else: - raise NotImplementedError(f"Unknown dtype {dtype}") - - # model_kwargs = {"torch_dtype": dtype} - model_kwargs = {"torch_dtype": "auto"} - - model = AutoModelForCausalLM.from_pretrained(ckpt_path, - device_map="auto", - **model_kwargs, - trust_remote_code=True) - model.eval() - - model_dtype = next(model.parameters()).dtype - if dtype != model_dtype: - print("[TensorRT-LLM][WARNING] The manually set model data type is " - f"{dtype}, but the data type of the HuggingFace model is " - f"{model_dtype}.") - - return model - - -def get_model_type(model): - for k, v in MODEL_NAME_PATTERN_MAP.items(): - if k.lower() in type(model).__name__.lower(): - return v - return None - - -def get_calib_dataloader(data="cnn_dailymail", - tokenizer=None, - batch_size=1, - calib_size=512, - block_size=512, - device=None): - print("Loading calibration dataset") - if data == "pileval": - dataset = load_dataset( - "json", - data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", - split="train") - dataset = dataset["text"][:calib_size] - elif data == "cnn_dailymail": - dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") - dataset = dataset["article"][:calib_size] - else: - raise NotImplementedError - - batch_encoded = tokenizer.batch_encode_plus(dataset, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=block_size) - if device: - batch_encoded = batch_encoded.to(device) - batch_encoded = batch_encoded["input_ids"] - - calib_dataloader = DataLoader(batch_encoded, - batch_size=batch_size, - shuffle=False) - - return calib_dataloader - - -def quantize_model(model, quant_cfg, calib_dataloader=None): - - def calibrate_loop(): - if calib_dataloader is None: - return - """Adjusts weights and scaling factors based on selected algorithms.""" - for idx, data in enumerate(calib_dataloader): - print(f"Calibrating batch {idx}") - model(data) - - print("Starting quantization...") - start_time = time.time() - atq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - end_time = time.time() - print("Quantization done. Total time used: {:.2f} s.".format(end_time - - start_time)) - - return model - - -def main(args): - if not torch.cuda.is_available(): - raise OSError("GPU is required for inference.") - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - - model = get_model(args.model_dir, args.dtype, args.device) - model_type = get_model_type(model) - tokenizer = get_tokenizer(args.model_dir, model_type=model_type) - - if args.qformat in ["full_prec", "int8_wo", "int4_wo" - ] and args.kv_cache_dtype is None: - print(f"No quantization applied, export {args.dtype} model") - else: - if "awq" in args.qformat: - if args.calib_size > 32: - print("AWQ calibration could take longer with calib_size = " - f"{args.calib_size}, Using calib_size=32 instead") - args.calib_size = 32 - print("\nAWQ calibration could take longer than other calibration " - "methods. Please increase the batch size to speed up the " - "calibration process. Batch size can be set by adding the " - "argument --batch_size to the command line.\n") - - calib_dataloader = get_calib_dataloader( - tokenizer=tokenizer, - batch_size=args.batch_size, - calib_size=args.calib_size, - device=args.device, - ) - - if args.qformat in QUANT_CFG_CHOICES: - quant_cfg = QUANT_CFG_CHOICES[args.qformat] - else: - raise ValueError( - f"Unsupported quantization format: {args.qformat}") - - if "awq" in args.qformat: - quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat]) - weight_quantizer = quant_cfg["quant_cfg"][ - "*weight_quantizer"] # type: ignore - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - weight_quantizer["block_sizes"][-1] = args.awq_block_size - - if args.kv_cache_dtype is not None: - if args.kv_cache_dtype == "fp8": - for value in KV_CACHE_CFG.values(): - value.update({"num_bits": (4, 3)}) # type: ignore - quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore - - print(quant_cfg) - - model = quantize_model(model, quant_cfg, calib_dataloader) - - with torch.inference_mode(): - if model_type is None: - print(f"Unknown model type {type(model).__name__}. Continue " - "exporting...") - model_type = f"unknown:{type(model).__name__}" - - export_path = args.output_dir - start_time = time.time() - - if args.qformat == "int4_awq" and model_type == "qwen": - torch.save(model.state_dict(), export_path) - else: - export_npz = (model_type not in [ - 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan' - ]) - - # export safetensors - export_model_config( - model, - model_type, - getattr(torch, args.dtype), - export_dir=export_path, - inference_tensor_parallel=args.tp_size, - inference_pipeline_parallel=args.pp_size, - # export_tensorrt_llm_config=(not export_npz), - export_tensorrt_llm_config=False, - export_npz=export_npz) - - # Workaround for wo quantization - if args.qformat in ["int8_wo", "int4_wo", "full_prec"]: - with open(f"{export_path}/config.json") as f: - tensorrt_llm_config = json.load(f) - if args.qformat == "int8_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' - elif args.qformat == "int4_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16' - else: - tensorrt_llm_config["quantization"]["quant_algo"] = None - with open(f"{export_path}/config.json", "w") as f: - json.dump(tensorrt_llm_config, f, indent=4) - - end_time = time.time() - print("Quantized model exported to {} \nTotal time used {:.2f} s.". - format(export_path, end_time - start_time)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--model-dir", - help="Specify where the HuggingFace model is", - required=True) - parser.add_argument("--device", default="cuda") - parser.add_argument("--dtype", help="Model data type.", default="float16") - parser.add_argument( - "--qformat", - help="Quantization format.", - default="full_prec", - choices=[ - "fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo", - "full_prec" - ], - ) - parser.add_argument("--batch-size", - help="Batch size for calibration.", - type=int, - default=1) - parser.add_argument("--calib-size", - help="Number of samples for calibration.", - type=int, - default=512) - parser.add_argument("--output-dir", default="exported_model") - parser.add_argument("--tp-size", type=int, default=1) - parser.add_argument("--pp-size", type=int, default=1) - parser.add_argument("--awq-block-size", type=int, default=128) - parser.add_argument("--kv-cache-dtype", - help="KV Cache dtype.", - default=None, - choices=["int8", "fp8", None]) - args = parser.parse_args() - - main(args) diff --git a/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json deleted file mode 100644 index a548f0a9611f..000000000000 --- a/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json +++ /dev/null @@ -1,90 +0,0 @@ -{ - "model_type": "llama", - "kv_cache": { - "dtype": "float8_e4m3fn", - "scaling_factor": { - "0": { - "0": 0.0230364128947258, - "1": 0.01979283057153225, - "2": 0.0241350457072258, - "3": 0.0308314748108387, - "4": 0.0430733822286129, - "5": 0.0370396226644516, - "6": 0.0306222103536129, - "7": 0.0357491634786129, - "8": 0.0358189195394516, - "9": 0.0443289652466774, - "10": 0.0433175228536129, - "11": 0.0416782945394516, - "12": 0.0366908498108387, - "13": 0.0432477705180645, - "14": 0.0410505048930645, - "15": 0.0457589291036129, - "16": 0.0418526791036129, - "17": 0.0432477705180645, - "18": 0.0469447560608387, - "19": 0.0514787957072258, - "20": 0.0541294664144516, - "21": 0.0587681382894516, - "22": 0.0625, - "23": 0.0585588738322258, - "24": 0.0600237175822258, - "25": 0.0588030144572258, - "26": 0.0531180277466774, - "27": 0.06396484375, - "28": 0.0603027381002903, - "29": 0.0582101047039032, - "30": 0.0625348836183548, - "31": 0.0585588738322258, - "32": 0.0582798570394516, - "33": 0.0575125589966774, - "34": 0.0590820349752903, - "35": 0.0614188089966774, - "36": 0.0631975457072258, - "37": 0.0615931935608387, - "38": 0.0601283498108387, - "39": 0.0571986623108387, - "40": 0.0670340433716774, - "41": 0.0523507259786129, - "42": 0.0547223798930645, - "43": 0.0631975457072258, - "44": 0.0663713738322258, - "45": 0.0603376142680645, - "46": 0.0652204304933548, - "47": 0.0734514519572258, - "48": 0.0693708211183548, - "49": 0.0725446492433548, - "50": 0.0627790242433548, - "51": 0.0691266804933548, - "52": 0.0688825398683548, - "53": 0.068429134786129, - "54": 0.0605119988322258, - "55": 0.0799386203289032, - "56": 0.0853097140789032, - "57": 0.0661969929933548, - "58": 0.0689871683716774, - "59": 0.0724051371216774, - "60": 0.0541643425822258, - "61": 0.0626743882894516, - "62": 0.0628487765789032, - "63": 0.0607212632894516, - "64": 0.0589076466858387, - "65": 0.0451660193502903, - "66": 0.0453055277466774, - "67": 0.0414341539144516, - "68": 0.0385044664144516, - "69": 0.0414341539144516, - "70": 0.0466308631002903, - "71": 0.0399693101644516, - "72": 0.0437011756002903, - "73": 0.0434221550822258, - "74": 0.0428989976644516, - "75": 0.0401785746216774, - "76": 0.0431082621216774, - "77": 0.0484444759786129, - "78": 0.0417829267680645, - "79": 0.0418178029358387 - } - } - } -} \ No newline at end of file diff --git a/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json deleted file mode 100644 index bb734039e982..000000000000 --- a/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json +++ /dev/null @@ -1,42 +0,0 @@ -{ - "model_type": "llama", - "kv_cache": { - "dtype": "float8_e4m3fn", - "scaling_factor": { - "0": { - "0": 0.0152239128947258, - "1": 0.0188860222697258, - "2": 0.0354178324341774, - "3": 0.0376674123108387, - "4": 0.0418526791036129, - "5": 0.0433175228536129, - "6": 0.0397600457072258, - "7": 0.0424455925822258, - "8": 0.0415387861430645, - "9": 0.0408412404358387, - "10": 0.0395856611430645, - "11": 0.0377371683716774, - "12": 0.0400739423930645, - "13": 0.040771484375, - "14": 0.0393415205180645, - "15": 0.0369001142680645, - "16": 0.03857421875, - "17": 0.0387486070394516, - "18": 0.0403180830180645, - "19": 0.0396205373108387, - "20": 0.0375627800822258, - "21": 0.0407366082072258, - "22": 0.0432477705180645, - "23": 0.0377022884786129, - "24": 0.0399693101644516, - "25": 0.0374581478536129, - "26": 0.0413295216858387, - "27": 0.0442243330180645, - "28": 0.0424804724752903, - "29": 0.0456891767680645, - "30": 0.0409109964966774, - "31": 0.0482352152466774 - } - } - } -} diff --git a/tests/models/decoder_only/language/test_fp8.py b/tests/models/decoder_only/language/test_fp8.py index 53f23e24511b..5f06f1e3a2fe 100644 --- a/tests/models/decoder_only/language/test_fp8.py +++ b/tests/models/decoder_only/language/test_fp8.py @@ -19,18 +19,17 @@ @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize( - "kv_cache_dtype,base_model,test_model,scale_path", + "kv_cache_dtype,base_model,test_model", [ # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. ("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct", - "nm-testing/Llama-3.2-1B-Instruct-FP8-KV", None), + "nm-testing/Llama-3.2-1B-Instruct-FP8-KV"), # Test FP16 checkpoint w. fp8_e5m2 kv-cache. ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct", None), + "meta-llama/Llama-3.2-1B-Instruct"), # Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. ("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-2-7b-chat-hf", - "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") + "meta-llama/Llama-2-7b-chat-hf") ]) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @@ -48,7 +47,6 @@ def test_models( kv_cache_dtype: str, base_model: str, test_model: str, - scale_path: Optional[str], max_tokens: int, enforce_eager: bool, backend: str, @@ -76,10 +74,6 @@ def test_models( baseline_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) - extra_kwargs = {} - if scale_path is not None: - extra_kwargs["quantization_param_path"] = scale_path - with vllm_runner( test_model, max_model_len=MAX_MODEL_LEN, @@ -87,7 +81,6 @@ def test_models( enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, disable_async_output_proc=disable_async_output_proc, - **extra_kwargs, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f6b5514f8987..0c5c3df999be 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -109,8 +109,8 @@ def paged_attention_v1( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -141,8 +141,8 @@ def paged_attention_v2( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -173,8 +173,8 @@ def paged_attention_rocm( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, @@ -1012,8 +1012,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, @@ -1027,8 +1027,8 @@ def reshape_and_cache_flash( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aed04361e5fb..52343a71be45 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass, fields +from dataclasses import dataclass, field, fields from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) @@ -124,6 +124,11 @@ class AttentionMetadata: multi_modal_placeholder_index_maps: Optional[Dict[ str, MultiModalPlaceholderMap.IndexMap]] + # Enable/disable KV scales calculation. This is so that we can disable the + # calculation until after prefill and cuda graph capture. + enable_kv_scales_calculation: bool = field(init=False, + default_factory=lambda: True) + @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: @@ -244,8 +249,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 99cb84346d84..7fed648bf1e7 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -357,8 +357,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 19daeb729ee6..1c9f715eb6c2 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -412,8 +412,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 05d997279893..27db056c987c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F +import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config @@ -56,10 +57,12 @@ def __init__( kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 is_attention_free = False + calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads @@ -69,8 +72,9 @@ def __init__( # expect the pre-quantized k/v_scale to be loaded along # with the model weights. self.kv_cache_dtype = kv_cache_dtype - self._k_scale = 1.0 - self._v_scale = 1.0 + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: @@ -120,6 +124,9 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + def forward( self, query: torch.Tensor, @@ -129,6 +136,9 @@ def forward( attn_metadata: AttentionMetadata, attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: + if self.calculate_kv_scales and \ + attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(key, value) if self.use_direct_call: return self.impl.forward(query, @@ -160,6 +170,12 @@ def forward( kv_cache, attn_type, self.layer_name) + def calc_kv_scales(self, key, value): + self._k_scale.copy_(torch.abs(key).max() / self.k_range) + self._v_scale.copy_(torch.abs(value).max() / self.v_range) + # We only calculate the scales once + self.calculate_kv_scales = False + def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore s += f", num_heads={self.impl.num_heads}" # type: ignore diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 076f151ffcb6..644c6f9d7da2 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -69,8 +69,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: ops.reshape_and_cache( key, @@ -95,8 +95,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 9c11a8df5527..ed192d2e7e9d 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -133,7 +133,7 @@ def _fwd_kernel( other=0.0) # [D,N] if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) else: k = k_load @@ -181,7 +181,7 @@ def _fwd_kernel( ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) # [N,D] if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: v = v_load p = p.to(v.dtype) diff --git a/vllm/config.py b/vllm/config.py index 307cf9c8d5b2..6d633921d709 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -112,11 +112,6 @@ class ModelConfig: decoding draft models. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. - quantization_param_path: Path to JSON file containing scaling factors. - Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the - model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. @@ -204,7 +199,6 @@ def __init__(self, max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None, quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, enforce_eager: Optional[bool] = None, max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20, @@ -260,7 +254,6 @@ def __init__(self, else: self.tokenizer_revision = tokenizer_revision self.quantization = quantization - self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs @@ -888,6 +881,7 @@ def __init__( sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, + calculate_kv_scales: Optional[bool] = None, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -898,7 +892,7 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb - + self.calculate_kv_scales = calculate_kv_scales self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() @@ -907,6 +901,10 @@ def __init__( self.num_gpu_blocks: Optional[int] = None self.num_cpu_blocks: Optional[int] = None + # Set calculate_kv_scales to False if the value is unset. + if self.calculate_kv_scales is None: + self.calculate_kv_scales = False + def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info @@ -3144,7 +3142,6 @@ def __str__(self): f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " f"kv_cache_dtype={self.cache_config.cache_dtype}, " - f"quantization_param_path={self.model_config.quantization_param_path}," f" device_config={self.device_config.device}, " f"decoding_config={self.decoding_config!r}, " f"observability_config={self.observability_config!r}, " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 64cc4592c286..8a3edb43a56b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -99,7 +99,6 @@ class EngineArgs: config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' kv_cache_dtype: str = 'auto' - quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -196,6 +195,7 @@ class EngineArgs: worker_cls: str = "auto" kv_transfer_config: Optional[KVTransferConfig] = None + calculate_kv_scales: Optional[bool] = None def __post_init__(self): if not self.tokenizer: @@ -346,17 +346,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=nullable_str, - default=None, - help='Path to the JSON file containing the KV cache ' - 'scaling factors. This should generally be supplied, when ' - 'KV cache dtype is FP8. Otherwise, KV cache scaling factors ' - 'default to 1.0, which may cause accuracy issues. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version ' - 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria.') parser.add_argument('--max-model-len', type=int, default=EngineArgs.max_model_len, @@ -943,6 +932,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default="auto", help='The worker class to use for distributed execution.') + parser.add_argument( + '--calculate-kv-scales', + action='store_true', + help='This enables dynamic calculation of ' + 'k_scale and v_scale when kv-cache-dtype is fp8. ' + 'If calculate-kv-scales is false, the scales will ' + 'be loaded from the model checkpoint if available. ' + 'Otherwise, the scales will default to 1.0.') + return parser @classmethod @@ -972,7 +970,6 @@ def create_model_config(self) -> ModelConfig: tokenizer_revision=self.tokenizer_revision, max_model_len=self.max_model_len, quantization=self.quantization, - quantization_param_path=self.quantization_param_path, enforce_eager=self.enforce_eager, max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, @@ -1046,6 +1043,7 @@ def create_engine_config(self, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, + calculate_kv_scales=self.calculate_kv_scales, ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, diff --git a/vllm/envs.py b/vllm/envs.py index 18870c1c6b51..f1bba5cfacf2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -72,6 +72,8 @@ VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False + K_SCALE_CONSTANT: int = 200 + V_SCALE_CONSTANT: int = 100 def get_default_cache_root(): @@ -459,6 +461,13 @@ def get_default_config_root(): "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + # Divisor for dynamic key scale factor calculation for FP8 KV Cache + "K_SCALE_CONSTANT": + lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + + # Divisor for dynamic value scale factor calculation for FP8 KV Cache + "V_SCALE_CONSTANT": + lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index d79536d196b9..b386a9f30963 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -2,7 +2,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.utils import print_warning_once +from vllm.platforms import current_platform +from vllm.utils import is_navi, print_warning_once class BaseKVCacheMethod(QuantizeMethodBase): @@ -38,11 +39,16 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. - if layer.kv_cache_dtype != "auto": + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: if layer.k_scale > 0.0 and layer.v_scale > 0.0: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 elif layer.k_scale < 0.0 and layer.v_scale < 0.0: # If no scales were loaded (both scales are invalid negative # values), use the default value of 1.0 @@ -56,6 +62,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: scale_to_duplicate = max(layer.k_scale, layer.v_scale) k_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 if not isinstance(k_scale, float) or not isinstance( v_scale, float): @@ -63,9 +72,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "for fp8 KV cache") # These are used in the final Attention.forward() - layer._k_scale = k_scale - layer._v_scale = v_scale - if (layer._k_scale == 1.0 and layer._v_scale == 1.0 + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + if (k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype): print_warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 9488d54edf36..0fd5840cb79f 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,8 +6,7 @@ import os import tempfile from collections import defaultdict -from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, - Tuple, Union) +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import filelock import gguf @@ -23,7 +22,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) -from vllm.model_executor.layers.quantization.schema import QuantParamSchema from vllm.platforms import current_platform from vllm.utils import print_warning_once @@ -469,46 +467,6 @@ def gguf_quant_weights_iterator( yield name, param -def kv_cache_scales_loader( - filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, - model_type: Optional[str]) -> Iterable[Tuple[int, float]]: - """ - A simple utility to read in KV cache scaling factors that have been - previously serialized to disk. Used by the model to populate the appropriate - KV cache scaling factors. The serialization should represent a dictionary - whose keys are the TP ranks and values are another dictionary mapping layers - to their KV cache scaling factors. - Keep this function in sync with the output of examples/fp8/extract_scales.py - """ - try: - with open(filename) as f: - context = { - "model_type": model_type, - "num_hidden_layers": num_hidden_layers, - "tp_rank": tp_rank, - "tp_size": tp_size, - } - schema_dct = json.load(f) - schema = QuantParamSchema.model_validate(schema_dct, - context=context) - layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] - return layer_scales_map.items() - - except FileNotFoundError: - logger.error("File or directory '%s' not found.", filename) - except json.JSONDecodeError: - logger.error("Error decoding JSON in file '%s'.", filename) - except Exception: - logger.exception("An error occurred while reading '%s'.", filename) - # This section is reached if and only if any of the excepts are hit - # Return an empty iterable (list) => no KV cache scales are loaded - # which ultimately defaults to 1.0 scales - logger.warning( - "Defaulting to KV cache scaling factors = 1.0 for all " - "layers in TP rank %d as an error occurred during loading.", tp_rank) - return [] - - def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: """convert PySafeSlice object from safetensors to torch.Tensor diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 0398f0943a70..6926de007bc8 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -30,8 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -46,9 +45,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.exaone import ExaoneConfig @@ -583,31 +581,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, - ): - if not isinstance(self.transformer.h[layer_idx], nn.Identity): - layer_self_attn = self.transformer.h[layer_idx].attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index f9e0443b9a50..6353a7703d6c 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -29,8 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -46,9 +45,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -525,28 +523,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): - if not isinstance(self.model.layers[layer_idx], nn.Identity): - layer_self_attn = self.model.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2902e6999c2f..e28b04264ce9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,8 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -45,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -433,31 +432,6 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_params.add(name) return loaded_params - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): - if not isinstance(self.layers[layer_idx], nn.Identity): - layer_self_attn = self.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") - class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -595,9 +569,6 @@ def load_weights(self, weights: Iterable[Tuple[str, self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights) - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - self.model.load_kv_cache_scales(quantization_param_path) - # This function is used to remap the mistral format as # used by Mistral and Llama <=2 def maybe_remap_mistral( diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index caae0b65d7d1..087697bc45e6 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -30,8 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -46,9 +45,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -542,31 +540,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, - ): - if not isinstance(self.model.layers[layer_idx], nn.Identity): - layer_self_attn = self.model.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6ff98a8f1bab..42d29254063e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -3,7 +3,6 @@ import inspect import itertools import time -import warnings import weakref from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, @@ -38,7 +37,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) -from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( @@ -176,6 +174,8 @@ def from_broadcasted_tensor_dict( if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) + if "enable_kv_scales_calculation" in tensor_dict: + tensor_dict.pop("enable_kv_scales_calculation") return cls(**tensor_dict) @@ -1134,33 +1134,6 @@ def load_model(self) -> None: self.prompt_adapter_manager.create_prompt_adapter_manager( self.model)) - if self.kv_cache_dtype == "fp8" and current_platform.is_rocm(): - # Currently only ROCm accepts kv-cache scaling factors - # via quantization_param_path and this will be deprecated - # in the future. - if self.model_config.quantization_param_path is not None: - if callable(getattr(self.model, "load_kv_cache_scales", None)): - warnings.warn( - "Loading kv cache scaling factor from JSON is " - "deprecated and will be removed. Please include " - "kv cache scaling factors in the model checkpoint.", - FutureWarning, - stacklevel=2) - self.model.load_kv_cache_scales( - self.model_config.quantization_param_path) - logger.info("Loaded KV cache scaling factors from %s", - self.model_config.quantization_param_path) - else: - raise RuntimeError( - "Using FP8 KV cache and scaling factors provided but " - "model %s does not support loading scaling factors.", - self.model.__class__) - else: - logger.warning( - "Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") - if self.vllm_config.compilation_config.level ==\ CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): backend = self.vllm_config.compilation_config.init_backend( @@ -1326,6 +1299,10 @@ def profile_run(self) -> None: dtype=self.model_config.dtype, device=self.device) + # Disable KV Scale Calculation for dummy data during profile run + if model_input.attn_metadata is not None: + model_input.attn_metadata.enable_kv_scales_calculation = False + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return @@ -1458,7 +1435,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: batch_size, is_encoder_decoder_model=self.model_config. is_encoder_decoder)) - + # Disable KV Scale Calculation for graph capture + attn_metadata.enable_kv_scales_calculation = False if self.lora_config: lora_mapping = LoRAMapping( **dict(index_mapping=[0] * batch_size, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index cd4770202a18..4a3f711e3a9e 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -47,7 +47,8 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - if field.name in tensor_dict: + if field.name in tensor_dict and field.name != \ + 'enable_kv_scales_calculation': valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) From ef181a9d535d6bd62f07f6bb0a8d833e6f37d2cb Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 18 Dec 2024 11:53:50 -0500 Subject: [PATCH 02/16] Mllama kv scale fix (#335) * Using tensors in the explicit cache function calls from mllama implementation * Properly creating the tensor Signed-off-by: Gregory Shtrasberg --- vllm/model_executor/models/mllama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6536f9807730..34e2ff55af65 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -827,6 +827,7 @@ def _attention_with_mask( ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: + i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1): cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) @@ -839,8 +840,8 @@ def _attention_with_mask( attn_metadata. cross_slot_mapping, # type: ignore[union-attr] "auto", - 1.0, - 1.0, + i, + i, ) elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( @@ -849,7 +850,7 @@ def _attention_with_mask( cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + attn_metadata.cross_slot_mapping, "auto", i, i) else: raise ValueError( f"Unsupported Attention backend {self.attn.backend} " From f9645bfffc8a1f1238f307e573d20a5dd8003575 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 18 Dec 2024 19:44:21 +0000 Subject: [PATCH 03/16] format Signed-off-by: Gregory Shtrasberg --- vllm/config.py | 1 - vllm/model_executor/models/llama.py | 1 - 2 files changed, 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 6d633921d709..a07f0e885c6e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -174,7 +174,6 @@ def compute_hash(self) -> str: factors.append(self.model) factors.append(self.dtype) factors.append(self.quantization) - factors.append(self.quantization_param_path) factors.append(self.revision) factors.append(self.code_revision) factors.append(self.trust_remote_code) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e28b04264ce9..7ba92cd6c582 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -46,7 +46,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP From 6ad050e050e2a9327dc704ab429983fba5a14a2a Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 18 Dec 2024 19:58:58 +0000 Subject: [PATCH 04/16] Fix for different attention types Signed-off-by: Gregory Shtrasberg --- csrc/rocm/attention.cu | 2 +- vllm/attention/backends/flash_attn.py | 4 ++-- vllm/attention/backends/flashinfer.py | 4 ++-- vllm/attention/backends/hpu_attn.py | 4 ++-- vllm/attention/backends/ipex_attn.py | 4 ++-- vllm/attention/backends/pallas.py | 4 ++-- vllm/attention/backends/torch_sdpa.py | 4 ++-- vllm/attention/backends/xformers.py | 4 ++-- vllm/attention/ops/ipex_attn.py | 16 ++++++++-------- vllm/attention/ops/paged_attn.py | 4 ++-- vllm/attention/ops/prefix_prefill.py | 8 ++++---- vllm/v1/attention/backends/flash_attn.py | 4 ++-- 12 files changed, 31 insertions(+), 31 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index b638f130cd02..5f3c9fed5f2c 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, const float* k_scale, const float* v_scale) { + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c69e12ad78c4..24c4bd182500 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -635,8 +635,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d2..775fc8ee8079 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -771,8 +771,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index f90d15d4207e..decfa3e60b2f 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -150,8 +150,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 21949874bea4..cf01e73046d2 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -170,8 +170,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: IpexAttnMetadata, # type: ignore - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 9809aed0e66f..7c114518d276 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -148,8 +148,8 @@ def forward( value: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 0cff6f5952ab..a33069031ed2 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -429,8 +429,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TorchSDPAMetadata, # type: ignore - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3e59b3603d2c..c4d5f40ae27e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -412,8 +412,8 @@ def forward( value: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: "XFormersMetadata", - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index cbc6c74acf09..3a07184ed31f 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -52,8 +52,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: ops.reshape_and_cache( @@ -80,8 +80,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: tp_rank: int = 0 @@ -149,8 +149,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( @@ -170,8 +170,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: block_size = value_cache.shape[2] diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 644c6f9d7da2..fd62329141f6 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -204,8 +204,8 @@ def forward_prefix( max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index ed192d2e7e9d..e2f2b66dfc90 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -564,7 +564,7 @@ def _fwd_kernel_alibi( other=0.0) # [D,N] if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) else: k = k_load @@ -604,7 +604,7 @@ def _fwd_kernel_alibi( ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: v = v_load p = p.to(v.dtype) @@ -713,8 +713,8 @@ def context_attention_fwd(q, b_seq_len, b_ctx_len, max_input_len, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, alibi_slopes=None, sliding_window=None): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 026a0292cc33..4023e4d2d599 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -109,8 +109,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, attn_type: AttentionType = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: From 3eaca59bae1edc20579c00c6a27a762ab46591a2 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:13:18 -0500 Subject: [PATCH 05/16] Properly initializing the new field in the attn metadata (#337) Signed-off-by: Gregory Shtrasberg --- tests/kernels/utils.py | 2 ++ tests/worker/test_model_input.py | 3 +++ vllm/attention/backends/abstract.py | 5 ++--- vllm/attention/backends/blocksparse_attn.py | 2 ++ vllm/attention/backends/flash_attn.py | 3 +++ vllm/attention/backends/flashinfer.py | 2 ++ vllm/attention/backends/placeholder_attn.py | 3 +++ vllm/attention/backends/rocm_flash_attn.py | 2 ++ vllm/attention/backends/torch_sdpa.py | 1 + vllm/attention/backends/utils.py | 2 ++ vllm/attention/backends/xformers.py | 2 ++ vllm/worker/hpu_model_runner.py | 7 +++++-- vllm/worker/model_runner.py | 2 -- vllm/worker/model_runner_base.py | 3 +-- vllm/worker/openvino_model_runner.py | 1 + vllm/worker/tpu_model_runner.py | 5 +++++ vllm/worker/xpu_model_runner.py | 2 ++ 17 files changed, 38 insertions(+), 9 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 848eea7f54ca..8011398551b9 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -909,6 +909,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -958,6 +959,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 309854e6babf..57f1fd47a600 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -74,6 +74,7 @@ def test_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), @@ -126,6 +127,7 @@ def test_embedding_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), @@ -177,6 +179,7 @@ def test_multi_step_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 4e58e6db91a4..796017c1c2b2 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, fields from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) @@ -126,8 +126,7 @@ class AttentionMetadata: # Enable/disable KV scales calculation. This is so that we can disable the # calculation until after prefill and cuda graph capture. - enable_kv_scales_calculation: bool = field(init=False, - default_factory=lambda: True) + enable_kv_scales_calculation: bool @property @abstractmethod diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 0bef046e1680..aa7cc981372a 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -222,6 +222,7 @@ def prefill_metadata( slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c2af8f5216e0..c7850009a954 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -224,6 +224,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -268,6 +269,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, @@ -550,6 +552,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index f1d1cd330373..d9d07e6be173 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -218,6 +218,7 @@ def graph_capture_get_metadata_for_batch( num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, @@ -711,6 +712,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 534f79b3a60b..d2dc0d6cf0a5 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_decode_query_len=0, @@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, @@ -378,6 +380,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e9ef8e52aa2f..bbfcad0a3d62 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -152,6 +152,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -181,6 +182,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index cb2f4d5adeaa..6f4b2ffb2371 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -372,6 +372,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], prefill_block_tables=prefill_block_tables, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) return attn_metadata diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 56cc43430301..8ceeaf48bb19 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -264,6 +264,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -316,6 +317,7 @@ def graph_capture_get_metadata_for_batch( num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 29d4571e0583..88b3f56e0848 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -215,6 +215,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -259,6 +260,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 9d479f412af4..9d49c6d0a098 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -892,7 +892,8 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + None, # FIXME(kzawora): mutli-modality will not work here + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -1046,7 +1047,9 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6fb3ca1936b3..4155f3f27508 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -176,8 +176,6 @@ def from_broadcasted_tensor_dict( if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) - if "enable_kv_scales_calculation" in tensor_dict: - tensor_dict.pop("enable_kv_scales_calculation") return cls(**tensor_dict) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 20b91ead0828..c7abad7e0258 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -46,8 +46,7 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - if field.name in tensor_dict and field.name != \ - 'enable_kv_scales_calculation': + if field.name in tensor_dict: valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 6000e5dfe4e3..23a48e29731a 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -278,6 +278,7 @@ def _prepare_model_input( block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 7bdb7f0e2d6a..4a9c06b1d476 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -186,6 +186,7 @@ def _dummy_run( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=None, context_lens=None, effective_query_lens=None, @@ -204,6 +205,7 @@ def _dummy_run( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=effective_query_lens, @@ -235,6 +237,7 @@ def _dummy_run( num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) @@ -420,6 +423,7 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=prompt_lens, @@ -491,6 +495,7 @@ def _prepare_decode( num_decode_tokens=batch_size, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9cf25387560d..a4a03050a46c 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -257,6 +257,7 @@ def _prepare_prompt( is_prompt=True, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, @@ -341,6 +342,7 @@ def _prepare_decode( is_prompt=False, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0, From 3a18c313c8bd44109927360c3cad8661841baf05 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 7 Jan 2025 19:30:52 +0000 Subject: [PATCH 06/16] Upstream doesn't have fp8 navi support yet Signed-off-by: Gregory Shtrasberg --- vllm/model_executor/layers/quantization/kv_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index b386a9f30963..9c5f4040810f 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -3,7 +3,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.platforms import current_platform -from vllm.utils import is_navi, print_warning_once +from vllm.utils import print_warning_once class BaseKVCacheMethod(QuantizeMethodBase): @@ -46,7 +46,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_rocm() and not is_navi(): + if current_platform.is_rocm(): k_scale *= 2 v_scale *= 2 elif layer.k_scale < 0.0 and layer.v_scale < 0.0: @@ -62,7 +62,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: scale_to_duplicate = max(layer.k_scale, layer.v_scale) k_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_rocm() and not is_navi(): + if current_platform.is_rocm(): k_scale *= 2 v_scale *= 2 From 390bdaa2226c86836cf072bedf48d1b6b0751bd0 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 7 Jan 2025 23:06:30 +0000 Subject: [PATCH 07/16] Cannot reference tensor contents during graph capture Signed-off-by: Gregory Shtrasberg --- vllm/attention/backends/flash_attn.py | 4 ---- vllm/attention/backends/ipex_attn.py | 1 - vllm/attention/backends/pallas.py | 1 - vllm/attention/backends/torch_sdpa.py | 1 - vllm/v1/attention/backends/flash_attn.py | 4 ---- 5 files changed, 11 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c7850009a954..561529bdcdeb 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -657,10 +657,6 @@ def forward( attn_metadata: Metadata for attention. NOTE: It in-place updates the output tensor. """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert k_scale == 1.0 and v_scale == 1.0, ( - "key/v_scale is not supported in FlashAttention.") - assert output is not None, "Output tensor must be provided." attn_type = self.attn_type diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 6327166e5232..3838621c423e 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -193,7 +193,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 1a2bf4ca135b..1703ef1ec4a5 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -173,7 +173,6 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 6f4b2ffb2371..56896bad4d33 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -449,7 +449,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert k_scale == 1.0 and v_scale == 1.0 attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 2809b4d95738..8838e4439d86 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -148,10 +148,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert k_scale == 1.0 and v_scale == 1.0, ( - "key/v_scale is not supported in FlashAttention.") - assert output is not None, "Output tensor must be provided." if attn_metadata is None: From 681ceb3493e2a94c2e2591f04493ff42129e2ff1 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 8 Jan 2025 17:47:23 +0000 Subject: [PATCH 08/16] Adjusted cpu implementation datatypes Signed-off-by: Gregory Shtrasberg --- csrc/cpu/attention.cpp | 12 ++++++------ csrc/cpu/cache.cpp | 6 ++---- csrc/cpu/torch_bindings.cpp | 6 +++--- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ef5b14088c63..b9764056e8a2 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -460,11 +460,11 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -782,11 +782,11 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 31d454328b2c..e3809acad745 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -107,10 +107,8 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double k_scale, - double v_scale) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); - + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 74e4d8189d40..5d1c5f4c83d3 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } From 31e0041e157908a122a01ac7fd923181329b5d35 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 15 Jan 2025 19:08:06 +0000 Subject: [PATCH 09/16] Adjust API to the latest attention revert Signed-off-by: Gregory Shtrasberg --- vllm/attention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 70034ba542ec..57bbb7fabb7f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -141,7 +141,7 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: if self.calculate_kv_scales and \ - _attn_metadata.enable_kv_scales_calculation: + attn_metadata.enable_kv_scales_calculation: self.calc_kv_scales(key, value) if self.use_output: output = torch.empty_like(query) From 5bc162c1ca1d2f10f5f32705153795dccd9081a7 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 21 Jan 2025 18:23:29 +0000 Subject: [PATCH 10/16] Fix kv scales value in the tests Signed-off-by: Gregory Shtrasberg --- benchmarks/kernels/benchmark_paged_attention.py | 4 +++- tests/kernels/test_attention.py | 2 +- tests/kernels/test_blocksparse_attention.py | 2 +- tests/kernels/test_cache.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14eef00b855a..219013a38134 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -98,7 +98,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, + dtype=torch.float32, + device=device) for _ in range(num_iters): if version == "v1": diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 124d5d297a57..574a0f223ef0 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -182,7 +182,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Call the paged attention kernel. output = torch.empty_like(query) diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index fad342d1b592..08f31219e357 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -210,7 +210,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) tp_rank = 0 # Call the paged attention kernel. diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 40550ed51e2c..2579516fa205 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -160,7 +160,7 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache, From 84573bf0c5810cfb47cfae2ca9a4fac2022c1585 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 21 Jan 2025 21:05:54 +0000 Subject: [PATCH 11/16] Add scales to test prefix prefill Signed-off-by: Gregory Shtrasberg --- tests/kernels/test_prefix_prefill.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 3fdb7996ba4e..10e73ab950b0 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -138,6 +138,7 @@ def test_contexted_kv_attention( # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time @@ -153,6 +154,8 @@ def test_contexted_kv_attention( b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() @@ -168,6 +171,8 @@ def test_contexted_kv_attention( b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() @@ -366,6 +371,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time @@ -381,6 +387,8 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, alibi_slopes=alibi_slopes) torch.cuda.synchronize() start_time = time.time() @@ -396,6 +404,8 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, alibi_slopes=alibi_slopes) torch.cuda.synchronize() end_time = time.time() From 1ac64b08b3f264fb3eba59e3ddb3f78cba9f6680 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 21 Jan 2025 22:30:02 +0000 Subject: [PATCH 12/16] Preserving original float scale values to pas to the attention backends that don't support tensors (Flashinfer), since on CUDA during graph capturing phase referencing tensor values is impossible Signed-off-by: Gregory Shtrasberg --- vllm/attention/backends/abstract.py | 2 ++ vllm/attention/backends/flashinfer.py | 8 ++++---- vllm/attention/layer.py | 6 ++++++ vllm/model_executor/layers/quantization/kv_cache.py | 2 ++ 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 45b7d6d4eb21..e55c3f7e8795 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -231,6 +231,8 @@ class AttentionLayer(Protocol): _k_scale: torch.Tensor _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float def forward( self, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e6343d1dcb0c..be869a84b253 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -888,8 +888,8 @@ def forward( kv_cache, logits_soft_cap=logits_soft_cap, causal=True, - k_scale=layer._k_scale, - v_scale=layer._v_scale, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, window_left=window_left) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None @@ -899,8 +899,8 @@ def forward( kv_cache, sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, - k_scale=layer._k_scale, - v_scale=layer._v_scale, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, window_left=window_left) if prefill_output is None and decode_output is not None: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6a35322193d7..93c4ae71bdc3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -76,6 +76,12 @@ def __init__( self.calculate_kv_scales = calculate_kv_scales self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep the float32 versions of k/v_scale for attention + # backends that don't support tensors (Flashinfer) + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 86b368b65b45..e1870c73cc93 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -76,6 +76,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale if (k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype): logger.warning_once( From d9edda421a8027a9a21e9593ae3abe08155236eb Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 21 Jan 2025 22:36:12 +0000 Subject: [PATCH 13/16] Return assertions now that we have float values to use Signed-off-by: Gregory Shtrasberg --- vllm/attention/backends/flash_attn.py | 4 ++++ vllm/attention/backends/ipex_attn.py | 1 + vllm/attention/backends/pallas.py | 1 + vllm/attention/layer.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5ca548810655..34c837a18078 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -659,6 +659,10 @@ def forward( attn_metadata: Metadata for attention. NOTE: It in-place updates the output tensor. """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + assert output is not None, "Output tensor must be provided." attn_type = self.attn_type diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 4cd3471715c2..57916a3c6a34 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -193,6 +193,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index dde19c137b84..facdee6b29e3 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -173,6 +173,7 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 93c4ae71bdc3..79ea9b666c7e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -180,6 +180,8 @@ def forward( def calc_kv_scales(self, key, value): self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() # We only calculate the scales once self.calculate_kv_scales = False From 6efd98ff710ccf1fcd0c97c44b6d14d74da80399 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 22 Jan 2025 00:40:45 +0000 Subject: [PATCH 14/16] Using tensors in test_reshape_and_cache_flash Signed-off-by: Gregory Shtrasberg --- tests/kernels/test_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 2579516fa205..14524a79166f 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -258,8 +258,8 @@ def test_reshape_and_cache_flash( del key_caches del value_caches - k_scale = key.amax().item() / 256 - v_scale = value.amax().item() / 256 + k_scale = key.amax() / 256.0 + v_scale = value.amax() / 256.0 # Clone the KV caches. if kv_cache_dtype == "fp8": From 79811daff3b6c94ad900153a3b92adb9bf167830 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 22 Jan 2025 09:50:10 -0800 Subject: [PATCH 15/16] Using correct dtype for scales in the test Signed-off-by: Gregory Shtrasberg --- tests/kernels/test_cache.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 14524a79166f..c848be4f9d80 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -258,8 +258,8 @@ def test_reshape_and_cache_flash( del key_caches del value_caches - k_scale = key.amax() / 256.0 - v_scale = value.amax() / 256.0 + k_scale = (key.amax() / 256.0).to(torch.float32) + v_scale = (value.amax() / 256.0).to(torch.float32) # Clone the KV caches. if kv_cache_dtype == "fp8": @@ -284,12 +284,12 @@ def test_reshape_and_cache_flash( result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) ops.convert_fp8(result_key_cache, key_cache, - k_scale, + k_scale.item(), kv_dtype=kv_cache_dtype) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) ops.convert_fp8(result_value_cache, value_cache, - v_scale, + v_scale.item(), kv_dtype=kv_cache_dtype) # Run the reference implementation. From 55ef92cb9e39cc6a5435c8751059e9b25d150ced Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 23 Jan 2025 15:37:19 +0000 Subject: [PATCH 16/16] Move kv scales documentation into the new file Signed-off-by: Gregory Shtrasberg --- docs/source/features/quantization/quantized_kvcache.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/features/quantization/quantized_kvcache.md b/docs/source/features/quantization/quantized_kvcache.md index 95fa5e81e2f7..9f36c2949e0d 100644 --- a/docs/source/features/quantization/quantized_kvcache.md +++ b/docs/source/features/quantization/quantized_kvcache.md @@ -35,16 +35,18 @@ Studies have shown that FP8 E4M3 quantization typically only minimally degrades Here is an example of how to enable FP8 quantization: ```python +# To calculate kv cache scales on the fly enable the calculate_kv_scales +# parameter + from vllm import LLM, SamplingParams sampling_params = SamplingParams(temperature=0.7, top_p=0.8) -llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", kv_cache_dtype="fp8") +llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + calculate_kv_scales=True) prompt = "London is the capital of" out = llm.generate(prompt, sampling_params)[0].outputs[0].text print(out) - -# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial, -# output w/o scaling factors: England, located in the southeastern part of the country. It is known ``` The `kv_cache_dtype` argument specifies the data type for KV cache storage: