Skip to content

Commit 3b01fac

Browse files
authored
bugfix: fix blackwell fmha hanging issue for empty kv_len (#1198)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Cherry picked from cutlass v4.0 changes. ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @pavanimajety
1 parent 8c2d5ef commit 3b01fac

File tree

4 files changed

+182
-0
lines changed

4 files changed

+182
-0
lines changed

β€Žinclude/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hppβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace cutlass::fmha::collective {
4040

4141
template <class Element, class ElementAcc, class TileShape>
4242
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
43+
using ElementOut = Element;
4344
using Pipeline = cutlass::PipelineAsync<2>;
4445
// using ShapeT = cute::Shape<int32_t, int32_t, cute::Shape<int32_t, int32_t>>;
4546
// using StrideO = cute::Shape<int32_t, _1, cute::Shape<int32_t, int32_t>>;

β€Žinclude/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hppβ€Ž

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,72 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
10881088
pipeline_epi.producer_commit(pipeline_epi_producer_state);
10891089
++pipeline_epi_producer_state;
10901090
}
1091+
1092+
template <class BlkCoord, class ProblemShape, class ParamsProblemShape, class TensorStorageEpi,
1093+
class CollectiveEpilogue>
1094+
CUTLASS_DEVICE auto correction_empty(
1095+
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape,
1096+
ParamsProblemShape const& params_problem_shape, TensorStorageEpi& shared_storage_epi,
1097+
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
1098+
CollectiveEpilogue& epilogue) {
1099+
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
1100+
1101+
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()),
1102+
typename TensorStorageEpi::SmemLayoutO{});
1103+
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), epilogue.params.layout_LSE);
1104+
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
1105+
1106+
using ElementOut = typename CollectiveEpilogue::ElementOut;
1107+
auto tiled_copy = make_cotiled_copy(
1108+
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
1109+
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}),
1110+
Step<_1, _0>{}),
1111+
sO.layout());
1112+
1113+
auto thr_copy = tiled_copy.get_slice(thread_idx);
1114+
auto tOgO = thr_copy.partition_D(sO);
1115+
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_, _, _, _0{})));
1116+
clear(tOrO);
1117+
1118+
copy(tiled_copy, tOrO, tOgO(_, _, _, _0{}));
1119+
1120+
if (epilogue.params.ptr_LSE != nullptr) {
1121+
int qo_tile_idx = get<0>(blk_coord);
1122+
int qo_head_idx = get<2, 0>(blk_coord);
1123+
int batch_idx = get<2, 1>(blk_coord);
1124+
int qo_len = get<0>(problem_shape);
1125+
int segment_offset = get<0>(params_problem_shape).segment_offsets[batch_idx];
1126+
int row_idx = thread_idx + get<0>(TileShape{}) * qo_tile_idx;
1127+
1128+
if (row_idx < qo_len) {
1129+
gLSE(segment_offset + row_idx, qo_head_idx) = -cuda::std::numeric_limits<float>::infinity();
1130+
}
1131+
}
1132+
1133+
pipeline_epi.producer_commit(pipeline_epi_producer_state);
1134+
++pipeline_epi_producer_state;
1135+
1136+
copy(tiled_copy, tOrO, tOgO(_, _, _, _1{}));
1137+
cutlass::arch::fence_view_async_shared();
1138+
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
1139+
1140+
if (epilogue.params.ptr_LSE != nullptr) {
1141+
int qo_tile_idx = get<0>(blk_coord);
1142+
int qo_head_idx = get<2, 0>(blk_coord);
1143+
int batch_idx = get<2, 1>(blk_coord);
1144+
int qo_len = get<0>(problem_shape);
1145+
int segment_offset = get<0>(params_problem_shape).segment_offsets[batch_idx];
1146+
int row_idx = thread_idx + get<0>(TileShape{}) * qo_tile_idx + get<0>(TileShapeQK{});
1147+
1148+
if (row_idx < qo_len) {
1149+
gLSE(segment_offset + row_idx, qo_head_idx) = -cuda::std::numeric_limits<float>::infinity();
1150+
}
1151+
}
1152+
1153+
cutlass::arch::fence_view_async_shared();
1154+
pipeline_epi.producer_commit(pipeline_epi_producer_state);
1155+
++pipeline_epi_producer_state;
1156+
}
10911157
};
10921158

10931159
} // namespace cutlass::fmha::collective

β€Žinclude/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hppβ€Ž

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
380380
continue;
381381
}
382382

383+
if (get<1>(logical_problem_shape) == 0) { // kv_len == 0
384+
continue;
385+
}
386+
383387
bool is_softmax_0 = role == WarpRole::Softmax0;
384388

385389
mainloop.softmax(
@@ -404,6 +408,13 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
404408
continue;
405409
}
406410

411+
if (get<1>(logical_problem_shape) == 0) { // kv_len == 0
412+
mainloop.correction_empty(blk_coord, params.mainloop, logical_problem_shape,
413+
params.problem_shape, shared_storage.epilogue,
414+
pipeline_corr_epi, pipeline_corr_epi_producer_state, epilogue);
415+
continue;
416+
}
417+
407418
mainloop.correction(blk_coord, params.mainloop, params.problem_shape, logical_problem_shape,
408419
shared_storage.epilogue, pipeline_s0_corr,
409420
pipeline_s0_corr_consumer_state, pipeline_s1_corr,
@@ -437,6 +448,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
437448
continue;
438449
}
439450

451+
if (get<1>(logical_problem_shape) == 0) { // kv_len == 0
452+
continue;
453+
}
454+
440455
mainloop.mma(
441456
blk_coord, params.mainloop, logical_problem_shape, shared_storage.mainloop,
442457
pipeline_load_q, pipeline_load_q_consumer_state, pipeline_load_k,
@@ -461,6 +476,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
461476
continue;
462477
}
463478

479+
if (get<1>(logical_problem_shape) == 0) { // kv_len == 0
480+
work_idx++;
481+
continue;
482+
}
483+
464484
mainloop.load(blk_coord, logical_problem_shape, params.mainloop, params.problem_shape,
465485
shared_storage.mainloop, pipeline_load_q, pipeline_load_q_producer_state,
466486
pipeline_load_k, pipeline_load_k_producer_state, pipeline_load_v,
@@ -491,6 +511,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
491511
epilogue.store(blk_coord, logical_problem_shape, params.epilogue, params.problem_shape,
492512
shared_storage.epilogue, pipeline_corr_epi,
493513
pipeline_corr_epi_consumer_state);
514+
494515
work_idx++;
495516
}
496517

β€Žtests/test_blackwell_fmha.pyβ€Ž

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,89 @@ def test_blackwell_cutlass_varlen(
244244
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
245245

246246

247+
@pytest.mark.parametrize("qo_indptr_list", [[0, 10, 20, 30, 40, 50, 60, 100]])
248+
@pytest.mark.parametrize("kv_indptr_list", [[0, 50, 50, 50, 50, 50, 50, 50]])
249+
@pytest.mark.parametrize("num_qo_heads", [32])
250+
@pytest.mark.parametrize("num_kv_heads", [8, 32])
251+
@pytest.mark.parametrize("head_dim_qk", [192, 128])
252+
@pytest.mark.parametrize("head_dim_vo", [128])
253+
@pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)])
254+
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
255+
def test_blackwell_cutlass_qo_kv_varlen(
256+
qo_indptr_list,
257+
kv_indptr_list,
258+
num_qo_heads,
259+
num_kv_heads,
260+
head_dim_qk,
261+
head_dim_vo,
262+
sm_scale,
263+
dtype,
264+
):
265+
causal = False
266+
if not is_sm100a_supported(torch.device("cuda")):
267+
pytest.skip("SM100A is not supported on this device")
268+
torch.manual_seed(42)
269+
q = torch.randn(
270+
qo_indptr_list[-1],
271+
num_qo_heads,
272+
head_dim_qk,
273+
dtype=dtype,
274+
device="cuda",
275+
)
276+
k = torch.randn(
277+
kv_indptr_list[-1],
278+
num_kv_heads,
279+
head_dim_qk,
280+
dtype=dtype,
281+
device="cuda",
282+
)
283+
v = torch.randn(
284+
kv_indptr_list[-1],
285+
num_kv_heads,
286+
head_dim_vo,
287+
dtype=dtype,
288+
device="cuda",
289+
)
290+
291+
qo_indptr = torch.tensor(qo_indptr_list, device="cuda", dtype=torch.int32)
292+
kv_indptr = torch.tensor(kv_indptr_list, device="cuda", dtype=torch.int32)
293+
294+
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
295+
torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8),
296+
kv_layout="NHD",
297+
backend="cutlass",
298+
)
299+
300+
wrapper.plan(
301+
qo_indptr,
302+
kv_indptr,
303+
num_qo_heads,
304+
num_kv_heads,
305+
head_dim_qk,
306+
head_dim_vo=head_dim_vo,
307+
causal=causal,
308+
sm_scale=sm_scale,
309+
q_data_type=dtype,
310+
kv_data_type=dtype,
311+
)
312+
o, lse = wrapper.run(q, k, v, return_lse=True)
313+
314+
gqa_group_ratio = num_qo_heads // num_kv_heads
315+
k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1)
316+
v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1)
317+
318+
o_ref, lse_ref = attention_varlen_ref(
319+
q, k_repeated, v_repeated, qo_indptr, kv_indptr, causal, sm_scale
320+
)
321+
322+
if dtype == torch.half:
323+
torch.testing.assert_close(o[10:60], o_ref[10:60], rtol=1e-3, atol=1e-3)
324+
else:
325+
torch.testing.assert_close(o[10:60], o_ref[10:60], rtol=1e-2, atol=1e-2)
326+
327+
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
328+
329+
247330
if __name__ == "__main__":
248331
test_blackwell_cutlass_fmha(
249332
9,
@@ -268,3 +351,14 @@ def test_blackwell_cutlass_varlen(
268351
True,
269352
torch.bfloat16,
270353
)
354+
355+
test_blackwell_cutlass_qo_kv_varlen(
356+
[0, 10, 20, 30, 40, 50, 60, 100],
357+
[0, 50, 50, 50, 50, 50, 50, 50],
358+
32,
359+
8,
360+
128,
361+
128,
362+
1,
363+
torch.bfloat16,
364+
)

0 commit comments

Comments
Β (0)