-
Notifications
You must be signed in to change notification settings - Fork 18
Stacked cache for MLPerf #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
… ring buffer support then fix the mask. Int8 updates also included but not tested.
…o the end of existing cache attention.
… into jit function.
…he. Fix the kv cache quantization test.
…im for quantization from 1,3 to -3,-1 to make it more robust;
…d for performance validation.
…on. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc.
…run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run.
5fbcd47
to
134247c
Compare
FanhaiLu1
approved these changes
Jul 20, 2024
qihqi
pushed a commit
that referenced
this pull request
Jul 23, 2024
* Almost working except mask, need to rebase to main to pick up the the ring buffer support then fix the mask. Int8 updates also included but not tested. * Fixed the test_model_impl for llama, but test_llama_e2e is still failing. * Adds lazy_cache_update and restructure the cache flags. * Disable all the prints. Fix create engine. * Fix typos and minor errors. * Fixes create engine. * Adds new_cache_stacked and fixes cache update. * Fix cache update when new_cach_stacked is False. * Fix the cache manager and make unit tests pass except for 1. * Updates the exportable model to return cache. * Removed the fori loop in cache finalize. Moves the cache.finalize() to the end of existing cache attention. * Try to use shard_map for cache update. * Fix update single cache line in cache.finalize() * Adds int8 support. * Int8 left aligned lazy cache update working, performance still not good enough. * Fix the stacked cache introduced in the previous couple of commits. * Put original ragged attention back. * Add the original ragged attention kernel. * Fixes the bf16/int8 cache stack. * Fix int8 stacked cache insertion in engine and finalization. * Fixes int8 with lazy cache update. * Updates the int8 test. * Fix the int8 ragged attention output sharding. * Fix group query attention broadcasting issue. * Fix shard map input issue. Variables not listed as inputs are freezed into jit function. * Fix the flash attention mask shape; Fix the update single cache line quant version * Adds the kv cache test. * Replace quantized cache "pos" with "input_pos" to align with bf16 cache. Fix the kv cache quantization test. * Fix prefill cache insertion issue for stacked cache; Changes reduce dim for quantization from 1,3 to -3,-1 to make it more robust; * Adds lazy cache update with generate cache stacked new cache unstacked for performance validation. * Fix the shard map sharding for stacked generate cache and unstacked new cache. * Using Jax API to slicing instead of Pytorch index slicing. * Adds stacked cache support in ragged attention reference kernel. * Adds stacked cache support for the modified ragged kernel. * Llama2 70b int8 optimization done. Output not correct yet. * Remove testing temp output files. * Fix the llama 70b output accuracy resulting from gqa. * Fixes the attention output slicing issue when not using flash attention. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc. * Fix the pallas kernel OOB issue * Fix tests; Fix lint issues; * Fix the interactive script. * Fix lint errors. * Fix errors. * Fix the comments. * Fix based on comments; Fix all the unit tests. * Fix the remaining pylint errors. * Default ring buffer back to true so that all the test_run_server and run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run. * Fix all the lint errors. * Remove the deps/JetStream changes. * Fix merge errors, fix lint errors.
qihqi
pushed a commit
that referenced
this pull request
Jul 23, 2024
* Almost working except mask, need to rebase to main to pick up the the ring buffer support then fix the mask. Int8 updates also included but not tested. * Fixed the test_model_impl for llama, but test_llama_e2e is still failing. * Adds lazy_cache_update and restructure the cache flags. * Disable all the prints. Fix create engine. * Fix typos and minor errors. * Fixes create engine. * Adds new_cache_stacked and fixes cache update. * Fix cache update when new_cach_stacked is False. * Fix the cache manager and make unit tests pass except for 1. * Updates the exportable model to return cache. * Removed the fori loop in cache finalize. Moves the cache.finalize() to the end of existing cache attention. * Try to use shard_map for cache update. * Fix update single cache line in cache.finalize() * Adds int8 support. * Int8 left aligned lazy cache update working, performance still not good enough. * Fix the stacked cache introduced in the previous couple of commits. * Put original ragged attention back. * Add the original ragged attention kernel. * Fixes the bf16/int8 cache stack. * Fix int8 stacked cache insertion in engine and finalization. * Fixes int8 with lazy cache update. * Updates the int8 test. * Fix the int8 ragged attention output sharding. * Fix group query attention broadcasting issue. * Fix shard map input issue. Variables not listed as inputs are freezed into jit function. * Fix the flash attention mask shape; Fix the update single cache line quant version * Adds the kv cache test. * Replace quantized cache "pos" with "input_pos" to align with bf16 cache. Fix the kv cache quantization test. * Fix prefill cache insertion issue for stacked cache; Changes reduce dim for quantization from 1,3 to -3,-1 to make it more robust; * Adds lazy cache update with generate cache stacked new cache unstacked for performance validation. * Fix the shard map sharding for stacked generate cache and unstacked new cache. * Using Jax API to slicing instead of Pytorch index slicing. * Adds stacked cache support in ragged attention reference kernel. * Adds stacked cache support for the modified ragged kernel. * Llama2 70b int8 optimization done. Output not correct yet. * Remove testing temp output files. * Fix the llama 70b output accuracy resulting from gqa. * Fixes the attention output slicing issue when not using flash attention. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc. * Fix the pallas kernel OOB issue * Fix tests; Fix lint issues; * Fix the interactive script. * Fix lint errors. * Fix errors. * Fix the comments. * Fix based on comments; Fix all the unit tests. * Fix the remaining pylint errors. * Default ring buffer back to true so that all the test_run_server and run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run. * Fix all the lint errors. * Remove the deps/JetStream changes. * Fix merge errors, fix lint errors.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Same as #151, check in to this branch for MLPerf