Skip to content

CUDA graphs break quantized K cache #7492

Closed
@JohannesGaessler

Description

@JohannesGaessler

As of right now it is already possible on master to quantize the K cache via e.g. -ctk q8_0. However, this is currently broken on master for batch size 1. Disabling CUDA graphs via the environment variable GGML_CUDA_DISABLE_GRAPHS=1 fixes the issue.

cc: @agray3

Activity

JohannesGaessler

JohannesGaessler commented on May 23, 2024

@JohannesGaessler
CollaboratorAuthor

To reproduce for example:

./perplexity --file wikitext-2-raw/wiki.test.raw --n-gpu-layers 99 --model models/opt/${model_name}-${quantization}.gguf --chunks 1 -ctk q8_0 -b 1
agray3

agray3 commented on May 23, 2024

@agray3
Contributor

Noted - I'll take a look.

added a commit that references this issue on May 24, 2024
a5fd193
agray3

agray3 commented on May 24, 2024

@agray3
Contributor

It seems that this case has some conditions which are causing some extra memory copies in matrix multiplication nodes that are causing an issue. A workaround is at agray3@a5fd193 which disables CUDA graphs for the specific conditions. However I'm not sure if this is overkill and may unnecessarily disable CUDA graphs for other cases where they are desired - do you have any insight? I'm not yet sure what is causing the issue with the copies, it may be related to kernel parameter changes (like I already dealt with for other copy kernels).

JohannesGaessler

JohannesGaessler commented on May 24, 2024

@JohannesGaessler
CollaboratorAuthor

I noticed this bug in the context of working on quantized KV cache for FlashAttention. These kernels (by themselves) do not do any memory copies but still suffer from this problem. So perhaps the issue is (also) the conversion of FP32 to the quantized format?

agray3

agray3 commented on May 27, 2024

@agray3
Contributor

I've now identified the issue - see the fix at #7565. The problem was that the implementation was assuming that only a single CUDA kernel was associated with nodes of type GGML_OP_CPY when performing param updates to the graph for each token. But in this case, there are 2 such kernels (cpy_f32_f16 and cp_f32_q). The perplexity reproducer is now working for me with this fix (and CUDA graphs give a nice 23% speedup on my A100-PCIe system)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Nvidia GPUIssues specific to Nvidia GPUsbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      Participants

      @agray3@JohannesGaessler

      Issue actions

        CUDA graphs break quantized K cache · Issue #7492 · ggml-org/llama.cpp