-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Memcpy kernel for flash attention #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
implementation is done. need testing (will do it on Thursday) the memory saving strategy is orthogonal to this kernel, so I would not include it in this PR |
678bb06
to
07e9891
Compare
07e9891
to
e21845e
Compare
Hey @suquark thanks for the PR! I have a quick question: have you also measured the performance diff between the two kernels before and after the optimization? |
see the PR comment for the optimized kernel performance comparison |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
* optimize * add benchmark * add assert * add test
Update optimum-intel
It's faster Signed-off-by: Nick Hill <[email protected]>
Adding fp8 gemm computation
sync release with IBM/release
…ack_acc_bf16 fix linear init impacts on generation
Add official doc index. Move the release content to the right place. Signed-off-by: wangxiyuan <[email protected]>
* Fix truncated output Signed-off-by: Woosuk Kwon <[email protected]> * fix Signed-off-by: Woosuk Kwon <[email protected]> --------- Signed-off-by: Woosuk Kwon <[email protected]>
* Fix truncated output Signed-off-by: Woosuk Kwon <[email protected]> * fix Signed-off-by: Woosuk Kwon <[email protected]> --------- Signed-off-by: Woosuk Kwon <[email protected]>
Memcpy kernel for flash attention
The performance is pretty good (theoretical optimal throughput is 1.6TB/s for A100-40GB), considering the memory layout is not ideal.
result for unoptimized kernel:
the optimized kernel works much better for smaller number of tokens (+20% speedup)