@@ -199,7 +199,7 @@ struct DMA
199
199
// The kv_offset_start.
200
200
int kv_offset_start = is_chunked_attention
201
201
? ((q_step_offset >> params.log2_chunked_attention_size ) << params.log2_chunked_attention_size )
202
- : max (0 , q_step_offset - params.sliding_window_size );
202
+ : max (0 , q_step_offset + 1 - params.sliding_window_size );
203
203
kv_idx_start = kv_offset_start / STEP_KV;
204
204
}
205
205
@@ -388,51 +388,6 @@ struct DMA
388
388
elect_one_, {-1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 });
389
389
}
390
390
391
- // Calculate the start tile idx.
392
- inline __device__ int remap_kv_tile_idx (
393
- int kv_tile_idx, int num_kv_cache_tiles, int past_kv_length, int sliding_window_size)
394
- {
395
-
396
- // The remapped kv tile idx.
397
- int remapped_kv_tile_idx = kv_tile_idx;
398
- // This will be removed later as the remapping will be handled by the kvCacheManger in TRTLLM.
399
- #ifdef GENERATE_CUBIN
400
- // Sliding window attention + chunked context needs special handling.
401
- if constexpr (SLIDING_OR_CHUNKED_ATTENTION)
402
- {
403
- // For chunked context (i.e. separate q and kv layout), the kv cache might be
404
- // overwritten after last chunk is processed.
405
- // To deal with this issue, the new tokens' kv will be appended to the kv cache first,
406
- // and overwrite the kv cache after FMHA is done.
407
- // The kv input layout is like: [cyclic kv cache] + [new tokens' kv].
408
- // There are two possible cases:
409
- // 1. The kv cache hasn't been overwritten while processing previous chunks, so we can
410
- // take it normally, where we have full kv cache.
411
- // 2. The kv cache has been overwritten while processing previous chunks. we need to
412
- // mask out the tokens in the kv cache based on the sliding window size. It needs
413
- // to track the last kv cache token's position in a circular way.
414
-
415
- // Remap the kv tile index when kv cache has been overwritten in a circular way.
416
- if (past_kv_length > sliding_window_size)
417
- {
418
- // Map the kv tile index to the new tokens' kv.
419
- if (kv_tile_idx * STEP_KV >= past_kv_length)
420
- {
421
- remapped_kv_tile_idx
422
- = num_kv_cache_tiles + int ((kv_tile_idx * STEP_KV - past_kv_length) / STEP_KV);
423
- }
424
- else
425
- {
426
- // Map the kv tile index to the cyclic kv cache.
427
- remapped_kv_tile_idx = kv_tile_idx % num_kv_cache_tiles;
428
- }
429
- }
430
- }
431
- #endif
432
- // Return the remapped kv tile idx.
433
- return remapped_kv_tile_idx;
434
- }
435
-
436
391
// Support contiguous Q + contiguous/paged KV separate cache.
437
392
inline __device__ void run_separate_q_and_kv (
438
393
bert::Fused_multihead_attention_params_v2 const & params, Shared* shared)
@@ -560,24 +515,20 @@ struct DMA
560
515
// Iterate over the kv tiles for this q step.
561
516
for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++)
562
517
{
563
- // Remap the kv tile idx if sliding window attention is enabled.
564
- // Sliding_window_size should be multiple of STEP_KV.
565
- int remapped_kv_step_idx = remap_kv_tile_idx (kv_step_idx, params.sliding_window_size / STEP_KV,
566
- past_kv_length, params.sliding_window_size );
567
518
// The barrier id.
568
519
int bar_id;
569
520
// Load paged kv input.
570
521
if constexpr (PAGED_KV_INPUT)
571
522
{
572
- bar_id = load_paged_kv (bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks,
523
+ bar_id = load_paged_kv (bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks,
573
524
params.paged_kv_cache .mTokensPerBlockLog2 , params.blocks_per_tma_load ,
574
525
params.blocks_per_tma_load_log2 , params.paged_kv_cache .mMaxBlocksPerSeq ,
575
526
paged_block_offsets, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
576
527
}
577
528
else
578
529
{
579
- bar_id = load_kv (bidh_kv, remapped_kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k,
580
- cbw_v, cbw_v_scratch);
530
+ bar_id = load_kv (
531
+ bidh_kv, kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
581
532
}
582
533
583
534
// Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor
0 commit comments