Skip to content

Commit a77f93e

Browse files
symphonylyhlancelly
authored andcommitted
[TRTLLM-6674][feat] (Breaking Change) Hopper SWA non-cyclic kernels + KV reuse + Spec Dec (NVIDIA#6379)
Signed-off-by: Haohang Huang <[email protected]> Signed-off-by: symphonylyh <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent eefbaf2 commit a77f93e

File tree

140 files changed

+460
-466
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

140 files changed

+460
-466
lines changed

cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,6 @@ struct Gmem_tile_contiguous_kv
796796
template <typename Smem_tile>
797797
inline __device__ void load(Smem_tile& smem_tile)
798798
{
799-
// TODO(perkzz): add remap_kv_row for sliding window attention.
800799
uint32_t preds[LDGS];
801800
#pragma unroll
802801
for (int ii = 0; ii < LDGS; ++ii)
@@ -1091,42 +1090,6 @@ struct Gmem_tile_paged_kv
10911090
}
10921091
}
10931092

1094-
////////////////////////////////////////////////////////////////////////////////////////////////////
1095-
// Remap the row to the one in cyclic kv cache.
1096-
inline __device__ void remap_kv_row(int& row)
1097-
{
1098-
// Sliding window attention + chunked context needs special handling.
1099-
if constexpr (SLIDING_WINDOW_ATTENTION)
1100-
{
1101-
// For chunked context (i.e. separate q and kv layout), the kv cache might be overwritten
1102-
// after last chunk is processed.
1103-
// To deal with this issue, the new tokens' kv will be appended to the kv cache first, and
1104-
// overwrite the kv cache after FMHA is done.
1105-
// The kv input layout is like: [cyclic kv cache] + [new tokens' kv].
1106-
// There are two possible cases:
1107-
// 1. The kv cache hasn't been overwritten while processing previous chunks, so we can take
1108-
// it normally, where we have full kv cache.
1109-
// 2. The kv cache has been overwritten while processing previous chunks. we need to mask
1110-
// out the tokens in the kv cache based on the sliding window size. It needs to track the
1111-
// last kv cache token's position in a circular way.
1112-
1113-
// Remap the kv row when kv cache has been overwritten in a circular way.
1114-
if (past_seqlen_ > sliding_window_size_)
1115-
{
1116-
// Map the kv row to the new tokens' kv.
1117-
if (row >= past_seqlen_)
1118-
{
1119-
row = sliding_window_size_ + (row - past_seqlen_);
1120-
}
1121-
else
1122-
{
1123-
// Map the kv row to the cyclic kv cache.
1124-
row = row % sliding_window_size_;
1125-
}
1126-
}
1127-
}
1128-
}
1129-
11301093
// Load data from memory.
11311094
template <typename Smem_tile>
11321095
inline __device__ void load(Smem_tile& smem_tile)
@@ -1144,13 +1107,6 @@ struct Gmem_tile_paged_kv
11441107
for (int ii = 0; ii < LDGS; ++ii)
11451108
{
11461109
int row_idx = row_ + ii * (int) ROWS_PER_LDG;
1147-
1148-
// Remap row_idx if sliding window attention is used.
1149-
// This will be removed later as the remapping will be handled by the kvCacheManger in TRTLLM.
1150-
#ifdef GENERATE_CUBIN
1151-
remap_kv_row(row_idx);
1152-
#endif
1153-
11541110
int paged_kv_block_idx = (row_idx >> paged_kv_log2_block_size_);
11551111
char const* local_kv_ptr = reinterpret_cast<char*>(paged_kv_block_pool_ptr_
11561112
+ params_kv_block_size_in_bytes_ * paged_kv_global_block_offsets_[paged_kv_block_idx]);

cpp/kernels/fmha_v2/src/fmha/mask.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ struct Mask<Traits, Cta_tile, 4> : public Mask<Traits, Cta_tile, 3>
478478
inline __device__ bool is_valid(int row, int col) const
479479
{
480480
// Is it a valid position in the sequence, i.e. are we in the lower triangle?
481-
return (row >= col) && (col >= max(0, row - sliding_window_size_));
481+
return (row >= col) && (col >= max(0, row + 1 - sliding_window_size_));
482482
}
483483

484484
// The sliding window size.
@@ -946,7 +946,7 @@ struct Mask_hopper<Traits, Cta_tile, 4> : public Mask_hopper<Traits, Cta_tile, 3
946946
inline __device__ bool is_valid(int row, int col) const
947947
{
948948
// Is it a valid position in the sequence?
949-
return col <= row && col >= max(0, row - sliding_window_size_);
949+
return col <= row && col >= max(0, row + 1 - sliding_window_size_);
950950
}
951951

952952
// The sliding window size for attention.

cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ struct Compute
288288
// The kv_left_mask_end is the start of the chunk.
289289
kv_left_mask_end = div_up(is_chunked_attention
290290
? ((tile_offset_end >> params.log2_chunked_attention_size) << params.log2_chunked_attention_size)
291-
: (tile_offset_end - params.sliding_window_size),
291+
: (tile_offset_end + 1 - params.sliding_window_size),
292292
STEP_KV);
293293
}
294294

cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h

Lines changed: 4 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ struct DMA
199199
// The kv_offset_start.
200200
int kv_offset_start = is_chunked_attention
201201
? ((q_step_offset >> params.log2_chunked_attention_size) << params.log2_chunked_attention_size)
202-
: max(0, q_step_offset - params.sliding_window_size);
202+
: max(0, q_step_offset + 1 - params.sliding_window_size);
203203
kv_idx_start = kv_offset_start / STEP_KV;
204204
}
205205

@@ -388,51 +388,6 @@ struct DMA
388388
elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1});
389389
}
390390

391-
// Calculate the start tile idx.
392-
inline __device__ int remap_kv_tile_idx(
393-
int kv_tile_idx, int num_kv_cache_tiles, int past_kv_length, int sliding_window_size)
394-
{
395-
396-
// The remapped kv tile idx.
397-
int remapped_kv_tile_idx = kv_tile_idx;
398-
// This will be removed later as the remapping will be handled by the kvCacheManger in TRTLLM.
399-
#ifdef GENERATE_CUBIN
400-
// Sliding window attention + chunked context needs special handling.
401-
if constexpr (SLIDING_OR_CHUNKED_ATTENTION)
402-
{
403-
// For chunked context (i.e. separate q and kv layout), the kv cache might be
404-
// overwritten after last chunk is processed.
405-
// To deal with this issue, the new tokens' kv will be appended to the kv cache first,
406-
// and overwrite the kv cache after FMHA is done.
407-
// The kv input layout is like: [cyclic kv cache] + [new tokens' kv].
408-
// There are two possible cases:
409-
// 1. The kv cache hasn't been overwritten while processing previous chunks, so we can
410-
// take it normally, where we have full kv cache.
411-
// 2. The kv cache has been overwritten while processing previous chunks. we need to
412-
// mask out the tokens in the kv cache based on the sliding window size. It needs
413-
// to track the last kv cache token's position in a circular way.
414-
415-
// Remap the kv tile index when kv cache has been overwritten in a circular way.
416-
if (past_kv_length > sliding_window_size)
417-
{
418-
// Map the kv tile index to the new tokens' kv.
419-
if (kv_tile_idx * STEP_KV >= past_kv_length)
420-
{
421-
remapped_kv_tile_idx
422-
= num_kv_cache_tiles + int((kv_tile_idx * STEP_KV - past_kv_length) / STEP_KV);
423-
}
424-
else
425-
{
426-
// Map the kv tile index to the cyclic kv cache.
427-
remapped_kv_tile_idx = kv_tile_idx % num_kv_cache_tiles;
428-
}
429-
}
430-
}
431-
#endif
432-
// Return the remapped kv tile idx.
433-
return remapped_kv_tile_idx;
434-
}
435-
436391
// Support contiguous Q + contiguous/paged KV separate cache.
437392
inline __device__ void run_separate_q_and_kv(
438393
bert::Fused_multihead_attention_params_v2 const& params, Shared* shared)
@@ -560,24 +515,20 @@ struct DMA
560515
// Iterate over the kv tiles for this q step.
561516
for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++)
562517
{
563-
// Remap the kv tile idx if sliding window attention is enabled.
564-
// Sliding_window_size should be multiple of STEP_KV.
565-
int remapped_kv_step_idx = remap_kv_tile_idx(kv_step_idx, params.sliding_window_size / STEP_KV,
566-
past_kv_length, params.sliding_window_size);
567518
// The barrier id.
568519
int bar_id;
569520
// Load paged kv input.
570521
if constexpr (PAGED_KV_INPUT)
571522
{
572-
bar_id = load_paged_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks,
523+
bar_id = load_paged_kv(bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks,
573524
params.paged_kv_cache.mTokensPerBlockLog2, params.blocks_per_tma_load,
574525
params.blocks_per_tma_load_log2, params.paged_kv_cache.mMaxBlocksPerSeq,
575526
paged_block_offsets, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
576527
}
577528
else
578529
{
579-
bar_id = load_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k,
580-
cbw_v, cbw_v_scratch);
530+
bar_id = load_kv(
531+
bidh_kv, kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
581532
}
582533

583534
// Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor

cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ struct Softmax_base
134134
else
135135
{
136136
// The sliding window start is the max of 0 and row - sliding_window_size.
137-
return max(0, row - sliding_window_size_);
137+
return max(0, row + 1 - sliding_window_size_);
138138
}
139139
}
140140

cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ int main(int argc, char** argv)
15781578
}
15791579
else
15801580
{
1581-
valid = valid && (si >= std::max(int(so - sliding_window_size), 0));
1581+
valid = valid && (si >= std::max(int(so + 1 - sliding_window_size), 0));
15821582
}
15831583
}
15841584
if (is_mtp)

cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ inline __device__ void device_flash_attention_nl(Params const& params)
175175

176176
int const kv_loop_end = ((valid_seqlen + Cta_tile_p::N - 1) / Cta_tile_p::N) * Cta_tile_p::N;
177177
int const kv_loop_start = mask_sliding_window
178-
? (max(0, q_sequence_start - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
178+
? (max(0, q_sequence_start + 1 - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
179179
: 0;
180180
int const sliding_window_mask_end = mask_sliding_window
181-
? (max(0, q_sequence_start + Cta_tile_p::M - 1 - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
181+
? (max(0, q_sequence_start + Cta_tile_p::M - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
182182
: 0;
183183

184184
static_assert(Cta_tile_p::M >= Cta_tile_p::N, "");

cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_noloop_tiled.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ inline __device__ void device_flash_attention_nl_tiled(Params const& params)
176176

177177
int const kv_loop_end = ((valid_seqlen + Cta_tile_p::N - 1) / Cta_tile_p::N) * Cta_tile_p::N;
178178
int const kv_loop_start = mask_sliding_window
179-
? (max(0, q_sequence_start - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
179+
? (max(0, q_sequence_start + 1 - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
180180
: 0;
181181
int const sliding_window_mask_end = mask_sliding_window
182-
? (max(0, q_sequence_start + Cta_tile_p::M - 1 - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
182+
? (max(0, q_sequence_start + Cta_tile_p::M - params.sliding_window_size) / Cta_tile_p::N) * Cta_tile_p::N
183183
: 0;
184184

185185
// Move K and V tiles.

cpp/kernels/xqa/defines.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ static_assert(CACHE_ELEM_ENUM != 0);
162162
#define OPTIMIZE_FOR_LATENCY 1
163163
#endif
164164

165+
#ifndef IS_SPEC_DEC_TREE
166+
#define IS_SPEC_DEC_TREE 1 // by default SPEC_DEC expect tree-based draft token structure
167+
#endif
168+
165169
#define DBG_BATCH_SIZE 2
166170
#define DBG_SEQ_LEN 256 * 4 + 3
167171
#define DBG_NB_CTAS_PER_SEQ 8

cpp/kernels/xqa/mha.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1592,7 +1592,6 @@ CUBIN_EXPORT __global__
15921592
#endif
15931593

15941594
uint32_t const cacheSeqLen = getCacheSeqLen<usePagedKVCache>(cacheList, idxReq);
1595-
static_assert(!(allowSlidingWindow && useSpecDec), "Sliding window is not yet supported in spec-dec mode");
15961595
#if SLIDING_WINDOW
15971596
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
15981597
uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0;

0 commit comments

Comments
 (0)