-
Notifications
You must be signed in to change notification settings - Fork 11.9k
CUDA graphs break quantized K cache #7492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
To reproduce for example:
|
Noted - I'll take a look. |
Fixes ggml-org#7492
It seems that this case has some conditions which are causing some extra memory copies in matrix multiplication nodes that are causing an issue. A workaround is at agray3@a5fd193 which disables CUDA graphs for the specific conditions. However I'm not sure if this is overkill and may unnecessarily disable CUDA graphs for other cases where they are desired - do you have any insight? I'm not yet sure what is causing the issue with the copies, it may be related to kernel parameter changes (like I already dealt with for other copy kernels). |
I noticed this bug in the context of working on quantized KV cache for FlashAttention. These kernels (by themselves) do not do any memory copies but still suffer from this problem. So perhaps the issue is (also) the conversion of FP32 to the quantized format? |
I've now identified the issue - see the fix at #7565. The problem was that the implementation was assuming that only a single CUDA kernel was associated with nodes of type GGML_OP_CPY when performing param updates to the graph for each token. But in this case, there are 2 such kernels ( |
As of right now it is already possible on master to quantize the K cache via e.g.
-ctk q8_0
. However, this is currently broken on master for batch size 1. Disabling CUDA graphs via the environment variableGGML_CUDA_DISABLE_GRAPHS=1
fixes the issue.cc: @agray3
The text was updated successfully, but these errors were encountered: