-
Notifications
You must be signed in to change notification settings - Fork 254
[Dev] Update linear attention examples to enhance performance on Hopper GPUs #621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,12 @@ | |
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA | ||
|
||
|
||
@tl.jit(out_idx=[3, 4]) | ||
@tl.jit( | ||
out_idx=[3, 4], | ||
pass_configs={ | ||
"tl.disable_tma_lower": True, | ||
"tl.disable_warp_specialized": True | ||
}) | ||
def chunk_linear_attn_fwd_kernel( | ||
B, | ||
S, | ||
|
@@ -26,16 +31,19 @@ def chunk_linear_attn_fwd_kernel( | |
accum_dtype = 'float' | ||
|
||
chunk_size = 64 | ||
BK = BV = 64 | ||
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding a comment here is good, but it would be better to quantify the numerical differences with FLA, or provide a link to a discussion or issue where these differences are analyzed. This helps future readers understand the trade-offs involved in choosing different block sizes. BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA. See [link to discussion/issue] for details. |
||
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 | ||
NK = tl.cdiv(DK, BK) | ||
NV = tl.cdiv(DV, BV) | ||
NT = tl.cdiv(S, chunk_size) | ||
|
||
@T.prim_func | ||
def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), | ||
V: T.Tensor([B, S, H, DV], dtype), O: T.Tensor([NK, B, S, H, DV], dtype), | ||
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): | ||
def chunk_linear_attn_fwd( | ||
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore | ||
K: T.Tensor([B, S, H, DK], dtype), # type: ignore | ||
V: T.Tensor([B, S, H, DV], dtype), # type: ignore | ||
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore | ||
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore | ||
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): | ||
i_b = i_bh // H | ||
i_h = i_bh % H | ||
|
@@ -57,9 +65,9 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), | |
h_shared: tl.layout.make_swizzled_layout(h_shared), | ||
s_shared: tl.layout.make_swizzled_layout(s_shared), | ||
}) | ||
T.use_swizzle(8) | ||
T.use_swizzle(10) | ||
|
||
for i in T.Pipelined(0, NT, num_stages=1): | ||
for i in T.Pipelined(0, NT, num_stages=2): | ||
for row, col in T.Parallel(chunk_size, BK): | ||
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale | ||
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) | ||
|
@@ -71,16 +79,16 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), | |
|
||
T.gemm(s_shared, v, o, clear_accum=True) | ||
T.copy(h, h_shared) | ||
T.gemm(q, h_shared, o) | ||
T.gemm(k, v, h, transpose_A=True) | ||
T.gemm(q, h_shared, o) | ||
T.copy( | ||
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, | ||
i_v * BV:(i_v + 1) * BV]) | ||
|
||
# Output final state | ||
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) | ||
|
||
return main | ||
return chunk_linear_attn_fwd | ||
|
||
|
||
def postprocess(o, h): | ||
|
@@ -91,8 +99,8 @@ def postprocess(o, h): | |
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--B', type=int, default=8, help='Batch size') | ||
parser.add_argument('--S', type=int, default=2048, help='Seq len') | ||
parser.add_argument('--H', type=int, default=64, help='Num heads') | ||
parser.add_argument('--S', type=int, default=4096, help='Seq len') | ||
parser.add_argument('--H', type=int, default=32, help='Num heads') | ||
parser.add_argument('--D', type=int, default=256, help='Head dim') | ||
args = parser.parse_args() | ||
B, S, H, D = args.B, args.S, args.H, args.D | ||
|
@@ -114,7 +122,7 @@ def main(): | |
lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0], | ||
warmup=25, | ||
rep=100) | ||
t2 = do_bench(lambda: kernel(q, k, v)[0].sum(0), warmup=25, rep=100) | ||
t2 = do_bench(lambda: postprocess(*kernel(q, k, v)), warmup=25, rep=100) | ||
print(f'Triton latency: {t1:.3f} ms') | ||
print(f'TileLang latency: {t2:.3f} ms') | ||
print(f'Speedup: {t1/t2:.3f}x') | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) Tile-AI Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch | ||
import tilelang as tl | ||
import tilelang.language as T | ||
from tilelang.profiler import do_bench | ||
|
||
import argparse | ||
|
||
|
||
@tl.jit(out_idx=3, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) | ||
def chunk_retention_fwd_kernel( | ||
B, | ||
S, | ||
H, | ||
DK, | ||
DV, | ||
dtype: str = 'float16', | ||
scale: float = None, | ||
) -> torch.Tensor: | ||
|
||
if scale is None: | ||
scale = DK**-0.5 | ||
accum_dtype = 'float' | ||
|
||
chunk_size = 64 | ||
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA | ||
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 | ||
NK = tl.cdiv(DK, BK) | ||
NV = tl.cdiv(DV, BV) | ||
NT = tl.cdiv(S, chunk_size) | ||
|
||
@T.prim_func | ||
def chunk_retention_fwd( | ||
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore | ||
K: T.Tensor([B, S, H, DK], dtype), # type: ignore | ||
V: T.Tensor([B, S, H, DV], dtype), # type: ignore | ||
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore | ||
): | ||
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): | ||
i_b = i_bh // H | ||
i_h = i_bh % H | ||
log_decay = T.alloc_var('float32') | ||
log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay | ||
|
||
q = T.alloc_shared([chunk_size, BK], dtype) | ||
k = T.alloc_shared([chunk_size, BK], dtype) | ||
v = T.alloc_shared([chunk_size, BV], dtype) | ||
h = T.alloc_fragment([BK, BV], accum_dtype) | ||
h_shared = T.alloc_shared([BK, BV], dtype) | ||
s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) | ||
s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) | ||
o = T.alloc_fragment([chunk_size, BV], accum_dtype) | ||
T.clear(h) | ||
|
||
T.annotate_layout({ | ||
q: tl.layout.make_swizzled_layout(q), | ||
k: tl.layout.make_swizzled_layout(k), | ||
v: tl.layout.make_swizzled_layout(v), | ||
h_shared: tl.layout.make_swizzled_layout(h_shared), | ||
s_shared: tl.layout.make_swizzled_layout(s_shared), | ||
}) | ||
T.use_swizzle(10) | ||
|
||
for i in T.Pipelined(0, NT): | ||
for row, col in T.Parallel(chunk_size, BK): | ||
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale | ||
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) | ||
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) | ||
|
||
T.gemm(q, k, s, clear_accum=True, transpose_B=True) | ||
for row, col in T.Parallel(chunk_size, chunk_size): | ||
s_shared[row, | ||
col] = T.if_then_else(row >= col, s[row, col] * T.exp2( | ||
(row - col) * log_decay), 0) | ||
|
||
T.copy(h, h_shared) | ||
T.gemm(q, h_shared, o, clear_accum=True) | ||
for row, col in T.Parallel(chunk_size, BV): | ||
o[row, col] = T.exp2((row + 1) * log_decay) * o[row, col] | ||
T.gemm(s_shared, v, o) | ||
|
||
for row, col in T.Parallel(chunk_size, BV): | ||
v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay) | ||
for row, col in T.Parallel(BK, BV): | ||
h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] | ||
T.copy( | ||
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, | ||
i_v * BV:(i_v + 1) * BV]) | ||
T.gemm(k, v, h, transpose_A=True) | ||
|
||
return chunk_retention_fwd | ||
|
||
|
||
def postprocess(o): | ||
return o if o.size(0) == 1 else o.sum(0) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--B', type=int, default=8, help='Batch size') | ||
parser.add_argument('--S', type=int, default=4096, help='Seq len') | ||
parser.add_argument('--H', type=int, default=32, help='Num heads') | ||
parser.add_argument('--D', type=int, default=128, help='Head dim') | ||
args = parser.parse_args() | ||
B, S, H, D = args.B, args.S, args.H, args.D | ||
total_flops = 2.0 * B * S * S * H * D # causal | ||
|
||
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) | ||
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) | ||
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) | ||
|
||
kernel = chunk_retention_fwd_kernel(B, S, H, D, D) | ||
|
||
t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) | ||
print(f'Tilelang latency: {t:.3f} ms') | ||
print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a comment here is good, but it would be better to quantify the numerical differences with FLA, or provide a link to a discussion or issue where these differences are analyzed. This helps future readers understand the trade-offs involved in choosing different block sizes.