Skip to content

llama : add option to override model tensor buffers #11397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 2, 2025

Conversation

slaren
Copy link
Member

@slaren slaren commented Jan 24, 2025

Adds command line parameter --override-tensor (-ot) that allows changing the buffer type where a model tensor is allocated. This gives user fine grained control over what tensors are to offloaded to each device.

How is this useful: for example, to force the experts in MoE models to stay on the CPU, while offloading the rest to the GPU, you could use -ngl 99 -ot exps=CPU. This may allow more efficient offloading schemes.

The syntax is <tensor name pattern>=<buffer type>. Currently the pattern is just a string search (edit: this is no longer the case, it is a C++ regex search), ie. any tensors that contains the characters in <tensor name pattern> will be matched and loaded into the given buffer type. Multiple overrides can be given by separating them with commas, or passing the -ot option multiple times. To see what tensors are being matched, enable debugging output with -v.

At this point it is just a demo, feel free to experiment and report if you find any interesting uses.

Edit: added regex support, for example to keep experts of layers 20-99 in the CPU you could use -ot "[2-9][0-9]\.ffn_.*_exps\.=CPU"

TODO:

  • Fix pipeline parallelism check
  • Support overriding KV cache allocation (different PR)

@slaren slaren added the demo Demonstrate some concept or idea, not intended to be merged label Jan 24, 2025
@slaren slaren changed the title llama : add option to override tensor buffers llama : add option to override model tensor buffers Jan 24, 2025
@slaren slaren added the need feedback Testing and feedback with results are needed label Jan 24, 2025
@bmtwl
Copy link
Contributor

bmtwl commented Jan 26, 2025

Is there a chance that the direction you're taking these changes might allow for scheduling specific threads to work on specific tensors? With R1 coming out, I'm very interested in reviving my work on trying to improve memory locality to increase CPU inference speeds.

@slaren
Copy link
Member Author

slaren commented Jan 26, 2025

No, that's something that would need to handled at a lower level in the CPU backend.

@bmtwl
Copy link
Contributor

bmtwl commented Jan 26, 2025

No, that's something that would need to handled at a lower level in the CPU backend.

Thanks for the reply @slaren. I figured it wouldn't directly help, but that maybe you'd be adding useful metadata to tensor objects that could help coordinate affinity in the future. I'll start a fresh branch and see how far I get.

At this point it is just a demo, feel free to experiment and report if you find any interesting uses.

I'll also try to pull this branch and test it to see what the speedup and sysmem savings look like.

@bmtwl
Copy link
Contributor

bmtwl commented Jan 27, 2025

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s
-ngl 10 = 5.15t/s
-ngl 20 = 5.64t/s
-ngl 30 = 6.10t/s
-ngl 40 = 6.95t/s

So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run.

@saood06
Copy link

saood06 commented Jan 27, 2025

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run.

@bmtwl
Do you mind testing performance with -nkvo?

@jukofyork
Copy link
Collaborator

What are the shared expert tensors called in llama.cpp - is there a pattern that catches the routed experts (that only activate 1/32 of the time), but doesn't catch the shared experts?

@slaren
Copy link
Member Author

slaren commented Jan 28, 2025

I believe the pattern exps will not match the shared experts, since they are called ffn_xxx_shexp.weight. You can use the gguf preview feature in huggingface to see the names of the tensors. Also remember that you can use multiple patterns, it doesn't have to be a single one.

@jukofyork
Copy link
Collaborator

I believe the pattern exps will not match the shared experts, since they are called ffn_xxx_shexp.weight. You can use the gguf preview feature in huggingface to see the names of the tensors. Also remember that you can use multiple patterns, it doesn't have to be a single one.

Thanks - I'll give this a try later in the week.

This PR together with Reddit post opens up the interesting possibility:

https://old.reddit.com/r/LocalLLaMA/comments/1ibbloy/158bit_deepseek_r1_131gb_dynamic_gguf/

of quantising up/gate projections to q2_k and down projections to q4_k (or something similar), then keeping everything else as q8_0.

Sadly I need to move some stuff about to get space to upscale the fp8 download to bf16 before I can try it, but will report back when I do.

@jukofyork
Copy link
Collaborator

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run.

It might be worth trying q4_0 as should almost let you offload all the layers and IIRC should be slightly faster to dequantise than the K-quants?

@jukofyork
Copy link
Collaborator

Is there a chance that the direction you're taking these changes might allow for scheduling specific threads to work on specific tensors? With R1 coming out, I'm very interested in reviving my work on trying to improve memory locality to increase CPU inference speeds.

Just being able to split the experts between NUMA nodes would make a big difference, but not sure how easy that would be as IIRC the experts' tensors are all in one huge tensor now?

@BarfingLemurs
Copy link
Contributor

During normal operation, When I fit a model between ram and vram, Does the offloading follow a set layer sequence? (layer 0 is chosen first to be offloaded to GPU, then layer 1, etc)

Between GPU offloading and ram, which takes priority?

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run.

Do you remember how much of a speedup? No need for extensive benchmarks, just the rough % estimate.

@saood06
Copy link

saood06 commented Feb 2, 2025

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

I can't seem to offload more than 29 layers of R1 (unsloth's UD-IQ2_XXS) via RPC. 29 layers and below work fine, but 30 just crashes my rpc_server, with no error output. It is not an issue of VRAM as even setting context very low so that it takes up nowhere near my GPU's limits and it still crashes.

@jukofyork
Copy link
Collaborator

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:
-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

I can't seem to offload more than 29 layers of R1 (unsloth's UD-IQ2_XXS) via RPC. 29 layers and below work fine, but 30 just crashes my rpc_server, with no error output. It is not an issue of VRAM as even setting context very low so that it takes up nowhere near my GPU's limits and it still crashes.

I had a similar problem where if I used a single GPU (via CUDA_VISIBLE_DEVICES=0) it ran fine and if I used both GPUs with the --no-kv-offload option it also ran fine (but much slower).

If I didn't use either of these it tried to allocate this 1.4TB monster buffer:

llama_init_from_model: pipeline parallelism enabled (n_copies=4)
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 1407257.91 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 1475616865280
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 351268.28 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 368331484928
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 353465.98 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 370635939584

After some searching I found this issue:

#7217

and recompiled using -DGGML_SCHED_MAX_COPIES=1 and now it's working fine.

(It's likely nothing to do with this PR, but thought it might help!)

@jukofyork
Copy link
Collaborator

@saood06

I figured it out: you have to reorder the devices so the local CUDA devices are last::

#11606
#11424

and mainly these:

#11435

You don't need to run RPC servers for local devices.

#9296
#11424

For those that don't get it (like me initially), you first need to check the device names using the --list-devices option (example below):

 $ llama.cpp/build/bin/llama-server --rpc <IP1>:<PORT1> --rpc <IP2>:<PORT2> --list-devices
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX XXXX, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce GTX YYYY, compute capability 7.5, VMM: yes
Available devices:
  CUDA0: NVIDIA GeForce RTX XXXX (A MiB, B MiB free)
  CUDA1: NVIDIA GeForce GTX YYYY (A MiB, B MiB free)
  RPC[IP1:PORT1]: RPC[IP1:PORT1] (A MiB, B MiB free)
  RPC[IP2:PORT2]: RPC[IP2:PORT2] (A MiB, B MiB free)

It is under Available devices where you get the device names. Next time you launch llama-server, you will use the --device option with the order you want for your devices. An example:

$ llama.cpp/build/bin/llama-server --rpc <IP1>:<PORT1> --rpc <IP2>:<PORT2> \
--device RPC[IP1:PORT1],CUDA0,CUDA1,RPC[IP2:PORT2] \
-ngl 33 --tensor_split 3/20/10/0 --device-draft CUDA1,RPC[IP2:PORT2] -ngld 99 [...]

This way, you can set up the order however you want. In the complicated example above, the main model is offloaded to the first RPC device (using IP1:PORT1 address), mostly on the CUDA0 device, and partially to the CUDA1 device, while the draft model is offloaded to the CUDA1 device and the second RPC device (using IP2:PORT2 address).

Means this works:

--device "RPC[IP1:PORT1],RPC[IP1:PORT2],RPC[IP1:PORT1],RPC[IP2:PORT2],CUDA0,CUDA1"

But if I don't do this I get OOM errors with plenty of VRAM left like you had.

@saood06
Copy link

saood06 commented Feb 5, 2025

I'm testing this with and without #11446 and without on unsloth's UD-IQ2_XXS I was only able to offload 29 layers, and with I was able to allocate only 28 (on a Q4_K_S quant). This is not a VRAM issue, it would have plenty of spare VRAM, it would even get past allocation, and get to warmup, where the rpc-server would then just crash.

The other issue is performance the more layers I allocate the worse performance gets while bmtwl shows performance increase with more layers offloaded with non-RPC based offloading.

@ro99
Copy link

ro99 commented Feb 5, 2025

I am able to load the model with llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-00001-of-00005.gguf --threads 28 --host 0.0.0.0 --port 5001 -c 8192 -ngl 99 -ot exps=CPU :

PID DEV TYPE GPU MEM HOST MEM Command
16431 0 Compute 13294MiB 54% 215686MiB /opt/llama.cpp/build/bin/llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-000
16431 2 Compute 12088MiB 49% 215686MiB /opt/llama.cpp/build/bin/llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-000
16431 3 Compute 11616MiB 47% 215686MiB /opt/llama.cpp/build/bin/llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-000
16431 1 Compute 11488MiB 47% 215686MiB /opt/llama.cpp/build/bin/llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-000

But as soon as I send the prompt I receive:

/opt/llama.cpp/ggml/src/ggml-alloc.c:182: not enough space in the buffer
ggml_dyn_tallocr_alloc: not enough space in the buffer to allocate 18446744073709550624 bytes, largest block available 9223372036854775807 bytes
[New LWP 16444]
[New LWP 16445]
[New LWP 16446]
[New LWP 16447]
...
[New LWP 16533]
[New LWP 16534]
[New LWP 16535]
[New LWP 16536]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007f1e950d0bd7 in wait4 () from /lib/x86_64-linux-gnu/libc.so.6
#0  0x00007f1e950d0bd7 in wait4 () from /lib/x86_64-linux-gnu/libc.so.6
#1  0x00007f1e95527fc1 in ggml_abort () from /opt/llama.cpp/build/bin/libggml-base.so
#2  0x00007f1e9553619c in ggml_gallocr_allocate_node () from /opt/llama.cpp/build/bin/libggml-base.so
#3  0x00007f1e955369d0 in ggml_gallocr_reserve_n () from /opt/llama.cpp/build/bin/libggml-base.so
#4  0x00007f1e9553c244 in ggml_backend_sched_alloc_graph () from /opt/llama.cpp/build/bin/libggml-base.so
#5  0x00007f1e95646030 in llama_decode_impl(llama_context&, llama_batch) () from /opt/llama.cpp/build/bin/libllama.so
#6  0x00007f1e95646f57 in llama_decode () from /opt/llama.cpp/build/bin/libllama.so
#7  0x000055f47d6647c9 in server_context::update_slots() ()
#8  0x000055f47d64f4d1 in server_queue::start_loop() ()
#9  0x000055f47d5fd067 in main ()
[Inferior 1 (process 16431) detached]
Aborted (core dumped)

Without the --override-tensor and offloading 20 layers to the GPU it works fine.

Testing with 4x RTX 3090 and 320GiB RAM. Built with cmake -B build -DGGML_CUDA=ON -DGGML_SCHED_MAX_COPIES=1.

@jukofyork
Copy link
Collaborator

Without the --override-tensor and offloading 20 layers to the GPU it works fine.

Testing with 4x RTX 3090 and 320GiB RAM. Built with cmake -B build -DGGML_CUDA=ON -DGGML_SCHED_MAX_COPIES=1.

Maybe try -ngl 61 to keep the output layer on the CPU too (that oddly worked for me earlier when I was having trouble with the RPC stuff).

@ro99
Copy link

ro99 commented Feb 5, 2025

Maybe try -ngl 61 to keep the output layer on the CPU too (that oddly worked for me earlier when I was having trouble with the RPC stuff).

No luck, still the same issue.

Oddly enough, the issue only happens when sending more than 450 tokens.

@slaren
Copy link
Member Author

slaren commented Feb 5, 2025

ggml_dyn_tallocr_alloc: not enough space in the buffer to allocate 18446744073709550624 bytes

It's trying to allocate a tensor of size 2^64, which suggest there is an integer overflow somewhere. If you set the environment variable GGML_SCHED_DEBUG=2, it will print the graph before allocating it, which may give some indication of which tensor is causing this. Or just change the error message in ggml_dyn_tallocr_alloc to include the tensor name.

@ro99
Copy link

ro99 commented Feb 6, 2025

It's trying to allocate a tensor of size 2^64, which suggest there is an integer overflow somewhere. If you set the environment variable GGML_SCHED_DEBUG=2, it will print the graph before allocating it, which may give some indication of which tensor is causing this. Or just change the error message in ggml_dyn_tallocr_alloc to include the tensor name.

It is the CPU#ffn_moe_topk-60#0 tensor.

Is it possible to try to force this particular one to be allocated into the GPU buffer?

@slaren
Copy link
Member Author

slaren commented Feb 6, 2025

This is most likely a bug, we need to understand why it is happening and fix it. Since you mentioned that it only happens with large prompts, I suspect that this is caused by a zero-sized tensors. When evaluating a batch where no logits are required (which happens when evaluating a prompt that needs to be split into multiple ubatches), zero-size tensors are created to skip the calculation of the logits.
I cannot run this model, so I would need your help to figure why this is happening. Can you print more details about the tensor? Something like this should do it:

diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 9a3bf9f29..470ef13e6 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -179,6 +179,9 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
             // this should never happen
             GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
                     __func__, size, max_avail);
+            GGML_LOG_ERROR("%s: tensor: %s, shape: %ld %ld %ld %ld, size: %zu",
+                __func__, tensor->name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],
+                ggml_nbytes(tensor));
             GGML_ABORT("not enough space in the buffer");
         }
     }

@slaren
Copy link
Member Author

slaren commented Feb 6, 2025

Ok nvm, I think I see the problem. I will push a possible fix soon.

@zts9989
Copy link

zts9989 commented Mar 14, 2025

I'll upload the modified thread pool code later when I have time. Please note upfront that the code I wrote is very messy and didn't account for more environmental compilation scenarios. It's for your reference.

@jukofyork
Copy link
Collaborator

jukofyork commented Mar 27, 2025

Any chance the conflicts could be resolved?

I was just in the process of finalising the latest MLA PR but can't test it without this PR! :)

Bump 😃

Sorry for the bump, but this PR is really essential for me to test the MLA stuff using the full-sized deepseek models, and now all the refactoring has settled down I was hoping to try to revive the -mla option code I wrote for use with the new build_attn_mha():

ggml_tensor * llm_graph_context::build_attn_mha(
         ggml_cgraph * gf,
         ggml_tensor * q,
         ggml_tensor * k,
         ggml_tensor * v,
         ggml_tensor * kq_b,
         ggml_tensor * kq_mask,
             bool      v_trans,
             float     kq_scale) const {
  //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

  //const int64_t n_head    = hparams.n_head(il);
  //const int64_t n_head_kv = hparams.n_head_kv(il);

  //const auto & n_embd_head_k = hparams.n_embd_head_k;
  //const auto & n_embd_head_v = hparams.n_embd_head_v;

    const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];

    const auto n_tokens = q->ne[1];
    const auto n_head   = q->ne[2];
    const auto n_kv     = k->ne[1];

I think it should be way cleaner to add now as nearly all the uglyness came from those hard-coded GQA assumptions from the GGUF file (ie: MLA converts into MQA or MHA depending on whether you use the "nieve" method or not).

I'm not sure if it might be worth factoring out the tensor-name regex stuff for use with this PR that aims to do something similar for the llama-quantize code:

#12511

@slaren slaren marked this pull request as ready for review April 2, 2025 00:32
@slaren slaren requested a review from ggerganov April 2, 2025 00:32
@ggerganov
Copy link
Member

I'm not sure if I am using it correctly, but on my Mac, overriding the buffers seems to lead to double the allocation. I am testing with:

make -j && ./bin/llama-cli -m ../models/deepseek-v2-lite-chat/ggml-model-q8_0.gguf -ot "ffn_.*"=CPU -lv 1

And I see in the output:

0.00.093.090 D tensor blk.0.ffn_norm.weight buffer type overriden to CPU
0.00.093.100 D tensor blk.0.ffn_gate.weight buffer type overriden to CPU
0.00.093.104 D tensor blk.0.ffn_down.weight buffer type overriden to CPU
0.00.093.108 D tensor blk.0.ffn_up.weight buffer type overriden to CPU
0.00.093.159 D tensor blk.1.ffn_norm.weight buffer type overriden to CPU
0.00.093.168 D tensor blk.1.ffn_gate_inp.weight buffer type overriden to CPU
0.00.093.173 D tensor blk.1.ffn_gate_exps.weight buffer type overriden to CPU
0.00.093.178 D tensor blk.1.ffn_down_exps.weight buffer type overriden to CPU
0.00.093.182 D tensor blk.1.ffn_up_exps.weight buffer type overriden to CPU
0.00.093.188 D tensor blk.1.ffn_gate_shexp.weight buffer type overriden to CPU
0.00.093.209 D tensor blk.1.ffn_down_shexp.weight buffer type overriden to CPU
0.00.093.215 D tensor blk.1.ffn_up_shexp.weight buffer type overriden to CPU
0.00.093.259 D tensor blk.2.ffn_norm.weight buffer type overriden to CPU
0.00.093.265 D tensor blk.2.ffn_gate_inp.weight buffer type overriden to CPU
0.00.093.270 D tensor blk.2.ffn_gate_exps.weight buffer type overriden to CPU
...
0.00.095.542 D tensor blk.26.ffn_up_exps.weight buffer type overriden to CPU
0.00.095.548 D tensor blk.26.ffn_gate_shexp.weight buffer type overriden to CPU
0.00.095.553 D tensor blk.26.ffn_down_shexp.weight buffer type overriden to CPU
0.00.095.561 D tensor blk.26.ffn_up_shexp.weight buffer type overriden to CPU
0.00.095.566 D load_tensors: tensor 'token_embd.weight' (q8_0) (and 212 others) cannot be used with preferred buffer type CPU_AARCH64, using CPU instead
0.00.096.438 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 15924.97 MiB, (15925.03 / 147456.00)
0.00.357.084 I load_tensors: offloading 27 repeating layers to GPU
0.00.357.088 I load_tensors: offloading output layer to GPU
0.00.357.089 I load_tensors: offloaded 28/28 layers to GPU
0.00.357.096 I load_tensors: Metal_Mapped model buffer size = 15924.97 MiB
0.00.357.097 I load_tensors:   CPU_Mapped model buffer size = 15698.50 MiB
........................................................................................
0.00.357.946 I llama_context: constructing llama_context

If I don't pass the -ot argument I see this:

0.00.393.685 I load_tensors: offloading 27 repeating layers to GPU
0.00.393.689 I load_tensors: offloading output layer to GPU
0.00.393.689 I load_tensors: offloaded 28/28 layers to GPU
0.00.393.699 I load_tensors: Metal_Mapped model buffer size = 15924.97 MiB
0.00.393.699 I load_tensors:   CPU_Mapped model buffer size =   212.50 MiB
......................................................................................
0.00.394.853 I llama_context: constructing llama_context

The Metal_Mapped buffer is the same size in both cases. Is this expected?

@slaren
Copy link
Member Author

slaren commented Apr 2, 2025

With Metal you would need to disable mmap to see lower memory usage, since the entire file or a large fraction of it will remain mapped.

@slaren slaren merged commit e0e912f into master Apr 2, 2025
48 checks passed
@slaren slaren deleted the sl/custom-tensor-offload branch April 2, 2025 12:52
@jukofyork
Copy link
Collaborator

@slaren Thanks (and sorry for the bumb again)!

@Panchovix
Copy link

Panchovix commented Apr 2, 2025

Sorry to hijack the thread, but how would you suggest running DeepSeek-R3-UD-Q2_K_XL.gguf on a system with 192GB RAM and 128GB VRAM with multiple GPUs (VRAM, ordered from CUDA_VISIBLE_DEVICES with 24/24/32/48 GB)

Would running

llama-server -m /DeepSeek-R3-UD-Q2_K_XL.gguf -c 8192 -ngl 99 -ot exps=CPU work out of the box for the multiple GPUs, or do I have to set the tensor split with -ts, as the 4 GPUs don't necessarily have the same amount of VRAM?

EDIT: It seems to work but uses just 10-12 GB of VRAM on each GPU

@BarfingLemurs
Copy link
Contributor

You would need to increase the layers offloaded to fill the vram of each GPU as much as possible.

./llama-cli -ngl 38 -ts 9/29 -m DeepSeek-R1-UD-IQ1_S-00001-of-00003.gguf -ctk q4_0 -p "hello!" --override-kv deepseek2.expert_used_count=int:4 -ot "[4-9][0-9]\.ffn_.*_exps\.=CPU" -c 1024

On my 2x3090s, it was this way.

@jerryrual
Copy link

The good news

Using -ot / --override-tensor flag seems to work properly in my testing. Running with -ngl 62 -ot exps=CPU is the fastest way to run R1 671B UD-Q2_K_XL (212GiB weights) on 256GB RAM plus single CUDA GPU on my ThreadRipper Pro 24 core test rig with llama.cpp.

It is counter-intuitive to me that offloading less layers onto GPU makes it go faster, and I presume this has something to do with CUDA graphs not working as well with even a single expert also in VRAM, but I'm really just speculating wildly.

This method is still not quite as fast as ktransformers, but it is faster than running ktransformers --no-use_cuda_graph.

The technically unrelated news

I had hoped to use this fine-grained offload method to distribute experts across 6 NUMA nodes on a big dual socket Intel Xeon 6980P. While it does technically work and runs, it is much slower than just running normally with no NUMA optimizations at all. I even tried making a patch to rpc-server example to allow specifying number of threads and forcing CPU backend.

--override-tensor works well with RPC devices and I appreciate how specifying the flag multiple times stacks how I would expect. However, as others have mentioned above, the current synchronous RPC send() implementation seems to bottleneck attempts to distribute computation and is not a true async tensor parallel optimized solution. (vLLM seems to implement some of this, and I hope to test it to find how well CPU backend works on it).

Example

I tried a few configurations including 5x rpc-server backends and a single llama-server frontend each in a different NUMA node. I also tried a more simple version with 1x rpc-server on a single NUMA node on the opposite CPU socket as the llama-server frontend. Even communicating over loopback device the performance was much worse.

I'll leave the commands and some info for anyone interested inside the fold below. Also a whole discussion on the challenges of running llama.cpp in more than a single NUMA node over here.

Cheers!

EDIT Tried one last time with -nkvo, --no-kv-offload disable KV offload but didn't make a significant difference, still very slow and not saturating CPU cores as probably waiting around for send() calls...

Example selective RPC backend offloading experiments

if llama.cpp support offload all routed-experts to CPU host memory? that is same to ktransformers. so the two solution will have the same performance?

@ddh0
Copy link
Contributor

ddh0 commented Apr 24, 2025

How would one override tensor buffers when using the API from include/llama.h?

The public interface currently look like this:

    struct llama_model_tensor_buft_override {
        const char * pattern;
        ggml_backend_buffer_type_t buft;
    };

    struct llama_model_params {
        // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
        ggml_backend_dev_t * devices;

        // NULL-terminated list of buffer types to use for tensors that match a pattern
        const struct llama_model_tensor_buft_override * tensor_buft_overrides;
	
	// ... etc ...
	
	};

As I understand it, since buft's type is ggml_backend_buffer_type_t, which is not part of the API, I'm not able to use the tensor override feature.

Maybe the API could be updated to something like this (based on ggml/include/ggml-backend.h#L130):

    enum ggml_backend_dev_type {
        // CPU device using system memory
        GGML_BACKEND_DEVICE_TYPE_CPU,
        // GPU device using dedicated memory
        GGML_BACKEND_DEVICE_TYPE_GPU,
        // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
        GGML_BACKEND_DEVICE_TYPE_ACCEL
    };

    struct llama_model_tensor_buft_override {
        const char * pattern;
        ggml_backend_dev_type buft;
    };

This would allow users of the public interface to create their own llama_model_tensor_buft_override and control tensor overrides. I don't think this is a complete solution, I'm just trying to illustrate the problem I'm having. Maybe there is a way to do this that I'm not aware of.

@slaren
Copy link
Member Author

slaren commented Apr 24, 2025

You need to use the ggml API to obtain the buffer types, in the same way the llama.cpp examples do it. The llama.cpp API includes ggml, in fact, when you include llama.h, ggml.h and ggml-backend.h are also included and available to applications.

@ddh0
Copy link
Contributor

ddh0 commented Apr 24, 2025

I see, thank you

@lilunxm12
Copy link

lilunxm12 commented Apr 29, 2025

Is it possible to use --override-tensor to specify which layers goes to gpu0 and others to gpu1?

With qwen3-30B-A3B, it's too large to fill in my single 4070ti super but perfect usable quant exist if I plug in old 3060. Can't find the definite answer, but I assume it would be like deepseek that some layers should be prioritized to the beefier card.

Does that make sense for pure gpu setup? I understand that layers needs to be processed sequentially otherwise gpu2gpu communication overhead kicks in frequectly, but I can't find how much of impact is it.

@ubergarm
Copy link

ubergarm commented Apr 29, 2025

@lilunxm12

Is it possible to use --override-tensor to specify which layers goes to gpu0 and others to gpu1?

This is possible with ik_llama.cpp fork e.g. you can put tensors/layers exactly where you want them across multiple CUDA devices or CPU e.g. i'm experimenting currently with Qwen3-235B-A22B MoE to fit a quant perfectly on 24GB VRAM + 96GB RAM.

Example Command and Logs

Partial logs shown for brevity:

# ik_llama.cpp fork

./build/bin/llama-sweep-bench \
  --model "$model" \
  -fa \
  -ctk f16 -ctv f16 \
  -c 32768 \
  -fmoe \
  -amb 512 \
  -ot blk\.[0-4]\.ffn.*=CUDA0 \
  -ot blk\.[5-9]\.ffn.*=CUDA1 \
  -ot exps=CPU \
  -ngl 99 \
  --threads 24

.
.
.

Tensor blk.0.ffn_norm.weight buffer type overriden to CUDA0
Tensor blk.0.ffn_gate_inp.weight buffer type overriden to CUDA0
Tensor blk.0.ffn_gate_exps.weight buffer type overriden to CUDA0
Tensor blk.0.ffn_down_exps.weight buffer type overriden to CUDA0
Tensor blk.0.ffn_up_exps.weight buffer type overriden to CUDA0
Tensor blk.1.ffn_norm.weight buffer type overriden to CUDA0
Tensor blk.1.ffn_gate_inp.weight buffer type overriden to CUDA0
Tensor blk.1.ffn_gate_exps.weight buffer type overriden to CUDA0
Tensor blk.1.ffn_down_exps.weight buffer type overriden to CUDA0
Tensor blk.1.ffn_up_exps.weight buffer type overriden to CUDA0
Tensor blk.2.ffn_norm.weight buffer type overriden to CUDA0
Tensor blk.2.ffn_gate_inp.weight buffer type overriden to CUDA0
Tensor blk.2.ffn_gate_exps.weight buffer type overriden to CUDA0
Tensor blk.2.ffn_down_exps.weight buffer type overriden to CUDA0
Tensor blk.2.ffn_up_exps.weight buffer type overriden to CUDA0
Tensor blk.3.ffn_norm.weight buffer type overriden to CUDA0
Tensor blk.3.ffn_gate_inp.weight buffer type overriden to CUDA0
Tensor blk.3.ffn_gate_exps.weight buffer type overriden to CUDA0
Tensor blk.3.ffn_down_exps.weight buffer type overriden to CUDA0
Tensor blk.3.ffn_up_exps.weight buffer type overriden to CUDA0
Tensor blk.4.ffn_norm.weight buffer type overriden to CUDA0
Tensor blk.4.ffn_gate_inp.weight buffer type overriden to CUDA0
Tensor blk.4.ffn_gate_exps.weight buffer type overriden to CUDA0
Tensor blk.4.ffn_down_exps.weight buffer type overriden to CUDA0
Tensor blk.4.ffn_up_exps.weight buffer type overriden to CUDA0
Tensor blk.5.ffn_norm.weight buffer type overriden to CUDA1
Tensor blk.5.ffn_gate_inp.weight buffer type overriden to CUDA1
Tensor blk.5.ffn_gate_exps.weight buffer type overriden to CUDA1
Tensor blk.5.ffn_down_exps.weight buffer type overriden to CUDA1
Tensor blk.5.ffn_up_exps.weight buffer type overriden to CUDA1
Tensor blk.6.ffn_norm.weight buffer type overriden to CUDA1
Tensor blk.6.ffn_gate_inp.weight buffer type overriden to CUDA1
Tensor blk.6.ffn_gate_exps.weight buffer type overriden to CUDA1
Tensor blk.6.ffn_down_exps.weight buffer type overriden to CUDA1
Tensor blk.6.ffn_up_exps.weight buffer type overriden to CUDA1
Tensor blk.7.ffn_norm.weight buffer type overriden to CUDA1
Tensor blk.7.ffn_gate_inp.weight buffer type overriden to CUDA1
Tensor blk.7.ffn_gate_exps.weight buffer type overriden to CUDA1
Tensor blk.7.ffn_down_exps.weight buffer type overriden to CUDA1
Tensor blk.7.ffn_up_exps.weight buffer type overriden to CUDA1
Tensor blk.8.ffn_norm.weight buffer type overriden to CUDA1
Tensor blk.8.ffn_gate_inp.weight buffer type overriden to CUDA1
Tensor blk.8.ffn_gate_exps.weight buffer type overriden to CUDA1
Tensor blk.8.ffn_down_exps.weight buffer type overriden to CUDA1
Tensor blk.8.ffn_up_exps.weight buffer type overriden to CUDA1
Tensor blk.9.ffn_norm.weight buffer type overriden to CUDA1
Tensor blk.9.ffn_gate_inp.weight buffer type overriden to CUDA1
Tensor blk.9.ffn_gate_exps.weight buffer type overriden to CUDA1
Tensor blk.9.ffn_down_exps.weight buffer type overriden to CUDA1
Tensor blk.9.ffn_up_exps.weight buffer type overriden to CUDA1
Tensor blk.10.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.10.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.10.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.11.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.11.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.11.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.12.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.12.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.12.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.13.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.13.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.13.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.14.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.14.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.14.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.15.ffn_gate_exps.weight buffer type overriden to CPU
.
.
.

I'm not 100% clear if mainline llama.cpp allows to specify anything other than CPU or possibly CUDA but it may just need a little more regex support added to do this, Sorry, but haven't had time to check in mad rush of Qwen3.

I understand that layers needs to be processed sequentially otherwise gpu2gpu communication overhead kicks in frequectly, but I can't find how much of impact is it.

Its perfectly fine and a good idea to place some layers on CUD0 and other layers on CUDA1 and it will perform well. No need to worry about P2P NVLINK etc as this is not tensor-parallel/data-parallel stuff like vLLM and sglang may use for say 8x or 16x GPU nodes.

@slaren
Copy link
Member Author

slaren commented Apr 29, 2025

The ik_llama implementation is just a copy paste of this PR, far from being only "inspired" from it as claimed.

@ubergarm
Copy link

ubergarm commented Apr 29, 2025

@lilunxm12

Is it possible to use --override-tensor to specify which layers goes to gpu0 and others to gpu1?

Okay, I had a moment to circle back around and test this out with a recent version of mainline llama.cpp. It does seem to allow you to specify e.g. CUDA0 and CUDA1 and CPU all together.

I'd recommend setting --verbosity 1 and do something like this with a model you have handy and it will print out the Available buffer types:

./build/bin/llama-server \
    --verbosity 1 \
    --model /mnt/astrodata/llm/models/bartowski/THUDM_GLM-Z1-32B-0414-GGUF/THUDM_GLM-Z1-32B-0414-IQ4_XS.gguf \
    -fa \
    --n-gpu-layers 99 \
    --ctx-size 8192 \
    --cache-type-k q8_0 \
    --cache-type-v q8_0 \
    -ot attn=FOO \
    -nkvo \
    --threads 16 \
    --host 127.0.0.1 \
    --port 8088

Available buffer types:
  CPU
  CUDA0
error while handling argument "-ot": unknown buffer type

Then you can piece together as you like with either one long -ot .... or multiple -ot being mindful of order of operations (earlier regex match first and take precedence over later regex).

Note it doesn't print out unmatched layers even with --verbosity 1 so you can do one last match and assign it to something it would have gone to anyway just to get explicit confirmation.

Here is another quick example. It is quite flexible and handy for some models if you even want to offload say only attention and kv-cache to CPU etc.

Example of using -ot to place exact tensors/layers on different CPU/GPU backends
  1. Place all attn tensors on all layers onto CPU.
  2. Place all ffn tensors from layers [0-5] on CUDA0
  3. Don't offload kv cache (leave it on CPU).
  4. If you have more GPUs you could add more e.g. -ot blk\.[6-9]\.ffn.*=CUDA1
  5. I usually always set -ngl 99 or a big number then manually override with multiple regex.
./build/bin/llama-server \
    --verbosity 1 \
    --model /mnt/astrodata/llm/models/bartowski/THUDM_GLM-Z1-32B-0414-GGUF/THUDM_GLM-Z1-32B-0414-IQ4_XS.gguf \
    -fa \
    --n-gpu-layers 99 \
    --ctx-size 8192 \
    --cache-type-k q8_0 \
    --cache-type-v q8_0 \
    -ot attn=CPU \
    -ot blk\.[0-5]\.ffn.*=CUDA0 \
    -nkvo \
    --threads 16 \
    --host 127.0.0.1 \
    --port 8088

tensor blk.0.attn_norm.weight buffer type overriden to CPU
tensor blk.0.attn_q.weight buffer type overriden to CPU
tensor blk.0.attn_k.weight buffer type overriden to CPU
tensor blk.0.attn_v.weight buffer type overriden to CPU
tensor blk.0.attn_output.weight buffer type overriden to CPU
tensor blk.0.ffn_norm.weight buffer type overriden to CUDA0
tensor blk.0.ffn_down.weight buffer type overriden to CUDA0
tensor blk.0.ffn_up.weight buffer type overriden to CUDA0
tensor blk.1.attn_norm.weight buffer type overriden to CPU
tensor blk.1.attn_q.weight buffer type overriden to CPU
tensor blk.1.attn_k.weight buffer type overriden to CPU
tensor blk.1.attn_v.weight buffer type overriden to CPU
tensor blk.1.attn_output.weight buffer type overriden to CPU
tensor blk.1.ffn_norm.weight buffer type overriden to CUDA0
tensor blk.1.ffn_down.weight buffer type overriden to CUDA0
tensor blk.1.ffn_up.weight buffer type overriden to CUDA0
tensor blk.2.attn_norm.weight buffer type overriden to CPU
tensor blk.2.attn_q.weight buffer type overriden to CPU
tensor blk.2.attn_k.weight buffer type overriden to CPU
tensor blk.2.attn_v.weight buffer type overriden to CPU
tensor blk.2.attn_output.weight buffer type overriden to CPU
tensor blk.2.ffn_norm.weight buffer type overriden to CUDA0
tensor blk.2.ffn_down.weight buffer type overriden to CUDA0
tensor blk.2.ffn_up.weight buffer type overriden to CUDA0
tensor blk.3.attn_norm.weight buffer type overriden to CPU
tensor blk.3.attn_q.weight buffer type overriden to CPU
tensor blk.3.attn_k.weight buffer type overriden to CPU
tensor blk.3.attn_v.weight buffer type overriden to CPU
tensor blk.3.attn_output.weight buffer type overriden to CPU
tensor blk.3.ffn_norm.weight buffer type overriden to CUDA0
tensor blk.3.ffn_down.weight buffer type overriden to CUDA0
tensor blk.3.ffn_up.weight buffer type overriden to CUDA0
tensor blk.4.attn_norm.weight buffer type overriden to CPU
tensor blk.4.attn_q.weight buffer type overriden to CPU
tensor blk.4.attn_k.weight buffer type overriden to CPU
tensor blk.4.attn_v.weight buffer type overriden to CPU
tensor blk.4.attn_output.weight buffer type overriden to CPU
tensor blk.4.ffn_norm.weight buffer type overriden to CUDA0
tensor blk.4.ffn_down.weight buffer type overriden to CUDA0
tensor blk.4.ffn_up.weight buffer type overriden to CUDA0
tensor blk.5.attn_norm.weight buffer type overriden to CPU
tensor blk.5.attn_q.weight buffer type overriden to CPU
tensor blk.5.attn_k.weight buffer type overriden to CPU
tensor blk.5.attn_v.weight buffer type overriden to CPU
tensor blk.5.attn_output.weight buffer type overriden to CPU
tensor blk.5.ffn_norm.weight buffer type overriden to CUDA0
tensor blk.5.ffn_down.weight buffer type overriden to CUDA0
tensor blk.5.ffn_up.weight buffer type overriden to CUDA0
tensor blk.6.attn_norm.weight buffer type overriden to CPU
tensor blk.6.attn_q.weight buffer type overriden to CPU
tensor blk.6.attn_k.weight buffer type overriden to CPU
tensor blk.6.attn_v.weight buffer type overriden to CPU
tensor blk.6.attn_output.weight buffer type overriden to CPU
tensor blk.7.attn_norm.weight buffer type overriden to CPU
tensor blk.7.attn_q.weight buffer type overriden to CPU
tensor blk.7.attn_k.weight buffer type overriden to CPU
tensor blk.7.attn_v.weight buffer type overriden to CPU
tensor blk.7.attn_output.weight buffer type overriden to CPU
.
.
.

@slaren Hey sorry I don't understand what appears as "beef" between both forks. I recognize there is history way beyond me. I was confused if this -ot feature fully supported regex or not, but realize I was thinking of the new llama-quantize --tensor-type regex logic.

I appreciate everyone, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged ggml changes relating to the ggml tensor library for machine learning need feedback Testing and feedback with results are needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.