From 42a1debef43b11ca06f4921ef3baa287b8547118 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 9 Jul 2025 07:40:46 +0000 Subject: [PATCH 1/3] Tune linear attention examples on H100 --- .../example_linear_attn_bwd.py | 36 ++++++++++--------- .../example_linear_attn_fwd.py | 32 +++++++++-------- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index b0db08ed8..42b692c99 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -10,7 +10,9 @@ from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA -@tl.jit(out_idx=[4, 5, 6]) +@tl.jit(out_idx=[4, 5, 6], + pass_configs={"tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True}) def chunk_linear_attn_bwd_kernel( B, S, @@ -26,21 +28,21 @@ def chunk_linear_attn_bwd_kernel( accum_dtype = 'float' chunk_size = 64 - BK = BV = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 NK = tl.cdiv(DK, BK) NV = tl.cdiv(DV, BV) NT = tl.cdiv(S, chunk_size) @T.prim_func - def main( - Q: T.Tensor([B, S, H, DK], dtype), - K: T.Tensor([B, S, H, DK], dtype), - V: T.Tensor([B, S, H, DV], dtype), - dO: T.Tensor([B, S, H, DV], dtype), - dQ: T.Tensor([NV, B, S, H, DK], dtype), - dK: T.Tensor([NV, B, S, H, DK], dtype), - dV: T.Tensor([NK, B, S, H, DV], dtype), + def chunk_linear_attn_bwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + dO: T.Tensor([B, S, H, DV], dtype), # type: ignore + dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore + dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore + dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H @@ -71,6 +73,7 @@ def main( h_shared: tl.layout.make_swizzled_layout(h_shared), dh_shared: tl.layout.make_swizzled_layout(dh_shared) }) + T.use_swizzle(10) # Calculate dQ for i in T.Pipelined(0, NT, num_stages=1): @@ -107,7 +110,7 @@ def main( T.copy( dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], do) - T.copy(dh, dh_shared) + # Calculate dk T.gemm( @@ -116,6 +119,7 @@ def main( for row, col in T.Parallel(chunk_size, chunk_size): ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) T.gemm(ds_shared, q, dk, clear_accum=True) + T.copy(dh, dh_shared) T.gemm(v, dh_shared, dk, transpose_B=True) # Calculate dv @@ -135,7 +139,7 @@ def main( dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV]) - return main + return chunk_linear_attn_bwd def postprocess(dQ, dK, dV): @@ -148,8 +152,8 @@ def postprocess(dQ, dK, dV): def main(): parser = argparse.ArgumentParser() parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=2048, help='Seq len') - parser.add_argument('--H', type=int, default=64, help='Num heads') + parser.add_argument('--S', type=int, default=4096, help='Seq len') + parser.add_argument('--H', type=int, default=32, help='Num heads') parser.add_argument('--D', type=int, default=256, help='Head dim') args = parser.parse_args() B, S, H, D = args.B, args.S, args.H, args.D @@ -161,7 +165,7 @@ def main(): kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D) dq, dk, dv = postprocess(*kernel(q, k, v, do)) - o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) + o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) o_ref.backward(do, retain_graph=True) if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad): print('Passed all tests!✅') @@ -169,7 +173,7 @@ def main(): print('Failed some tests!❌') t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100) q.grad = k.grad = v.grad = None - o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) + o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100) print(f'Triton latency: {t1:.3f} ms') print(f'TileLang latency: {t2:.3f} ms') diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index afba81a02..23cbed4c8 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -10,7 +10,9 @@ from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA -@tl.jit(out_idx=[3, 4]) +@tl.jit(out_idx=[3, 4], + pass_configs={"tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True}) def chunk_linear_attn_fwd_kernel( B, S, @@ -26,16 +28,18 @@ def chunk_linear_attn_fwd_kernel( accum_dtype = 'float' chunk_size = 64 - BK = BV = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 NK = tl.cdiv(DK, BK) NV = tl.cdiv(DV, BV) NT = tl.cdiv(S, chunk_size) @T.prim_func - def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), - V: T.Tensor([B, S, H, DV], dtype), O: T.Tensor([NK, B, S, H, DV], dtype), - final_state: T.Tensor([B, H, DK, DV], accum_dtype)): + def chunk_linear_attn_fwd(Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H @@ -57,9 +61,9 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), h_shared: tl.layout.make_swizzled_layout(h_shared), s_shared: tl.layout.make_swizzled_layout(s_shared), }) - T.use_swizzle(8) + T.use_swizzle(10) - for i in T.Pipelined(0, NT, num_stages=1): + for i in T.Pipelined(0, NT, num_stages=2): for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) @@ -69,10 +73,10 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), for row, col in T.Parallel(chunk_size, chunk_size): s_shared[row, col] = T.if_then_else(row >= col, s[row, col], 0) - T.gemm(s_shared, v, o, clear_accum=True) + T.gemm(s_shared, v, o, clear_accum=True) T.copy(h, h_shared) - T.gemm(q, h_shared, o) T.gemm(k, v, h, transpose_A=True) + T.gemm(q, h_shared, o) T.copy( o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV]) @@ -80,10 +84,10 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), # Output final state T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) - return main + return chunk_linear_attn_fwd -def postprocess(o, h): +def postprocess(o, h): o = o[0] if o.size(0) == 1 else o.sum(0) return o, h @@ -91,8 +95,8 @@ def postprocess(o, h): def main(): parser = argparse.ArgumentParser() parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=2048, help='Seq len') - parser.add_argument('--H', type=int, default=64, help='Num heads') + parser.add_argument('--S', type=int, default=4096, help='Seq len') + parser.add_argument('--H', type=int, default=32, help='Num heads') parser.add_argument('--D', type=int, default=256, help='Head dim') args = parser.parse_args() B, S, H, D = args.B, args.S, args.H, args.D @@ -114,7 +118,7 @@ def main(): lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0], warmup=25, rep=100) - t2 = do_bench(lambda: kernel(q, k, v)[0].sum(0), warmup=25, rep=100) + t2 = do_bench(lambda: postprocess(*kernel(q, k, v)), warmup=25, rep=100) print(f'Triton latency: {t1:.3f} ms') print(f'TileLang latency: {t2:.3f} ms') print(f'Speedup: {t1/t2:.3f}x') From 21c8dc5abb72837f61f9191601fea9aa9d9b8f61 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 9 Jul 2025 07:41:30 +0000 Subject: [PATCH 2/3] Add retnet fwd kernel --- .../linear_attention/example_retention_fwd.py | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 examples/linear_attention/example_retention_fwd.py diff --git a/examples/linear_attention/example_retention_fwd.py b/examples/linear_attention/example_retention_fwd.py new file mode 100644 index 000000000..4be2b23fc --- /dev/null +++ b/examples/linear_attention/example_retention_fwd.py @@ -0,0 +1,126 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +import torch +import tilelang as tl +import tilelang.language as T +from tilelang.profiler import do_bench + +import argparse + + +@tl.jit(out_idx=3, + pass_configs={"tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True}) +def chunk_retention_fwd_kernel( + B, + S, + H, + DK, + DV, + dtype: str = 'float16', + scale: float = None, +) -> torch.Tensor: + + if scale is None: + scale = DK**-0.5 + accum_dtype = 'float' + + chunk_size = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tl.cdiv(DK, BK) + NV = tl.cdiv(DV, BV) + NT = tl.cdiv(S, chunk_size) + + @T.prim_func + def chunk_retention_fwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + ): + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + log_decay = T.alloc_var('float32') + log_decay = T.log2(1 - T.exp2(-5. - 1.* i_h)) # Head-specific log decay + + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BK, BV], accum_dtype) + h_shared = T.alloc_shared([BK, BV], dtype) + s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + o = T.alloc_fragment([chunk_size, BV], accum_dtype) + T.clear(h) + + T.annotate_layout({ + q: tl.layout.make_swizzled_layout(q), + k: tl.layout.make_swizzled_layout(k), + v: tl.layout.make_swizzled_layout(v), + h_shared: tl.layout.make_swizzled_layout(h_shared), + s_shared: tl.layout.make_swizzled_layout(s_shared), + }) + T.use_swizzle(10) + + for i in T.Pipelined(0, NT): + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + + T.gemm(q, k, s, clear_accum=True, transpose_B=True) + for row, col in T.Parallel(chunk_size, chunk_size): + s_shared[row, col] = T.if_then_else( + row >= col, + s[row, col] * T.exp2((row-col) * log_decay), + 0 + ) + + T.copy(h, h_shared) + T.gemm(q, h_shared, o, clear_accum=True) + for row, col in T.Parallel(chunk_size, BV): + o[row, col] = T.exp2((row + 1) * log_decay) * o[row, col] + T.gemm(s_shared, v, o) + + for row, col in T.Parallel(chunk_size, BV): + v[row, col] = v[row,col] * T.exp2((chunk_size - row - 1) * log_decay) + for row, col in T.Parallel(BK, BV): + h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] + T.copy( + o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, + i_v * BV:(i_v + 1) * BV]) + T.gemm(k, v, h, transpose_A=True) + + return chunk_retention_fwd + + +def postprocess(o): + return o if o.size(0) == 1 else o.sum(0) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--B', type=int, default=8, help='Batch size') + parser.add_argument('--S', type=int, default=4096, help='Seq len') + parser.add_argument('--H', type=int, default=32, help='Num heads') + parser.add_argument('--D', type=int, default=128, help='Head dim') + args = parser.parse_args() + B, S, H, D = args.B, args.S, args.H, args.D + total_flops = 2.0 * B * S * S * H * D # causal + + q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + + kernel = chunk_retention_fwd_kernel(B, S, H, D, D) + + t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) + print(f'Tilelang latency: {t:.3f} ms') + print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}') + + +if __name__ == '__main__': + main() From f1c6e70aefd64b55635dea591b567251e10f5b55 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 9 Jul 2025 07:43:54 +0000 Subject: [PATCH 3/3] fix lint --- .../example_linear_attn_bwd.py | 10 +- .../example_linear_attn_fwd.py | 24 ++- .../linear_attention/example_retention_fwd.py | 32 ++- examples/linear_attention/example_retnet.py | 194 ------------------ 4 files changed, 34 insertions(+), 226 deletions(-) delete mode 100644 examples/linear_attention/example_retnet.py diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 42b692c99..a944a9a20 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -10,9 +10,12 @@ from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA -@tl.jit(out_idx=[4, 5, 6], - pass_configs={"tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True}) +@tl.jit( + out_idx=[4, 5, 6], + pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True + }) def chunk_linear_attn_bwd_kernel( B, S, @@ -110,7 +113,6 @@ def chunk_linear_attn_bwd( T.copy( dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], do) - # Calculate dk T.gemm( diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 23cbed4c8..734c54c8c 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -10,9 +10,12 @@ from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA -@tl.jit(out_idx=[3, 4], - pass_configs={"tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True}) +@tl.jit( + out_idx=[3, 4], + pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True + }) def chunk_linear_attn_fwd_kernel( B, S, @@ -35,11 +38,12 @@ def chunk_linear_attn_fwd_kernel( NT = tl.cdiv(S, chunk_size) @T.prim_func - def chunk_linear_attn_fwd(Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore - final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore + def chunk_linear_attn_fwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H @@ -73,7 +77,7 @@ def chunk_linear_attn_fwd(Q: T.Tensor([B, S, H, DK], dtype), # type: ignore for row, col in T.Parallel(chunk_size, chunk_size): s_shared[row, col] = T.if_then_else(row >= col, s[row, col], 0) - T.gemm(s_shared, v, o, clear_accum=True) + T.gemm(s_shared, v, o, clear_accum=True) T.copy(h, h_shared) T.gemm(k, v, h, transpose_A=True) T.gemm(q, h_shared, o) @@ -87,7 +91,7 @@ def chunk_linear_attn_fwd(Q: T.Tensor([B, S, H, DK], dtype), # type: ignore return chunk_linear_attn_fwd -def postprocess(o, h): +def postprocess(o, h): o = o[0] if o.size(0) == 1 else o.sum(0) return o, h diff --git a/examples/linear_attention/example_retention_fwd.py b/examples/linear_attention/example_retention_fwd.py index 4be2b23fc..6d44a9160 100644 --- a/examples/linear_attention/example_retention_fwd.py +++ b/examples/linear_attention/example_retention_fwd.py @@ -9,9 +9,7 @@ import argparse -@tl.jit(out_idx=3, - pass_configs={"tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True}) +@tl.jit(out_idx=3, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) def chunk_retention_fwd_kernel( B, S, @@ -35,16 +33,16 @@ def chunk_retention_fwd_kernel( @T.prim_func def chunk_retention_fwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore - ): + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H log_decay = T.alloc_var('float32') - log_decay = T.log2(1 - T.exp2(-5. - 1.* i_h)) # Head-specific log decay + log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay q = T.alloc_shared([chunk_size, BK], dtype) k = T.alloc_shared([chunk_size, BK], dtype) @@ -73,20 +71,18 @@ def chunk_retention_fwd( T.gemm(q, k, s, clear_accum=True, transpose_B=True) for row, col in T.Parallel(chunk_size, chunk_size): - s_shared[row, col] = T.if_then_else( - row >= col, - s[row, col] * T.exp2((row-col) * log_decay), - 0 - ) + s_shared[row, + col] = T.if_then_else(row >= col, s[row, col] * T.exp2( + (row - col) * log_decay), 0) T.copy(h, h_shared) T.gemm(q, h_shared, o, clear_accum=True) for row, col in T.Parallel(chunk_size, BV): o[row, col] = T.exp2((row + 1) * log_decay) * o[row, col] - T.gemm(s_shared, v, o) + T.gemm(s_shared, v, o) for row, col in T.Parallel(chunk_size, BV): - v[row, col] = v[row,col] * T.exp2((chunk_size - row - 1) * log_decay) + v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay) for row, col in T.Parallel(BK, BV): h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] T.copy( @@ -97,7 +93,7 @@ def chunk_retention_fwd( return chunk_retention_fwd -def postprocess(o): +def postprocess(o): return o if o.size(0) == 1 else o.sum(0) @@ -120,7 +116,7 @@ def main(): t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) print(f'Tilelang latency: {t:.3f} ms') print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}') - + if __name__ == '__main__': main() diff --git a/examples/linear_attention/example_retnet.py b/examples/linear_attention/example_retnet.py deleted file mode 100644 index 0b05ed3da..000000000 --- a/examples/linear_attention/example_retnet.py +++ /dev/null @@ -1,194 +0,0 @@ -import argparse -import torch -import tilelang -import tilelang.language as T - - -@tilelang.jit(out_idx=[4]) -def retnet(batch, heads, seq_len, dim_qk, dim_v, block_M, block_N): - qk_shape = [batch, seq_len, heads, dim_qk] - v_shape = [batch, seq_len, heads, dim_v] - dtype = "float16" - accum_dtype = "float" - - @T.prim_func - def main( - Q: T.Tensor(qk_shape, dtype), - K: T.Tensor(qk_shape, dtype), - V: T.Tensor(v_shape, dtype), - mask: T.Tensor([heads, seq_len, seq_len], dtype), - Output: T.Tensor(v_shape, dtype), - ): - with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128 * 2) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim_qk], dtype) - K_shared = T.alloc_shared([block_N, dim_qk], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - mask_shared = T.alloc_shared([block_M, block_N], dtype) - acc_o_shared = T.alloc_shared([block_M, dim_v], dtype) - mask_local = T.alloc_fragment([block_M, block_N], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_1 = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_shared = T.alloc_shared([block_M, block_N], dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) - abs_sum = T.alloc_fragment([block_M], accum_dtype) - r_wo_clamp = T.alloc_fragment([block_M], accum_dtype) - r = T.alloc_fragment([block_M], accum_dtype) - r_new = T.alloc_fragment([block_M], accum_dtype) - - T.annotate_layout({ - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - mask_shared: tilelang.layout.make_swizzled_layout(mask_shared), - acc_s_shared: tilelang.layout.make_swizzled_layout(acc_s_shared), - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared) - }) - - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) - - T.fill(r, 0) - T.fill(r_new, 0) - T.fill(r_wo_clamp, 0) - T.fill(acc_o, 0) - loop_range = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.copy(mask[by, bx * block_M:(bx + 1) * block_M, k * block_N:(k + 1) * block_N], - mask_shared) - T.copy(mask_shared, mask_local) - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = acc_s[i, j] * mask_local[i, j] - T.copy(acc_s, acc_s_shared) - T.copy(acc_s_shared, acc_s_1) - T.reduce_abssum(acc_s_1, abs_sum, dim=1) - for i in T.Parallel(block_M): - r_wo_clamp[i] = r_wo_clamp[i] + abs_sum[i] - for i in T.Parallel(block_M): - r_new[i] = T.max(r_wo_clamp[i], 1) - for i, j in T.Parallel(block_M, dim_v): - acc_o[i, j] = T.if_then_else(k > 0, acc_o[i, j] * r[i] / r_new[i], acc_o[i, j]) - T.copy(r_new, r) - for i, j in T.Parallel(block_M, block_N): - acc_s_1[i, j] = acc_s_1[i, j] / r_new[i] - T.copy(acc_s_1, acc_s_cast) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) - T.copy(acc_o, acc_o_shared) - T.copy(acc_o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - - return main - - -def ref_program(Q, K, V, mask): - qk = torch.einsum('bqhd,bkhd->bhqk', Q, K) - qkm = qk * mask - r = qkm.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1.0) - o = torch.einsum('bhqk,bkhd->bqhd', qkm / r, V) - return o.to(dtype=torch.float16) - - -def ref_inference(Q, K, V, prev_kv, prev_scale, decay): - # Q : batch, seqlen, num_heads, head_dimqk - # K : batch, seqlen, num_heads, head_dimqk - # V : batch, seqlen, num_heads, head_dimv - # prev_kv : batch, num_heads, head_dimv, head_dimqk - # prev_scale : num_heads, 1, 1 - # decay : num_heads, 1, 1 - seqlen = V.size(1) - num_heads = V.size(2) - assert seqlen == 1, "Only support seqlen == 1" - - qr = Q.transpose(1, 2).contiguous() # batch, num_heads, 1, head_dimqk - kr = K.transpose(1, 2).contiguous() # batch, num_heads, 1, head_dimqk - v = V.transpose(1, 2).transpose(2, 3).contiguous() # batch, num_heads, head_dimv, 1 - - kv = kr * v # batch, num_heads, head_dimv, head_dimqk - scale = prev_scale * decay + 1 # num_heads, 1, 1 - kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view( - num_heads, 1, 1) + kv / scale.sqrt().view(num_heads, 1, 1) - output = torch.sum(qr * kv, dim=3) - return output - - -def retnet_inference(batch, heads, dim_qk, dim_v, block_M): - qk_shape = [batch, 1, heads, dim_qk] - v_shape = [batch, 1, heads, dim_v] - dtype = "float16" - accum_dtype = "float" - - @T.prim_func - def main( - Q: T.Tensor(qk_shape, dtype), - K: T.Tensor(qk_shape, dtype), - V: T.Tensor(v_shape, dtype), - prev_kv: T.Tensor([batch, heads, dim_v, dim_qk], dtype), - prev_scale: T.Tensor([heads], dtype), - decay: T.Tensor([heads], dtype), - Output: T.Tensor([batch, heads, dim_v], dtype), - ): - with T.Kernel(T.ceildiv(dim_v, block_M), heads, batch, threads=128) as (bx, by, bz): - Q_local = T.alloc_fragment([1, dim_qk], dtype) - K_local = T.alloc_fragment([dim_qk], dtype) - V_local = T.alloc_fragment([block_M], dtype) - kv_local = T.alloc_fragment([block_M, dim_qk], accum_dtype) - prev_kv_local = T.alloc_fragment([block_M, dim_qk], dtype) - prev_scale_local = T.alloc_fragment([1], dtype) - decay_local = T.alloc_fragment([1], accum_dtype) - # scale_local = T.alloc_fragment([1], accum_dtype) - qkv_local = T.alloc_fragment([block_M, dim_qk], accum_dtype) - o_local = T.alloc_fragment([block_M], accum_dtype) - - T.annotate_layout({ - prev_scale_local: T.Layout(prev_scale_local.shape, lambda i: i), - decay_local: T.Layout(decay_local.shape, lambda i: i), - # scale_local: T.Layout(scale_local.shape, lambda i : i), - kv_local: T.Fragment(kv_local.shape, lambda i, j: j // 8), - }) - - T.copy(Q[bz, 0, by, :], Q_local) - T.copy(K[bz, 0, by, :], K_local) - T.copy(V[bz, 0, by, bx * block_M:(bx + 1) * block_M], V_local) - T.copy(prev_kv[bz, by, bx * block_M:(bx + 1) * block_M, :], prev_kv_local) - prev_scale_local[0] = prev_scale[by] - decay_local[0] = decay[by] - for i, j in T.Parallel(block_M, dim_qk): - kv_local[i, j] = K_local[j] * V_local[i] - for i, j in T.Parallel(block_M, dim_qk): - kv_local[i, j] += kv_local[i, j] - for i, j in T.Parallel(block_M, dim_qk): - kv_local[i, j] += prev_kv_local[i, j] * T.sqrt(prev_scale[by]) * decay[by] - for i, j in T.Parallel(block_M, dim_qk): - kv_local[i, j] = kv_local[i, j] / T.sqrt(prev_scale[by] * decay[by] + 1) - for i, j in T.Parallel(block_M, dim_qk): - qkv_local[i, j] = Q_local[0, j] * kv_local[i, j] - T.reduce_sum(qkv_local, o_local, dim=1) - T.copy(o_local, Output[bz, by, bx * block_M:(bx + 1) * block_M]) - - return main - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=10, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--dim_qk', type=int, default=256, help='Head dimension') - parser.add_argument('--dim_v', type=int, default=448, help='Head dimension') - args = parser.parse_args() - BATCH, H, N_CTX, dim_qk, dim_v = args.batch, args.h, args.n_ctx, args.dim_qk, args.dim_v - total_flops = 2.0 * BATCH * H * N_CTX * N_CTX * (dim_qk + dim_v) - BLOCK_M = 64 - BLOCK_N = 64 - kernel = retnet(BATCH, H, N_CTX, dim_qk, dim_v, BLOCK_M, BLOCK_N) - profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) - - ins = profiler._get_inputs() - - ref_outs = ref_program(*ins) - lib_outs = kernel(*ins) - - profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) - latency = profiler.do_bench(n_warmup=10, n_repeat=10, profiler="torch") - print("tilelang: {:.2f} ms".format(latency)) - print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))