Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions examples/linear_attention/example_linear_attn_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA


@tl.jit(out_idx=[4, 5, 6])
@tl.jit(
out_idx=[4, 5, 6],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def chunk_linear_attn_bwd_kernel(
B,
S,
Expand All @@ -26,21 +31,21 @@ def chunk_linear_attn_bwd_kernel(
accum_dtype = 'float'

chunk_size = 64
BK = BV = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding a comment here is good, but it would be better to quantify the numerical differences with FLA, or provide a link to a discussion or issue where these differences are analyzed. This helps future readers understand the trade-offs involved in choosing different block sizes.

BK = BV = 64  # Set to 128 can be faster, but has some numerical differences with FLA. See [link to discussion/issue] for details.

assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK)
NV = tl.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size)

@T.prim_func
def main(
Q: T.Tensor([B, S, H, DK], dtype),
K: T.Tensor([B, S, H, DK], dtype),
V: T.Tensor([B, S, H, DV], dtype),
dO: T.Tensor([B, S, H, DV], dtype),
dQ: T.Tensor([NV, B, S, H, DK], dtype),
dK: T.Tensor([NV, B, S, H, DK], dtype),
dV: T.Tensor([NK, B, S, H, DV], dtype),
def chunk_linear_attn_bwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
dO: T.Tensor([B, S, H, DV], dtype), # type: ignore
dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore
dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore
dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
):
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H
Expand Down Expand Up @@ -71,6 +76,7 @@ def main(
h_shared: tl.layout.make_swizzled_layout(h_shared),
dh_shared: tl.layout.make_swizzled_layout(dh_shared)
})
T.use_swizzle(10)

# Calculate dQ
for i in T.Pipelined(0, NT, num_stages=1):
Expand Down Expand Up @@ -107,7 +113,6 @@ def main(
T.copy(
dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], do)
T.copy(dh, dh_shared)

# Calculate dk
T.gemm(
Expand All @@ -116,6 +121,7 @@ def main(
for row, col in T.Parallel(chunk_size, chunk_size):
ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0)
T.gemm(ds_shared, q, dk, clear_accum=True)
T.copy(dh, dh_shared)
T.gemm(v, dh_shared, dk, transpose_B=True)

# Calculate dv
Expand All @@ -135,7 +141,7 @@ def main(
dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV])

return main
return chunk_linear_attn_bwd


def postprocess(dQ, dK, dV):
Expand All @@ -148,8 +154,8 @@ def postprocess(dQ, dK, dV):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=2048, help='Seq len')
parser.add_argument('--H', type=int, default=64, help='Num heads')
parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=256, help='Head dim')
args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D
Expand All @@ -161,15 +167,15 @@ def main():

kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D)
dq, dk, dv = postprocess(*kernel(q, k, v, do))
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
o_ref.backward(do, retain_graph=True)
if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad):
print('Passed all tests!✅')
else:
print('Failed some tests!❌')
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100)
q.grad = k.grad = v.grad = None
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100)
print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms')
Expand Down
32 changes: 20 additions & 12 deletions examples/linear_attention/example_linear_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA


@tl.jit(out_idx=[3, 4])
@tl.jit(
out_idx=[3, 4],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def chunk_linear_attn_fwd_kernel(
B,
S,
Expand All @@ -26,16 +31,19 @@ def chunk_linear_attn_fwd_kernel(
accum_dtype = 'float'

chunk_size = 64
BK = BV = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding a comment here is good, but it would be better to quantify the numerical differences with FLA, or provide a link to a discussion or issue where these differences are analyzed. This helps future readers understand the trade-offs involved in choosing different block sizes.

BK = BV = 64  # Set to 128 can be faster, but has some numerical differences with FLA. See [link to discussion/issue] for details.

assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK)
NV = tl.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size)

@T.prim_func
def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype),
V: T.Tensor([B, S, H, DV], dtype), O: T.Tensor([NK, B, S, H, DV], dtype),
final_state: T.Tensor([B, H, DK, DV], accum_dtype)):
def chunk_linear_attn_fwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H
i_h = i_bh % H
Expand All @@ -57,9 +65,9 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(8)
T.use_swizzle(10)

for i in T.Pipelined(0, NT, num_stages=1):
for i in T.Pipelined(0, NT, num_stages=2):
for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
Expand All @@ -71,16 +79,16 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype),

T.gemm(s_shared, v, o, clear_accum=True)
T.copy(h, h_shared)
T.gemm(q, h_shared, o)
T.gemm(k, v, h, transpose_A=True)
T.gemm(q, h_shared, o)
T.copy(
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV])

# Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])

return main
return chunk_linear_attn_fwd


def postprocess(o, h):
Expand All @@ -91,8 +99,8 @@ def postprocess(o, h):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=2048, help='Seq len')
parser.add_argument('--H', type=int, default=64, help='Num heads')
parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=256, help='Head dim')
args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D
Expand All @@ -114,7 +122,7 @@ def main():
lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0],
warmup=25,
rep=100)
t2 = do_bench(lambda: kernel(q, k, v)[0].sum(0), warmup=25, rep=100)
t2 = do_bench(lambda: postprocess(*kernel(q, k, v)), warmup=25, rep=100)
print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms')
print(f'Speedup: {t1/t2:.3f}x')
Expand Down
122 changes: 122 additions & 0 deletions examples/linear_attention/example_retention_fwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.

import torch
import tilelang as tl
import tilelang.language as T
from tilelang.profiler import do_bench

import argparse


@tl.jit(out_idx=3, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def chunk_retention_fwd_kernel(
B,
S,
H,
DK,
DV,
dtype: str = 'float16',
scale: float = None,
) -> torch.Tensor:

if scale is None:
scale = DK**-0.5
accum_dtype = 'float'

chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK)
NV = tl.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size)

@T.prim_func
def chunk_retention_fwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
):
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H
i_h = i_bh % H
log_decay = T.alloc_var('float32')
log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay

q = T.alloc_shared([chunk_size, BK], dtype)
k = T.alloc_shared([chunk_size, BK], dtype)
v = T.alloc_shared([chunk_size, BV], dtype)
h = T.alloc_fragment([BK, BV], accum_dtype)
h_shared = T.alloc_shared([BK, BV], dtype)
s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)
s_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
T.clear(h)

T.annotate_layout({
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(10)

for i in T.Pipelined(0, NT):
for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)

T.gemm(q, k, s, clear_accum=True, transpose_B=True)
for row, col in T.Parallel(chunk_size, chunk_size):
s_shared[row,
col] = T.if_then_else(row >= col, s[row, col] * T.exp2(
(row - col) * log_decay), 0)

T.copy(h, h_shared)
T.gemm(q, h_shared, o, clear_accum=True)
for row, col in T.Parallel(chunk_size, BV):
o[row, col] = T.exp2((row + 1) * log_decay) * o[row, col]
T.gemm(s_shared, v, o)

for row, col in T.Parallel(chunk_size, BV):
v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay)
for row, col in T.Parallel(BK, BV):
h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col]
T.copy(
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV])
T.gemm(k, v, h, transpose_A=True)

return chunk_retention_fwd


def postprocess(o):
return o if o.size(0) == 1 else o.sum(0)


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=128, help='Head dim')
args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D
total_flops = 2.0 * B * S * S * H * D # causal

q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)

kernel = chunk_retention_fwd_kernel(B, S, H, D, D)

t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100)
print(f'Tilelang latency: {t:.3f} ms')
print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}')


if __name__ == '__main__':
main()
Loading