Skip to content

Commit aa26b7d

Browse files
authored
Merge branch 'main' into user/dongfengy/fix_ref
2 parents e0356eb + aae5d22 commit aa26b7d

File tree

107 files changed

+3895
-1085
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+3895
-1085
lines changed

.dockerignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ examples/**/.git
99
examples/**/*.bin
1010
examples/**/*.engine
1111
examples/**/*.onnx
12+
examples/**/*.safetensors
1213
examples/**/c-model
1314
examples/models/core/gpt/gpt*

.github/CODEOWNERS

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
# This file defines code ownership rules for the repository.
22

3-
# The following rule should only be uncommented on release branches (e.g., release/0.19).
4-
# The rule below requires that any PR to release/**/* branches must be approved by at least one member
5-
# of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR.
6-
# Without approval from a member of this team, PRs cannot be merged to release branches.
7-
# * @NVIDIA/trt-llm-release-branch-approval
83

94
## TensorRT-LLM Infra
105
### CI
@@ -160,3 +155,9 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
160155
# from a member of this team, PRs affecting public APIs cannot be merged to main or release branches.
161156
/tests/unittest/api_stability/ @NVIDIA/trt-llm-noncommitted-api-review-committee
162157
/tests/unittest/api_stability/references_committed/ @NVIDIA/trt-llm-committed-api-review-committee
158+
159+
# The following rule should only be uncommented on release branches (e.g., release/0.19).
160+
# The rule below requires that any PR to release/**/* branches must be approved by at least one member
161+
# of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR.
162+
# Without approval from a member of this team, PRs cannot be merged to release branches.
163+
# * @NVIDIA/trt-llm-release-branch-approval

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ TensorRT-LLM
99
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
1010
[![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads)
1111
[![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt)
12-
[![version](https://img.shields.io/badge/release-1.1.0rc2-green)](./tensorrt_llm/version.py)
12+
[![version](https://img.shields.io/badge/release-1.1.0rc3-green)](./tensorrt_llm/version.py)
1313
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
1414

1515
[Architecture](./docs/source/torch/arch_overview.md)   |   [Performance](./docs/source/performance/perf-overview.md)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](./docs/source/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
@@ -18,6 +18,9 @@ TensorRT-LLM
1818
<div align="left">
1919

2020
## Tech Blogs
21+
* [08/29] ADP Balance Strategy
22+
[➡️ link](./docs/source/blogs/tech_blog/blog10_ADP_Balance_Strategy.md)
23+
2124
* [08/05] Running a High-Performance GPT-OSS-120B Inference Server with TensorRT-LLM
2225
[➡️ link](./docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md)
2326

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,6 @@ class KVCacheBlockPool
480480
SizeType32 numKvHeads;
481481
SizeType32 sizePerHead;
482482
SizeType32 tokensPerBlock;
483-
SizeType32 quantSize;
484483
SizeType32 blockSize;
485484

486485
// Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
@@ -491,15 +490,14 @@ class KVCacheBlockPool
491490
bool containsBlockScales;
492491

493492
KVCacheBlockPool(SizeType32 numLayers, SizeType32 kvFactor, SizeType32 numKvHeads, SizeType32 sizePerHead,
494-
SizeType32 tokensPerBlock, SizeType32 quantSize, runtime::ITensor::SharedPtr primaryPtr = nullptr,
493+
SizeType32 tokensPerBlock, runtime::ITensor::SharedPtr primaryPtr = nullptr,
495494
runtime::ITensor::SharedPtr secondaryPtr = nullptr, bool containsBlockScales = false)
496495
: numLayers(numLayers)
497496
, kvFactor(kvFactor)
498497
, numKvHeads(numKvHeads)
499498
, sizePerHead(sizePerHead)
500499
, tokensPerBlock(tokensPerBlock)
501-
, quantSize(quantSize)
502-
, blockSize((numKvHeads * sizePerHead * tokensPerBlock) / quantSize)
500+
, blockSize(numKvHeads * sizePerHead * tokensPerBlock)
503501
, primaryPtr(std::move(primaryPtr))
504502
, secondaryPtr(std::move(secondaryPtr))
505503
, containsBlockScales(containsBlockScales)
@@ -648,6 +646,15 @@ class WindowBlockManager
648646
return mPools.at(poolIdx).blockSize;
649647
}
650648

649+
[[nodiscard]] SizeType32 getNumEltsPerContainer() const
650+
{
651+
#ifdef ENABLE_FP4
652+
return mDataType == nvinfer1::DataType::kFP4 ? 2 : 1;
653+
#else
654+
return 1;
655+
#endif
656+
}
657+
651658
[[nodiscard]] SizeType32 getNumPools(bool includeBlockScalePools = true) const noexcept
652659
{
653660
if (includeBlockScalePools)
@@ -1243,6 +1250,8 @@ class BaseKVCacheManager
12431250

12441251
[[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockPoolPointers() const = 0;
12451252

1253+
[[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockScalePoolPointers() const = 0;
1254+
12461255
[[nodiscard]] virtual runtime::ITensor::SharedPtr getLayerToPoolMapping() const = 0;
12471256

12481257
virtual void getBlockOffsetsOfBatch(
@@ -1552,7 +1561,7 @@ class KVCacheManager : public BaseKVCacheManager
15521561
return mLayerToPoolMapping;
15531562
}
15541563

1555-
[[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const
1564+
[[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const override
15561565
{
15571566
// TODO: add a new optional model input so the attention plugin can access these
15581567
return mBlockScalePoolPointers;

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,14 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
612612
mLayerToIndexWithinPool[layerIdx] = layerIndexWithinPool;
613613
}
614614

615+
auto numEltsPerContainer = getNumEltsPerContainer();
616+
#ifdef ENABLE_FP4
617+
if (numEltsPerContainer == 2)
618+
{
619+
TLLM_CHECK_WITH_INFO(sizePerHead % 2 == 0, "sizePerHead must be divisible by 2 for 4-bit KV cache.");
620+
}
621+
#endif
622+
615623
size_t poolIndex = 0;
616624
for (auto const [numKvHeads, numLayers] : numLayersPerPool)
617625
{
@@ -622,7 +630,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
622630
mLayerToPoolIndex[layerIdx] = poolIndex;
623631
}
624632
}
625-
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead, tokensPerBlock, 1);
633+
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead / numEltsPerContainer, tokensPerBlock);
626634
++poolIndex;
627635
}
628636

@@ -707,15 +715,16 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co
707715

708716
void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize)
709717
{
718+
SizeType32 const numEltsPerContainer = getNumEltsPerContainer();
710719
auto num_pools = mPools.size();
711720
for (size_t i = 0; i < num_pools; ++i)
712721
{
713722
auto& kv_pool = mPools[i];
714-
TLLM_CHECK_WITH_INFO(kv_pool.blockSize % quantBlockSize == 0,
715-
"Cannot use FP4 quantization since kv_pool.blockSize is not divisible by FP4 quantBlockSize.");
716-
717-
mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, kv_pool.sizePerHead,
718-
kv_pool.tokensPerBlock, quantBlockSize,
723+
TLLM_CHECK_WITH_INFO((kv_pool.sizePerHead * numEltsPerContainer) % quantBlockSize == 0,
724+
"Cannot use FP4 quantization since kv_pool.sizePerHead is not divisible by FP4 quantBlockSize.");
725+
auto blockScaleSizePerHead = kv_pool.sizePerHead * numEltsPerContainer / quantBlockSize;
726+
mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, blockScaleSizePerHead,
727+
kv_pool.tokensPerBlock,
719728
/*primaryPool=*/nullptr,
720729
/*secondaryPool=*/nullptr,
721730
/*containsBlockScales=*/true);
@@ -749,10 +758,6 @@ void WindowBlockManager::allocatePools(bool useUvm)
749758

750759
if (poolIsFP4)
751760
{
752-
TLLM_CHECK_WITH_INFO(blockSize % 2 == 0, "Block size must be divisible by 2 for FP4 KV cache.");
753-
// Divide by 2. We can't create FP4 buffers directly, so we'll have to create a uint8 buffer with
754-
// half the expected number of elements.
755-
blockSize /= 2;
756761
poolDtype = nvinfer1::DataType::kINT8;
757762
}
758763

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
214214
}
215215
xqaParams.kv_cache_data_type = DATA_TYPE_E4M3;
216216
}
217+
else if (mKVCacheQuantMode.hasFp4KvCache())
218+
{
219+
xqaParams.kv_cache_data_type = DATA_TYPE_E2M1;
220+
}
217221
else
218222
{
219223
xqaParams.kv_cache_data_type = xqaParams.data_type;
@@ -959,6 +963,9 @@ int AttentionOp::mlaGeneration(
959963
generation_params.can_use_one_more_block, generation_params.host_primary_pool_pointer,
960964
generation_params.host_secondary_pool_pointer, generation_params.block_offsets);
961965

966+
// Currently NVFP4 KV cache is not supported for MLA. An empty placeholder is provided.
967+
auto kv_scale_cache_buffer = KVBlockArray();
968+
962969
// Workspace pointer shift
963970
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
964971
size_t offset = 0;
@@ -1234,7 +1241,7 @@ int AttentionOp::mlaGeneration(
12341241
{
12351242
TLLM_LOG_DEBUG("XQA kernels are selected in the generation phase.");
12361243
xqaParams.stream = stream;
1237-
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
1244+
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
12381245
return 0;
12391246
}
12401247
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
@@ -1308,8 +1315,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
13081315
float const q_scaling = mQScaling;
13091316

13101317
KVCacheBuffer kv_cache_buffer;
1311-
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
1312-
auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
1318+
KVCacheBuffer kv_scale_cache_buffer;
1319+
1320+
auto sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits<T>() / 8 /*bits*/;
1321+
13131322
if (useKVCache())
13141323
{
13151324
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -1318,6 +1327,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
13181327
sizePerToken, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
13191328
params.sink_token_length, params.can_use_one_more_block, params.host_primary_pool_pointer,
13201329
params.host_secondary_pool_pointer, params.block_offsets);
1330+
if (mKVCacheQuantMode.hasFp4KvCache())
1331+
{
1332+
kv_scale_cache_buffer = KVBlockArray(params.batch_size, params.max_blocks_per_sequence, mTokensPerBlock,
1333+
sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
1334+
params.sink_token_length, params.can_use_one_more_block,
1335+
params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer,
1336+
params.block_offsets);
1337+
}
13211338
}
13221339
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
13231340
{
@@ -1326,6 +1343,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
13261343
isCrossAttention() ? params.cross_kv_length : params.max_attention_window_size, sizePerToken,
13271344
params.cyclic_attention_window_size, params.sink_token_length, false,
13281345
reinterpret_cast<BufferDataType*>(params.key_value_cache));
1346+
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV.");
13291347
}
13301348
}
13311349

@@ -1490,8 +1508,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14901508
decoder_params.blockSparseParams = mBlockSparseParams;
14911509
decoder_params.fmhaTileCounter = fmha_tile_counter_ptr;
14921510
decoder_params.quantScaleO = params.attention_output_orig_quant;
1493-
decoder_params.dequantScaleQ = params.kv_scale_quant_orig;
1494-
decoder_params.dequantScaleKv = params.kv_scale_quant_orig;
1511+
decoder_params.dequantScaleQkv = params.kv_scale_quant_orig;
1512+
decoder_params.separateQkvScales = mKVCacheQuantMode.hasFp4KvCache();
14951513
decoder_params.fmhaHostBmm1Scale = 1.0f / (sqrtf(getHeadSize() * 1.0f) * q_scaling);
14961514
decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr;
14971515
decoder_params.fmhaBmm2Scale = fmha_bmm2_scale_ptr;
@@ -1549,9 +1567,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15491567
sync_check_cuda_error(stream);
15501568
}
15511569

1552-
KvCacheDataType const cache_type = mKVCacheQuantMode.hasInt8KvCache()
1553-
? KvCacheDataType::INT8
1554-
: (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
1570+
KvCacheDataType cache_type{KvCacheDataType::BASE};
1571+
if (mKVCacheQuantMode.hasInt8KvCache())
1572+
{
1573+
cache_type = KvCacheDataType::INT8;
1574+
}
1575+
else if (mKVCacheQuantMode.hasFp8KvCache())
1576+
{
1577+
cache_type = KvCacheDataType::FP8;
1578+
}
1579+
else if (mKVCacheQuantMode.hasFp4KvCache())
1580+
{
1581+
cache_type = KvCacheDataType::NVFP4;
1582+
}
15551583

15561584
cudaDataType_t const gemm_data_type = tc::CudaDataType<T>::value;
15571585
int const attention_seq_len_1 = params.input_seq_length; // q length
@@ -1600,6 +1628,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16001628
preprocessingParams.quantized_qkv_output = fp8_qkv_buffer;
16011629
preprocessingParams.q_output = q_buf_2_;
16021630
preprocessingParams.kv_cache_buffer = kv_cache_buffer;
1631+
preprocessingParams.kv_cache_block_scales_buffer = kv_scale_cache_buffer;
16031632
preprocessingParams.qkv_bias = params.qkv_bias;
16041633
preprocessingParams.tokens_info = decoder_params.tokensInfo;
16051634
preprocessingParams.seq_lens = params.context_lengths;
@@ -1612,7 +1641,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16121641
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
16131642
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin;
16141643
preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin;
1615-
preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant;
1644+
preprocessingParams.qkv_scale_orig_quant = params.kv_scale_orig_quant;
16161645
preprocessingParams.spec_decoding_position_offsets = nullptr;
16171646
preprocessingParams.logn_scaling = params.logn_scaling_ptr;
16181647

@@ -1781,6 +1810,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
17811810
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
17821811
{
17831812
fmhaParams.pagedKvCache = kv_cache_buffer;
1813+
fmhaParams.pagedKvSfCache = kv_scale_cache_buffer;
17841814
}
17851815
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
17861816
fmhaParams.kvSeqLenPtr = decoder_params.seqKVLengths;
@@ -2126,8 +2156,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
21262156
int32_t const batch_beam = params.beam_width * params.num_requests;
21272157

21282158
KVCacheBuffer kv_cache_buffer;
2129-
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
2130-
auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
2159+
KVCacheBuffer kv_scale_cache_buffer;
2160+
2161+
auto const sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits<T>() / 8 /*bits*/;
2162+
21312163
if (useKVCache())
21322164
{
21332165
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
@@ -2137,13 +2169,22 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
21372169
params.cyclic_attention_window_size, params.max_cyclic_attention_window_size, params.sink_token_length,
21382170
params.can_use_one_more_block, params.host_primary_pool_pointer, params.host_secondary_pool_pointer,
21392171
reinterpret_cast<BufferDataType*>(params.block_offsets));
2172+
if (mKVCacheQuantMode.hasFp4KvCache())
2173+
{
2174+
kv_scale_cache_buffer = KVBlockArray(batch_beam, params.max_blocks_per_sequence, mTokensPerBlock,
2175+
sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
2176+
params.sink_token_length, params.can_use_one_more_block,
2177+
params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer,
2178+
reinterpret_cast<BufferDataType*>(params.block_offsets));
2179+
}
21402180
}
21412181
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
21422182
{
21432183
using BufferDataType = typename KVCacheBuffer::DataType;
21442184
kv_cache_buffer = KVLinearBuffer(batch_beam, params.max_attention_window_size, sizePerToken,
21452185
params.cyclic_attention_window_size, params.sink_token_length, false,
21462186
reinterpret_cast<BufferDataType*>(params.key_value_cache));
2187+
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV.");
21472188
}
21482189
}
21492190
sync_check_cuda_error(stream);
@@ -2215,7 +2256,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22152256
xqaParams.output = mhaOutput;
22162257
xqaParams.qkv = attention_input;
22172258
}
2218-
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
2259+
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
22192260
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
22202261
{
22212262
this->template ulyssesGenerationPostprocess<T>(
@@ -2232,6 +2273,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22322273
{
22332274
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
22342275
}
2276+
else if (mKVCacheQuantMode.hasFp4KvCache())
2277+
{
2278+
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 KV cache.");
2279+
}
22352280
else
22362281
{
22372282
TLLM_LOG_DEBUG("XQA kernels are not selected in the generation phase.");
@@ -2503,6 +2548,10 @@ int AttentionOp::initialize() noexcept
25032548
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121,
25042549
"fuse_fp4_quant only supports SM100 or SM120 or SM121 devices.");
25052550

2551+
// Check requirements for FP4 KV cache.
2552+
TLLM_CHECK_WITH_INFO(!mKVCacheQuantMode.hasFp4KvCache() || mFP8ContextFMHA,
2553+
"mFP8ContextFMHA must enable if FP4 KV cache is enabled");
2554+
25062555
TLLM_CHECK(isRoPE() == (mRotaryEmbeddingDim != 0));
25072556
TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16),
25082557
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
@@ -2579,7 +2628,10 @@ int AttentionOp::initialize() noexcept
25792628
{
25802629
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
25812630
}
2582-
// TODO: add FP4 KV cache support.
2631+
else if (mKVCacheQuantMode.hasFp4KvCache())
2632+
{
2633+
fmhaParams.dataTypeKv = DATA_TYPE_E2M1;
2634+
}
25832635
}
25842636
// The output dtype.
25852637
fmhaParams.dataTypeOut = data_type;
@@ -2789,6 +2841,11 @@ int AttentionOp::initialize() noexcept
27892841
fixedParams.kvDataType = DATA_TYPE_E4M3;
27902842
fixedParams.mathDataType = DATA_TYPE_E4M3;
27912843
}
2844+
else if (mKVCacheQuantMode.hasFp4KvCache())
2845+
{
2846+
fixedParams.kvDataType = DATA_TYPE_E2M1;
2847+
fixedParams.mathDataType = DATA_TYPE_E4M3;
2848+
}
27922849
else
27932850
{
27942851
fixedParams.kvDataType = fixedParams.inputDataType;

0 commit comments

Comments
 (0)