Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ bool DecoderXQAImplJIT::mayHavePerfGain(XQAParams const& xqaParams) const
if (xqaParams.multi_block_mode)
{
int history_length = xqaParams.max_past_kv_length;
multi_block_count = history_length / kMinHistoryTokensPerBlock;
// Always use at least 1 block regardless of history length
multi_block_count = std::max(1, history_length / kMinHistoryTokensPerBlock);
}
int block_count = num_kv_heads * batch_size * multi_block_count;
return static_cast<float>(block_count) * kEnableMinBlockFactor >= static_cast<float>(mRunner->mMultiProcessorCount);
Expand All @@ -98,12 +99,25 @@ bool DecoderXQAImplJIT::shouldUse(XQAParams const& umbrellaXQAParams, bool forCo
return true;
}
}
TLLM_LOG_DEBUG("JIT XQA is not used: no supported configuration found for any beam_width");
return false;
}
else
{
auto const& xqaParams = umbrellaXQAParams;
return supportConfig(xqaParams, forConfigurePlugin) && mayHavePerfGain(xqaParams);
bool isConfigSupported = supportConfig(xqaParams, forConfigurePlugin);
if (!isConfigSupported)
{
TLLM_LOG_DEBUG("JIT XQA is not used: unsupported configuration");
return false;
}
bool hasPerfGain = mayHavePerfGain(xqaParams);
if (!hasPerfGain)
{
TLLM_LOG_DEBUG("JIT XQA is not used: maybe no performance gain");
return false;
}
return true;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ void DecoderXQAImplPrecompiled::runDispatchBuffer(

#define SUPPORT_RETURN_FALSE(X) \
{ \
TLLM_LOG_DEBUG("XQA is not used. Reason: %s", X); \
return false; \
}

Expand Down Expand Up @@ -522,8 +523,17 @@ bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams, bool forCo
}

XQAKernelList const* xqa_kernel = getXQAKernels(mRunner->mDataType, tensorrt_llm::common::getSMVersion());
return xqa_kernel->supportConfig(xqaParams)
&& xqa_kernel->mayHavePerfGain(xqaParams, mRunner->mMultiProcessorCount);
bool supportConfig = xqa_kernel->supportConfig(xqaParams);
if (!supportConfig)
{
SUPPORT_RETURN_FALSE("supportConfig");
}
bool mayHavePerfGain = xqa_kernel->mayHavePerfGain(xqaParams, mRunner->mMultiProcessorCount);
if (!mayHavePerfGain)
{
SUPPORT_RETURN_FALSE("mayHavePerfGain");
}
return true;
}

#undef SUPPORT_RETURN_FALSE
Expand Down
10 changes: 10 additions & 0 deletions tests/unittest/llmapi/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
import gc
import json
import os

# Required for test_generate_with_seed to pass.
# See the discussion in https://github.com/NVIDIA/TensorRT-LLM/pull/4264#issuecomment-2943269891
# The following line must be ahead of any tensorrt_llm imports,
# since currently env util functions like getEnvForceDeterministic are implemented using static variables,
# which means they are only initialized once the CPP translation unit is loaded (should be refactored to be non static later).
os.environ['TRTLLM_FORCE_XQA'] = '1'
# Note that we cannot use os.environ['FORCE_DETERMINISTIC'] = '1' here,
# since it will disable KV cache reuse and make test_llm_api_draft_target fail.

import random
import shutil
import sys
Expand Down