@@ -214,6 +214,10 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
214
214
}
215
215
xqaParams.kv_cache_data_type = DATA_TYPE_E4M3;
216
216
}
217
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
218
+ {
219
+ xqaParams.kv_cache_data_type = DATA_TYPE_E2M1;
220
+ }
217
221
else
218
222
{
219
223
xqaParams.kv_cache_data_type = xqaParams.data_type ;
@@ -959,6 +963,9 @@ int AttentionOp::mlaGeneration(
959
963
generation_params.can_use_one_more_block , generation_params.host_primary_pool_pointer ,
960
964
generation_params.host_secondary_pool_pointer , generation_params.block_offsets );
961
965
966
+ // Currently NVFP4 KV cache is not supported for MLA. An empty placeholder is provided.
967
+ auto kv_scale_cache_buffer = KVBlockArray ();
968
+
962
969
// Workspace pointer shift
963
970
int8_t * workspace_byte_ptr = reinterpret_cast <int8_t *>(params.workspace );
964
971
size_t offset = 0 ;
@@ -1234,7 +1241,7 @@ int AttentionOp::mlaGeneration(
1234
1241
{
1235
1242
TLLM_LOG_DEBUG (" XQA kernels are selected in the generation phase." );
1236
1243
xqaParams.stream = stream;
1237
- mXqaDispatcher ->run (xqaParams, kv_cache_buffer);
1244
+ mXqaDispatcher ->run (xqaParams, kv_cache_buffer, kv_scale_cache_buffer );
1238
1245
return 0 ;
1239
1246
}
1240
1247
else if (mIsSpecDecodingEnabled && mUseSpecDecoding )
@@ -1308,8 +1315,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1308
1315
float const q_scaling = mQScaling ;
1309
1316
1310
1317
KVCacheBuffer kv_cache_buffer;
1311
- auto const elemSize = mKVCacheQuantMode .hasKvCacheQuant () ? sizeof (int8_t ) : sizeof (T);
1312
- auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
1318
+ KVCacheBuffer kv_scale_cache_buffer;
1319
+
1320
+ auto sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits<T>() / 8 /* bits*/ ;
1321
+
1313
1322
if (useKVCache ())
1314
1323
{
1315
1324
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -1318,6 +1327,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1318
1327
sizePerToken, params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
1319
1328
params.sink_token_length , params.can_use_one_more_block , params.host_primary_pool_pointer ,
1320
1329
params.host_secondary_pool_pointer , params.block_offsets );
1330
+ if (mKVCacheQuantMode .hasFp4KvCache ())
1331
+ {
1332
+ kv_scale_cache_buffer = KVBlockArray (params.batch_size , params.max_blocks_per_sequence , mTokensPerBlock ,
1333
+ sizePerToken / 8 , params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
1334
+ params.sink_token_length , params.can_use_one_more_block ,
1335
+ params.host_primary_block_scale_pool_pointer , params.host_secondary_block_scale_pool_pointer ,
1336
+ params.block_offsets );
1337
+ }
1321
1338
}
1322
1339
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
1323
1340
{
@@ -1326,6 +1343,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1326
1343
isCrossAttention () ? params.cross_kv_length : params.max_attention_window_size , sizePerToken,
1327
1344
params.cyclic_attention_window_size , params.sink_token_length , false ,
1328
1345
reinterpret_cast <BufferDataType*>(params.key_value_cache ));
1346
+ TLLM_CHECK_WITH_INFO (!(mKVCacheQuantMode .hasFp4KvCache ()), " FP4 KV cache only supports paged KV." );
1329
1347
}
1330
1348
}
1331
1349
@@ -1490,8 +1508,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1490
1508
decoder_params.blockSparseParams = mBlockSparseParams ;
1491
1509
decoder_params.fmhaTileCounter = fmha_tile_counter_ptr;
1492
1510
decoder_params.quantScaleO = params.attention_output_orig_quant ;
1493
- decoder_params.dequantScaleQ = params.kv_scale_quant_orig ;
1494
- decoder_params.dequantScaleKv = params. kv_scale_quant_orig ;
1511
+ decoder_params.dequantScaleQkv = params.kv_scale_quant_orig ;
1512
+ decoder_params.separateQkvScales = mKVCacheQuantMode . hasFp4KvCache () ;
1495
1513
decoder_params.fmhaHostBmm1Scale = 1 .0f / (sqrtf (getHeadSize () * 1 .0f ) * q_scaling);
1496
1514
decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr;
1497
1515
decoder_params.fmhaBmm2Scale = fmha_bmm2_scale_ptr;
@@ -1549,9 +1567,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1549
1567
sync_check_cuda_error (stream);
1550
1568
}
1551
1569
1552
- KvCacheDataType const cache_type = mKVCacheQuantMode .hasInt8KvCache ()
1553
- ? KvCacheDataType::INT8
1554
- : (mKVCacheQuantMode .hasFp8KvCache () ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
1570
+ KvCacheDataType cache_type{KvCacheDataType::BASE};
1571
+ if (mKVCacheQuantMode .hasInt8KvCache ())
1572
+ {
1573
+ cache_type = KvCacheDataType::INT8;
1574
+ }
1575
+ else if (mKVCacheQuantMode .hasFp8KvCache ())
1576
+ {
1577
+ cache_type = KvCacheDataType::FP8;
1578
+ }
1579
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
1580
+ {
1581
+ cache_type = KvCacheDataType::NVFP4;
1582
+ }
1555
1583
1556
1584
cudaDataType_t const gemm_data_type = tc::CudaDataType<T>::value;
1557
1585
int const attention_seq_len_1 = params.input_seq_length ; // q length
@@ -1600,6 +1628,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1600
1628
preprocessingParams.quantized_qkv_output = fp8_qkv_buffer;
1601
1629
preprocessingParams.q_output = q_buf_2_;
1602
1630
preprocessingParams.kv_cache_buffer = kv_cache_buffer;
1631
+ preprocessingParams.kv_cache_block_scales_buffer = kv_scale_cache_buffer;
1603
1632
preprocessingParams.qkv_bias = params.qkv_bias ;
1604
1633
preprocessingParams.tokens_info = decoder_params.tokensInfo ;
1605
1634
preprocessingParams.seq_lens = params.context_lengths ;
@@ -1612,7 +1641,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1612
1641
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
1613
1642
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin ;
1614
1643
preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin ;
1615
- preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant ;
1644
+ preprocessingParams.qkv_scale_orig_quant = params.kv_scale_orig_quant ;
1616
1645
preprocessingParams.spec_decoding_position_offsets = nullptr ;
1617
1646
preprocessingParams.logn_scaling = params.logn_scaling_ptr ;
1618
1647
@@ -1781,6 +1810,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1781
1810
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
1782
1811
{
1783
1812
fmhaParams.pagedKvCache = kv_cache_buffer;
1813
+ fmhaParams.pagedKvSfCache = kv_scale_cache_buffer;
1784
1814
}
1785
1815
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
1786
1816
fmhaParams.kvSeqLenPtr = decoder_params.seqKVLengths ;
@@ -2126,8 +2156,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
2126
2156
int32_t const batch_beam = params.beam_width * params.num_requests ;
2127
2157
2128
2158
KVCacheBuffer kv_cache_buffer;
2129
- auto const elemSize = mKVCacheQuantMode .hasKvCacheQuant () ? sizeof (int8_t ) : sizeof (T);
2130
- auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
2159
+ KVCacheBuffer kv_scale_cache_buffer;
2160
+
2161
+ auto const sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits<T>() / 8 /* bits*/ ;
2162
+
2131
2163
if (useKVCache ())
2132
2164
{
2133
2165
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -2137,13 +2169,22 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
2137
2169
params.cyclic_attention_window_size , params.max_cyclic_attention_window_size , params.sink_token_length ,
2138
2170
params.can_use_one_more_block , params.host_primary_pool_pointer , params.host_secondary_pool_pointer ,
2139
2171
reinterpret_cast <BufferDataType*>(params.block_offsets ));
2172
+ if (mKVCacheQuantMode .hasFp4KvCache ())
2173
+ {
2174
+ kv_scale_cache_buffer = KVBlockArray (batch_beam, params.max_blocks_per_sequence , mTokensPerBlock ,
2175
+ sizePerToken / 8 , params.cyclic_attention_window_size , params.max_cyclic_attention_window_size ,
2176
+ params.sink_token_length , params.can_use_one_more_block ,
2177
+ params.host_primary_block_scale_pool_pointer , params.host_secondary_block_scale_pool_pointer ,
2178
+ reinterpret_cast <BufferDataType*>(params.block_offsets ));
2179
+ }
2140
2180
}
2141
2181
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
2142
2182
{
2143
2183
using BufferDataType = typename KVCacheBuffer::DataType;
2144
2184
kv_cache_buffer = KVLinearBuffer (batch_beam, params.max_attention_window_size , sizePerToken,
2145
2185
params.cyclic_attention_window_size , params.sink_token_length , false ,
2146
2186
reinterpret_cast <BufferDataType*>(params.key_value_cache ));
2187
+ TLLM_CHECK_WITH_INFO (!(mKVCacheQuantMode .hasFp4KvCache ()), " FP4 KV cache only supports paged KV." );
2147
2188
}
2148
2189
}
2149
2190
sync_check_cuda_error (stream);
@@ -2215,7 +2256,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
2215
2256
xqaParams.output = mhaOutput;
2216
2257
xqaParams.qkv = attention_input;
2217
2258
}
2218
- mXqaDispatcher ->run (xqaParams, kv_cache_buffer);
2259
+ mXqaDispatcher ->run (xqaParams, kv_cache_buffer, kv_scale_cache_buffer );
2219
2260
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1 )
2220
2261
{
2221
2262
this ->template ulyssesGenerationPostprocess <T>(
@@ -2232,6 +2273,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
2232
2273
{
2233
2274
TLLM_CHECK_WITH_INFO (false , " No available kernels are found for FP4 output." );
2234
2275
}
2276
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2277
+ {
2278
+ TLLM_CHECK_WITH_INFO (false , " No available kernels are found for FP4 KV cache." );
2279
+ }
2235
2280
else
2236
2281
{
2237
2282
TLLM_LOG_DEBUG (" XQA kernels are not selected in the generation phase." );
@@ -2503,6 +2548,10 @@ int AttentionOp::initialize() noexcept
2503
2548
TLLM_CHECK_WITH_INFO (!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121 ,
2504
2549
" fuse_fp4_quant only supports SM100 or SM120 or SM121 devices." );
2505
2550
2551
+ // Check requirements for FP4 KV cache.
2552
+ TLLM_CHECK_WITH_INFO (!mKVCacheQuantMode .hasFp4KvCache () || mFP8ContextFMHA ,
2553
+ " mFP8ContextFMHA must enable if FP4 KV cache is enabled" );
2554
+
2506
2555
TLLM_CHECK (isRoPE () == (mRotaryEmbeddingDim != 0 ));
2507
2556
TLLM_CHECK_WITH_INFO ((mSM >= 80 ) || (mType != nvinfer1::DataType::kBF16 ),
2508
2557
" Unsupported data type, pre SM 80 GPUs do not support bfloat16" );
@@ -2579,7 +2628,10 @@ int AttentionOp::initialize() noexcept
2579
2628
{
2580
2629
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
2581
2630
}
2582
- // TODO: add FP4 KV cache support.
2631
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2632
+ {
2633
+ fmhaParams.dataTypeKv = DATA_TYPE_E2M1;
2634
+ }
2583
2635
}
2584
2636
// The output dtype.
2585
2637
fmhaParams.dataTypeOut = data_type;
@@ -2789,6 +2841,11 @@ int AttentionOp::initialize() noexcept
2789
2841
fixedParams.kvDataType = DATA_TYPE_E4M3;
2790
2842
fixedParams.mathDataType = DATA_TYPE_E4M3;
2791
2843
}
2844
+ else if (mKVCacheQuantMode .hasFp4KvCache ())
2845
+ {
2846
+ fixedParams.kvDataType = DATA_TYPE_E2M1;
2847
+ fixedParams.mathDataType = DATA_TYPE_E4M3;
2848
+ }
2792
2849
else
2793
2850
{
2794
2851
fixedParams.kvDataType = fixedParams.inputDataType ;
0 commit comments