Skip to content

Commit 57d68bc

Browse files
authored
Merge branch 'main' into fix/input_prep_for_OOT_models
2 parents 90f1b6d + e968f98 commit 57d68bc

File tree

2,189 files changed

+37396
-9728
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

2,189 files changed

+37396
-9728
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ TensorRT-LLM
99
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
1010
[![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads)
1111
[![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt)
12-
[![version](https://img.shields.io/badge/release-1.0.0rc6-green)](./tensorrt_llm/version.py)
12+
[![version](https://img.shields.io/badge/release-1.1.0rc0-green)](./tensorrt_llm/version.py)
1313
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
1414

1515
[Architecture](./docs/source/torch/arch_overview.md)   |   [Performance](./docs/source/performance/perf-overview.md)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](./docs/source/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
@@ -253,5 +253,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer
253253
## Useful Links
254254
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT-LLM.
255255
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT-LLM.
256-
- [AutoDeploy](./examples/auto_deploy/README.md): An experimental backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
256+
- [AutoDeploy](./examples/auto_deploy/README.md): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
257257
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT-LLM Q&A and news.

cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "tensorrt_llm/executor/executor.h"
2020

21+
#include <atomic>
2122
#include <chrono>
2223
#include <condition_variable>
2324
#include <deque>
@@ -36,7 +37,8 @@ using BlockPtr = std::shared_ptr<KVCacheBlock>;
3637
class KVCacheEventManager
3738
{
3839
public:
39-
explicit KVCacheEventManager(size_t maxKVEventEntries);
40+
explicit KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank = std::nullopt,
41+
std::optional<SizeType32> attentionDpSize = std::nullopt, SizeType32 attentionDpEventsGatherPeriodMs = 5);
4042

4143
~KVCacheEventManager();
4244
KVCacheEventManager(KVCacheEventManager& other) = delete;
@@ -61,14 +63,19 @@ class KVCacheEventManager
6163
// Worker thread which adds events to mEvents.
6264
void worker();
6365

66+
// Thread which exchanges events if attentionDP is enabled
67+
void exchangeAttentionDpThread();
68+
6469
private:
6570
// Add an event to mEventQueue
6671
void enqueueEvent(executor::KVCacheEvent&& event);
6772

6873
/// @brief Flag to terminate the worker
69-
bool mRun;
74+
std::atomic<bool> mRun;
7075
/// @brief Worker thread
7176
std::thread mWorkerThread;
77+
/// @brief Exchange thread for attention DP events
78+
std::thread mExchangeAttentionDpThread;
7279

7380
/// @brief The deque of events
7481
std::deque<executor::KVCacheEvent> mEvents;
@@ -91,6 +98,17 @@ class KVCacheEventManager
9198
size_t mMaxSize;
9299
/// @brief An auto-incrementing event id counter
93100
size_t mEventId;
101+
102+
/// @brief Attention DP ranks and size
103+
/// If set, we will exchange KV cache events and accumulate on rank 0
104+
std::optional<SizeType32> mAttentionDpRank;
105+
std::optional<SizeType32> mAttentionDpSize;
106+
107+
/// @brief The period in milliseconds to gather attention DP events across rank
108+
SizeType32 mAttentionDpEventsGatherPeriodMs;
109+
110+
/// @brief MPI communicator for attention DP
111+
std::unique_ptr<tensorrt_llm::mpi::MpiComm> mMpiComm;
94112
};
95113

96114
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2027,7 +2027,7 @@ class GenericLlmRequest
20272027

20282028
// Scatter the input tokens to other beam
20292029
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
2030-
mLastTokens = VecTokens(mSamplingConfig.beamWidth);
2030+
mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back());
20312031

20322032
// Init mUniqueTokens
20332033
VecUniqueTokens uniqueTokens{inputTokens.size()};
@@ -2347,6 +2347,9 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23472347
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager);
23482348

23492349
void moveLoraWeightsToGpu(runtime::BufferManager const& manager);
2350+
2351+
// Remove LoRA weights and LoRA config tensors
2352+
void removeLoraTensors();
23502353
};
23512354

23522355
} // namespace tensorrt_llm::batch_manager

cpp/include/tensorrt_llm/common/quantization.h

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ class QuantMode
122122
return QuantMode(BaseType(1u) << 14);
123123
}
124124

125+
static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept
126+
{
127+
return QuantMode(BaseType(1u) << 15);
128+
}
129+
130+
static constexpr QuantMode w4a16Mxfp4() noexcept
131+
{
132+
return QuantMode(BaseType(1u) << 16);
133+
}
134+
125135
constexpr BaseType value() const noexcept
126136
{
127137
return mValue;
@@ -202,14 +212,25 @@ class QuantMode
202212
return isSet(w4a8Mxfp4Fp8());
203213
}
204214

215+
constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept
216+
{
217+
return isSet(w4a8Mxfp4Mxfp8());
218+
}
219+
220+
constexpr bool hasW4a16Mxfp4() const noexcept
221+
{
222+
return isSet(w4a16Mxfp4());
223+
}
224+
205225
constexpr bool hasKvCacheQuant() const noexcept
206226
{
207227
return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache();
208228
}
209229

210230
static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken,
211231
bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq,
212-
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8)
232+
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8,
233+
bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4)
213234
{
214235
QuantMode quantMode{};
215236
if (quantizeWeights)
@@ -278,25 +299,35 @@ class QuantMode
278299
quantMode += w4a8Mxfp4Fp8();
279300
}
280301

302+
if (useW4a8Mxfp4Mxfp8)
303+
{
304+
quantMode += w4a8Mxfp4Mxfp8();
305+
}
306+
307+
if (useW4a16Mxfp4)
308+
{
309+
quantMode += w4a16Mxfp4();
310+
}
311+
281312
return quantMode;
282313
}
283314

284315
static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
285316
{
286-
return fromDescription(
287-
true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false);
317+
return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false,
318+
false, false, false, false);
288319
}
289320

290321
static constexpr QuantMode useQServe(bool perGroup)
291322
{
292-
return fromDescription(
293-
true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false);
323+
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false,
324+
false, false, false);
294325
}
295326

296327
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
297328
{
298329
return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false,
299-
false, false, false);
330+
false, false, false, false, false);
300331
}
301332

302333
static QuantMode const fromQuantAlgo(
@@ -353,28 +384,38 @@ class QuantMode
353384
}
354385
else if (quantAlgo == "FP8")
355386
{
356-
quantMode = fromDescription(
357-
false, false, false, false, false, false, false, false, true, false, false, false, false, false);
387+
quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false,
388+
false, false, false, false, false);
358389
}
359390
else if (quantAlgo == "FP8_ROWWISE")
360391
{
361-
quantMode = fromDescription(
362-
false, false, true, true, false, false, false, false, false, true, false, false, false, false);
392+
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false,
393+
false, false, false, false);
363394
}
364395
else if (quantAlgo == "FP4")
365396
{
366-
quantMode = fromDescription(
367-
false, false, false, false, false, false, false, false, false, false, false, true, false, false);
397+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
398+
true, false, false, false, false);
368399
}
369400
else if (quantAlgo == "FP8_BLOCK_SCALES")
370401
{
371-
quantMode = fromDescription(
372-
false, false, false, false, false, false, false, false, false, false, false, false, true, false);
402+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
403+
false, true, false, false, false);
373404
}
374405
else if (quantAlgo == "W4A8_MXFP4_FP8")
375406
{
376-
quantMode = fromDescription(
377-
false, false, false, false, false, false, false, false, false, false, false, false, false, true);
407+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
408+
false, false, true, false, false);
409+
}
410+
else if (quantAlgo == "W4A8_MXFP4_MXFP8")
411+
{
412+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
413+
false, false, false, true, false);
414+
}
415+
else if (quantAlgo == "W4A16_MXFP4")
416+
{
417+
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
418+
false, false, false, false, true);
378419
}
379420

380421
if (kvCacheQuantAlgo == "INT8")

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,7 @@ class KvCacheConfig
10011001
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
10021002
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
10031003
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
1004+
SizeType32 attentionDpEventsGatherPeriodMs = 5,
10041005
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);
10051006

10061007
[[nodiscard]] bool getEnableBlockReuse() const;
@@ -1016,6 +1017,7 @@ class KvCacheConfig
10161017
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
10171018
[[nodiscard]] size_t getEventBufferMaxSize() const;
10181019
[[nodiscard]] bool getUseUvm() const;
1020+
[[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const;
10191021

10201022
void setEnableBlockReuse(bool enableBlockReuse);
10211023
void setEnablePartialReuse(bool enablePartialReuse);
@@ -1030,6 +1032,7 @@ class KvCacheConfig
10301032
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
10311033
void setEventBufferMaxSize(size_t eventBufferMaxSize);
10321034
void setUseUvm(bool useUvm);
1035+
void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs);
10331036

10341037
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults);
10351038

@@ -1085,6 +1088,9 @@ class KvCacheConfig
10851088

10861089
/// @brief Whether to use UVM for the KV cache.
10871090
bool mUseUvm;
1091+
1092+
/// @brief The period in milliseconds to gather attention DP events across ranks
1093+
SizeType32 mAttentionDpEventsGatherPeriodMs;
10881094
};
10891095

10901096
/// @brief Configuration class for the runtime perf knobs
@@ -1702,6 +1708,12 @@ struct KVCacheUpdatedData
17021708
explicit KVCacheUpdatedData(IdType blockHash)
17031709
: blockHash{blockHash} {};
17041710

1711+
explicit KVCacheUpdatedData(IdType blockHash, std::optional<KVCacheEventDiff<SizeType32>> cacheLevel,
1712+
std::optional<KVCacheEventDiff<SizeType32>> priority)
1713+
: blockHash{blockHash}
1714+
, cacheLevel{cacheLevel}
1715+
, priority{priority} {};
1716+
17051717
KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue)
17061718
{
17071719
cacheLevel = KVCacheEventDiff<SizeType32>{oldValue, newValue};
@@ -1726,15 +1738,17 @@ using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVC
17261738

17271739
struct KVCacheEvent
17281740
{
1729-
1730-
KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize);
1741+
KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize,
1742+
std::optional<SizeType32> attentionDpRank = std::nullopt);
17311743

17321744
/// @brief The unique id of this event
17331745
IdType eventId;
17341746
/// @brief The data corresponding to this event
17351747
KVCacheEventData data;
17361748
/// @brief The sliding window size
17371749
SizeType32 windowSize;
1750+
/// @brief The attention DP rank of the event, if applicable
1751+
std::optional<SizeType32> attentionDpRank;
17381752
};
17391753

17401754
/// @brief Exposes a limited set of KV cache manager functionalities

cpp/include/tensorrt_llm/executor/serialization.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,53 @@ class Serialization
302302
[[nodiscard]] static std::vector<RequestStatsPerIteration> deserializeRequestStatsPerIterationVec(
303303
std::vector<char>& buffer);
304304

305+
// KVCacheEvent deque
306+
[[nodiscard]] static std::vector<char> serialize(std::deque<KVCacheEvent> const& kvCacheEvents);
307+
[[nodiscard]] static std::deque<KVCacheEvent> deserializeKVCacheEvents(std::vector<char>& buffer);
308+
309+
// KVCacheEvent
310+
[[nodiscard]] static size_t serializedSize(KVCacheEvent const& event);
311+
static void serialize(KVCacheEvent const& event, std::ostream& os);
312+
[[nodiscard]] static KVCacheEvent deserializeKVCacheEvent(std::istream& is);
313+
314+
// KVCacheCreatedData
315+
[[nodiscard]] static size_t serializedSize(KVCacheCreatedData const& data);
316+
static void serialize(KVCacheCreatedData const& data, std::ostream& os);
317+
[[nodiscard]] static KVCacheCreatedData deserializeKVCacheCreatedData(std::istream& is);
318+
319+
// KVCacheStoredData
320+
[[nodiscard]] static size_t serializedSize(KVCacheStoredData const& data);
321+
static void serialize(KVCacheStoredData const& data, std::ostream& os);
322+
[[nodiscard]] static KVCacheStoredData deserializeKVCacheStoredData(std::istream& is);
323+
324+
// KVCacheStoredBlockData
325+
[[nodiscard]] static size_t serializedSize(KVCacheStoredBlockData const& data);
326+
static void serialize(KVCacheStoredBlockData const& data, std::ostream& os);
327+
[[nodiscard]] static KVCacheStoredBlockData deserializeKVCacheStoredBlockData(std::istream& is);
328+
329+
// KVCacheRemovedData
330+
[[nodiscard]] static size_t serializedSize(KVCacheRemovedData const& data);
331+
static void serialize(KVCacheRemovedData const& data, std::ostream& os);
332+
[[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is);
333+
334+
// KVCacheEventDiff
335+
template <typename T>
336+
[[nodiscard]] static size_t serializedSize(KVCacheEventDiff<T> const& data);
337+
template <typename T>
338+
static void serialize(KVCacheEventDiff<T> const& data, std::ostream& os);
339+
template <typename T>
340+
[[nodiscard]] static KVCacheEventDiff<T> deserializeKVCacheEventDiff(std::istream& is);
341+
342+
// KVCacheUpdateData
343+
[[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data);
344+
static void serialize(KVCacheUpdatedData const& data, std::ostream& os);
345+
[[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is);
346+
347+
// UniqueToken
348+
[[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token);
349+
static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os);
350+
[[nodiscard]] static tensorrt_llm::runtime::UniqueToken deserializeUniqueToken(std::istream& is);
351+
305352
// String
306353
static std::string deserializeString(std::istream& is);
307354

cpp/include/tensorrt_llm/runtime/utils/mpiTags.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ enum class MpiTag : int
6868
// LogitsThread
6969
kSpecDecLogitsId = 129,
7070
kSpecDecLogitsData = 1025,
71+
72+
// KvCacheEventManager
73+
kKvCacheEventSize = 1026,
74+
kKvCacheEvent = 1027
7175
};
7276

7377
} // namespace tensorrt_llm::mpi

cpp/kernels/fmha_v2/fmha_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def getSMVersion():
5050
ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
5151
@pytest.mark.parametrize('flag', [
5252
"-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
53-
"-softcapping-scale-bmm1 30", "-contiguous-q-kv"
53+
"-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks"
5454
])
5555
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
5656
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
@@ -117,8 +117,8 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
117117
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
118118
shell=True,
119119
check=True)
120-
# alibi and softcapping-scale-bmm1 are mutually exclusive.
121-
if '-softcapping-scale-bmm1' not in flag:
120+
# alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks.
121+
if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag:
122122
subprocess.run(
123123
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
124124
shell=True,

0 commit comments

Comments
 (0)