Quantized LLM training in pure CUDA/C++.
llm.q
is an implementation of (quantized) large language model training in CUDA, inspired by llm.c. It is particularly aimed at medium-sized training setups, i.e., a single node with multiple GPUs.
The code is written in C++20 and requires CUDA 12 or later. It depends on nccl for communication, and cudnn for fast attention. Multi-GPU training can either be run in multi-process mode (requires OpenMPI) or in multi-thread mode. On a recent ubuntu system, this should provide the required dependencies (adapt cuda version as needed):
# build tools
apt install cmake ninja-build git gcc-13 g++-13
# libs
apt install cuda-12-8 cudnn9-cuda-12-8 libnccl2 libnccl-dev libopenmpi-dev
Additional header-only dependencies are automatically downloaded by cmake during the build process. These are:
To build the training executable, run
mkdir build
cmake -S . -B build
cmake --build build --parallel --target train
In order to train/fine-tune a model, you first need some data. The tokenize_data script provides a utility to prepare token files for training. It only supports a limited number of datasets, but hacking it for your own dataset should be straightforward.
uv run tokenize_data.py --dataset tiny-shakespeare --model qwen
This will create tiny-shakespeare-qwen-train.bin
and tiny-shakespeare-qwen-eval.bin
.
Let's fine-tune the smallest Qwen model on this data:
./build/train --model=Qwen/Qwen2.5-0.5B \
--train-file=data/tiny-shakespeare-qwen/train.bin \
--eval-file=data/tiny-shakespeare-qwen/eval.bin \
--model-dtype=bf16 --opt-m-dtype=bf16 --opt-v-dtype=bf16 \
--matmul-dtype=e4m3 \
--recompute-ffn --recompute-att \
--grad-accumulation=8 --steps=30 \
--learning-rate=1e-5 --gpus=1 --batch-size=8
The program will print some logging information, such as the following:
[Options]
recompute-swiglu : true
recompute-norm : false
[...]
[System 0]
Device 0: NVIDIA GeForce RTX 4090
CUDA version: driver 13000, runtime 13000
Memory: 906 MiB / 24080 MiB
Loading model from `/huggingface/hub/models--Qwen--Qwen2.5-0.5B/snapshots/060db6499f32faf8b98477b0a26969ef7d8b9987/model.safetensors`
done
[Dataset]
train: 329k tokens
data/tiny-shakespeare-qwen/train.bin : 329663
eval: 33k tokens
data/tiny-shakespeare-qwen/eval.bin : 33698
[Allocator State]
Adam V: 942 MiB
Adam M: 942 MiB
Gradients: 942 MiB
Activations: 4505 MiB
Master: 682 MiB
Weights: 601 MiB
Free: 14078 MiB
Reserved: 483 MiB
Other: 903 MiB
If [Allocator State]
shows a lot of Free
memory (as it does here), you can try increasing the batch size (and adjust --grad-accumulation
accordingly), or reduce the amount of activation checkpointing.
For example, on a 16GiB 4060Ti, --recompute-ffn --recompute-att
can be replaced by --recompute-swiglu
, which increases the activation memory from 4.5 GiB
to 9 GiB
, and
the speed from ~11k tps to ~13k tps.
Then, the actual training will begin:
[T] step 0 [ 19.9%] | time: 1869 ms | norm 4.315545 | loss 3.282568 | tps 35064 | sol 42.9%
[T] step 1 [ 39.8%] | time: 1709 ms | norm 8.423664 | loss 3.310652 | tps 38347 | sol 46.9%
[T] step 2 [ 59.6%] | time: 1708 ms | norm 4.818971 | loss 3.330125 | tps 38370 | sol 47.0%
[T] step 3 [ 79.5%] | time: 1715 ms | norm 5.247286 | loss 3.259991 | tps 38213 | sol 46.8%
[V] step 4 [ 0.0%] | time: 165 ms | eval 2.945187 | train 3.295834 | tps 148k
Each [T]
line is a training step (in contrast to [V]
validation). It shows the step number, the progress within the epoch,
the elapsed time, as well as the current loss and gradient norm. It also calculates the current throuput in tokens per second,
and sets this in relation to the GPU's speed of light (SOL), i.e., the fastest possible speed if the GPU was only running strictly necessary matmuls at peak flop/s.
After 50 steps, the training will finish, and save the final model to model.safetensors
. In addition, a log file will be created,
which contains the training log in JSON format. We can visualize the log using the plot-training-run.py
utility script:`
uv run python/plot-training-run.py log.json
This shows the training and evaluation losses over time, for quick inspection. For a more detailed and interactive workflow, you can export the log to weights&biases:
uv run python/export-wandb.py --log-file log.json --project <ProjectName>
Finally, we want to evaluate the produced model, which by default is placed in output/model.safetensors
. This
file is transformers compatible, in contrast to checkpoints that get generated
at intermediate steps of longer training runs. However, as llmq
does not include any tokenization, in order to run this
with transformers we still need the tokenizers. A simple way to get them is to just copy them over from the original HF
checkpoint, in this example:
cp /huggingface/hub/models--Qwen--Qwen2.5-0.5B/snapshots/060db6499f32faf8b98477b0a26969ef7d8b9987/tokenizer* output/
Then we can run lm-eval evaluations:
uv run lm_eval --model hf --model_args pretrained=./output --tasks hellaswag --batch_size auto
which yields
Tasks | Version | Filter | n-shot | Metric | Value | Stderr | ||
---|---|---|---|---|---|---|---|---|
hellaswag | 1 | none | 0 | acc | ↑ | 0.4043 | ± | 0.0049 |
none | 0 | acc_norm | ↑ | 0.5270 | ± | 0.0050 |
Of course, we wouldn't expect tiny-shakespeare to improve these evaluation metrics, but for a real training run, running evaluations with lm-eval provides independent verification of the resulting model.
For a real training run, we aim for a 1.5B model in Qwen architecture, trained on 10B tokens of Climb using 4 x RTX 4090 GPUs:
uv run tokenize_data.py --dataset climb-10b --model qwen
./build/train --model=Qwen/Qwen2.5-1.5B \
--from-scratch \
"--train-file=data/climb-10b-qwen/train-*.bin" --eval-file=data/climb-10b-qwen/eval.bin \
--ckpt-interval=10000 --steps=33900 \
--eval-num-steps=40 \
--learning-rate=0.0006 --final-lr-fraction=0.1 --warmup=150 \
--seq-len=2048 --batch-size=9 \
--model-dtype=bf16 --matmul-dtype=e4m3 --opt-m-dtype=bf16 --opt-v-dtype=bf16 \
--gpus=4 \
--recompute-ffn --recompute-norm \
--shard-weights --persistent-quants --offload-quants --write-combined --memcpy-all-gather
In this case, we tokenize the data into multiple files, so we need to specify a glob expression for --train-file
.
At the moment, the glob gets resolved inside the program, so we need to escape it to prevent the shell from providing multiple values to the option.
We configure the total number of training steps, the learning rate schedule, dtypes, recomputations and weight sharding.
Note that despite the --from-scratch
flag, we need to specify the HF model name; this allows the training script to look up the model
dimensions in the config.json
file.`
Then we can watch the loss go down.
After about 40 hours, the training finishes, and we can run evaluations as above:
Tasks | Version | Filter | n-shot | Metric | Ours | Qwen2.5-1.5B | TinyLLama | ||
---|---|---|---|---|---|---|---|---|---|
arc_challenge | 1 | none | 0 | acc | ↑ | 0.3490 | 0.4104 | 0.3097 | |
none | 0 | acc_norm | ↑ | 0.3729 | 0.4539 | 0.3268 | |||
arc_easy | 1 | none | 0 | acc | ↑ | 0.6911 | 0.7529 | 0.6166 | |
none | 0 | acc_norm | ↑ | 0.6292 | 0.7184 | 0.5547 | |||
boolq | 2 | none | 0 | acc | ↑ | 0.5936 | 0.7251 | 0.5599 | |
copa | 1 | none | 0 | acc | ↑ | 0.7100 | 0.8300 | 0.7500 | |
hellaswag | 1 | none | 0 | acc | ↑ | 0.4094 | 0.5020 | 0.4670 | |
none | 0 | acc_norm | ↑ | 0.5301 | 0.6781 | 0.6147 | |||
openbookqa | 1 | none | 0 | acc | ↑ | 0.2680 | 0.3240 | 0.2520 | |
none | 0 | acc_norm | ↑ | 0.3560 | 0.4040 | 0.3680 | |||
piqa | 1 | none | 0 | acc | ↑ | 0.7329 | 0.7584 | 0.7263 | |
none | 0 | acc_norm | ↑ | 0.7269 | 0.7617 | 0.7356 | |||
winogrande | 1 | none | 0 | acc | ↑ | 0.5485 | 0.6377 | 0.5943 |
Unsurprisingly, the model performs worse than the official Qwen2.5-1.5B model, which has been trained on 1000x as much data. Pitted against TinyLLama, the model does manage to hold its own, in particular the higher-quality climb data seems to have helped with tasks like arc. Not bad form something that can be done on a single workstation in a reasonable amount of time. If you were to rent the corresponding compute on vast.ai, at $0.31 per GPU-hour (Sept 26, 2025), this training run would have cost less than $50.
The training script accepts numerous command-line arguments to configure model training, optimization, memory management, and data handling. Below is a detailed description of each argument category:
--model <path>
- Path to the Hugging Face model directory or name of a HF model that is cached locally. The script will automatically locate model files if given a model name, but it will not download new models.--from-scratch
- Train the model from random initialization instead of loading pre-trained weights.--model-dtype <dtype>
- Data type for model parameters. Supported types arefp32
,bf16
. This is the data type used for model master weights, and also for all activations and gradients. Matrix multiplications may be performed in lower precision according to the--matmul-dtype
argument.
--train-file <path>
- Path to the binary file containing training tokens.--eval-file <path>
- Path to the binary file containing validation tokens.--batch-size, --batch <int>
- Micro-batch size (default: 4). This is the batch size per GPU.--seq-len <int>
- Sequence length for training (default: 1024).
--learning-rate, --lr <float>
- Base learning rate (default: 1e-5).--warmup <int>
- Number of warmup steps. If not specified, defaults to a fraction of total steps.--final-lr-fraction <float>
- Fraction of base learning rate to use for the final steps (default: 1.0).--beta-1 <float>
- Beta1 parameter for Adam optimizer (default: 0.9).--beta-2 <float>
- Beta2 parameter for Adam optimizer (default: 0.95).--grad-accumulation <int>
- Number of micro-batches per optimizer step (default: 4). Effective batch size = batch-size × grad-accumulation × num-gpus.--grad-clip <float>
- Gradient clipping threshold (default: 1.0).--weight-decay <float>
- Weight decay for matrix parameters (default: 1.0).--steps <int>
- Total number of training steps (default: 1000).--eval-every-n-steps <int>
- Run evaluation every N optimizer steps (default: 100).--eval-num-steps <int>
- Number of evaluation batches to process (default: 100). The evaluation at the end of training is always performed on the full dataset.
--out-file <path>
- Path where to save the final trained model (default: "model.safetensors").--checkpoint-dir <path>
- Directory for saving training checkpoints (default: "ckpt").--ckpt-interval <int>
- Save checkpoint every N optimizer steps (default: 100).--continue
- Continue training from the latest checkpoint in checkpoint-dir.--log-file <path>
- Path for training log in JSON format (default: "log.json").--log-gpu-util <int>
- Log GPU utilization every N steps (default: 25). Set to 0 to disable.
--matmul-dtype <dtype>
- Data type for matrix multiplications. If not specified, uses model-dtype.--opt-m-dtype <dtype>
- Data type for first-order momentum (Adam's m). Can befp32
orbf16
.--opt-v-dtype <dtype>
- Data type for second-order momentum (Adam's v). Can befp32
orbf16
.
The following switches enable fine-grained control over which parts of the activations
will be recomputed during the backward pass. Recomputing swiglu
and norm
is computationally cheap -- these operations are memory bound and run at the speed of memcpy. Recomputing matrix multiplications, as required by the other options, is more expensive.
Additional memory savings, especially for larger models, can be achieved by the offloading options described later.
--recompute-swiglu
- Recompute SwiGLU activation during backward pass to save activation memory. As swiglu is at the widest part of the model, this will result in substantial memory savings at only moderate compute increases (especially for large models).--recompute-norm
- Recompute RMS normalizations during backward pass.--recompute-ffn
- Recompute the entire feed-forward block during backward pass (implies --recompute-swiglu).--recompute-qkv
- Recompute QKV projections during backward pass.--recompute-att
- Recompute the entire attention block during backward pass (implies --recompute-qkv).
-
--gpus <int>
- Number of GPUs to use. If 0, uses all available GPUs. -
--zero-level <int>
- ZeRO redundancy optimization level:- 1: Sharded optimizer states (default)
- 2: Sharded gradients + optimizer states
- 3: Sharded weights + gradients + optimizer states
You can also configure weights and gradients individually, using the
--shard-weights
and--shard-gradients
flags. When training in fp8, for example, it makes sense to enable weight sharding before gradient sharding, as weights need only half the amount of bandwidth.
These options enable offloading parts of the model and optimizer state to pinned host memory.
For the optimizer state, this will slow down the optimizer step drastically (memory-bound operation), but if enough gradient accumulation steps are performed, the overall contribution of the optimizer step will be negligible. When training with lower precision, the --persist-quants
option allows avoiding re-quantization of weights; this increases memory, however, when combined with --offload-quants
, the additional memory is placed on the host. In a PCIe setting where any GPU-to-GPU communication has to pass through host memory anway, this can actually lead to significant speed-ups, especially if combined with the --memcpy-all-gather
option.
--offload-master
- Store master weights in pinned host memory.--offload-opt-m
- Store first-order momentum in pinned host memory.--offload-opt-v
- Store second-order momentum in pinned host memory.--persistent-quants
- Keep quantized weights in memory instead of re-quantizing (requires --shard-weights).--offload-quants
- Store quantized weights in pinned host memory (requires --persistent-quants).--use-write-combined
- Use write-combined memory for offloaded tensors. In some situations, this may improve PCie throughput.
--memcpy-all-gather
- Use memcpy for all-gather operations (threads backend only). Memcpy generally gets better bandwidth utilization on PCIe, and does not consume SM resources.--memcpy-send-recv
- Use memcpy for send/receive operations (threads backend only).--all-to-all-reduce
- Use all-to-all-based reduce algorithm (combine with --memcpy-send-recv).--use-cuda-graphs
/--no-use-cuda-graphs
- Enable or disable CUDA graphs for performance.
The src directory contains four major subdirectories: kernels, models, training, and utilities.
- Kernels: The kernels directory contains the CUDA kernels, and each file is mostly self-contained, except for using a small subset of utilities. Each kernel should be declared in kernels.h with all its dtype combinations. The actual cuda kernels should only tyke primitive parameters (i.e., pointers, integers, floats, etc.); but
kernels.h
should also provide aTensor
version of the kernel, which is implemented in kernels.cpp and extracts the data pointers from theTensor
object before dispatching to the right dtype combination. - Utilities: Contains code for memory management, safetensors, error handling, gpu monitoring, etc., as well as basic type declarations and the Tensor class.
- Training: Contains training utilities, such as checkpointing, data loading, and logging. It also defines an abstract Model interface, which is the only way in which training utilities should interact with the model.
- Models: Contains the model implementations. This should not include anything from
training
, except for the Model interface.
TPS: tokens per second. SOL: Speed of light. TTB: Time to a billion tokens.
Note that the numbers below are snapshots on a single machine. There can be significant variability depending on cooling, power supply, etc, especially in non-datacenter deployments.
RTX PRO 6000 (courtesy of Datacrunch)
Using --model-dtype=bf16 --opt-m-dtype=bf16 --opt-v-dtype=bf16 --use-cuda-graphs
, and a total batch size of 524288
for <= 3B (single-GPU 7B uses half the total batch size).
Model | nGPU | DType | Batch | TPS | SOL | TTB |
---|---|---|---|---|---|---|
Qwen2.5-0.5B | 1 | fp8 | 32 | 90k | 36% | 3:05h |
Qwen2.5-0.5B | 1 | bf16 | 32 | 85k | 52% | 3:16h |
Qwen2.5-1.5B | 1 | fp8 | 32 | 36k | 41% | 7:43h |
Qwen2.5-1.5B | 1 | bf16 | 32 | 32k | 61% | 8:41h |
Qwen2.5-3B | 1 | fp8 | 16 | 22k | 46% | 13h |
Qwen2.5-3B | 1 | bf16 | 16 | 17k | 63% | 16h |
Qwen2.5-7B | 1 | fp8 | 4 | 11k | 53% | 25h |
Qwen2.5-7B | 1 | bf16 | 4 | 8k | 67% | 35h |
Qwen2.5-1.5B | 4 | fp8 | 32 | 145k | 40% | 1:55h |
Qwen2.5-1.5B | 4 | bf16 | 32 | 129k | 61% | 2:09h |
Qwen2.5-3B | 4 | fp8 | 16 | 87k | 46% | 3:11h |
Qwen2.5-3B | 4 | bf16 | 16 | 68k | 64% | 4:05h |
Qwen2.5-7B | 4 | fp8 | 8 | 45k | 53% | 6:10h |
Qwen2.5-7B | 4 | bf16 | 4 | 23k | 62% | 12h |
Qwen2.5-14B | 4 | fp8 | 4 | 23k | 52% | 12h |
Qwen2.5-7B* | 4 | fp8 | 16 | 46k | 54% | 6:02h |
The run marked with * uses additional arguments --shard-weights --memcpy-all-gather --persistent-quants --offload-quants
to offload quantized weights, which enables the increase in batch size.
Model | nGPU | DType | Batch | TPS | SOL | TTB |
---|---|---|---|---|---|---|
Qwen2.5-0.5B | 1 | fp8 | 32 | 155k | 32% | 1:47h |
Qwen2.5-0.5B | 1 | bf16 | 32 | 144k | 45% | 1:55h |
Qwen2.5-1.5B | 1 | fp8 | 16 | 68k | 39% | 4:05h |
Qwen2.5-1.5B | 1 | bf16 | 16 | 58k | 55% | 4.47h |
Qwen2.5-3B | 1 | fp8 | 16 | 39k | 42% | 7:07h |
Qwen2.5-3B | 1 | bf16 | 8 | 30k | 57% | 9:15h |
Qwen2.5-7B¹ | 1 | fp8 | 4 | 18k | 42% | 15h |
Qwen2.5-7B | 1 | bf16 | 4 | 14k | 60% | 20h |
^1: --recompute-swiglu --recompute-norm
to fit batch size 4. Smaller batch sizes drastically reduce efficiency.
Model | nGPU | DType | Batch | TPS | SOL | TTB |
---|---|---|---|---|---|---|
Qwen2.5-0.5B | 1 | fp8 | 8 | 49k | 60% | 5.40h |
Qwen2.5-0.5B | 1 | bf16 | 8 | 40k | 75% | 6:56h |
Qwen2.5-1.5B | 1 | fp8 | 4 | 19k | 67% | 14h |
Qwen2.5-1.5B | 1 | bf16 | 4 | 14k | 80% | 20h |
Qwen2.5-3B¹ | 1 | fp8 | 4 | 7.9k | 51% | 35h |
Qwen2.5-3B² | 1 | bf16 | 4 | 6.2k | 71% | 45h |
Qwen2.5-7B³ | 1 | fp8 | 4 | 3.0k | 44% | 92h |
Qwen2.5-7B⁴ | 1 | bf16 | 4 | 1.9k | 27% | 146h |
Qwen2.5-0.5B | 4 | fp8 | 8 | 185k | 57% | 1:30h |
Qwen2.5-0.5B | 4 | bf16 | 8 | 151k | 71% | 1:50h |
Qwen2.5-1.5B⁵ | 4 | fp8 | 8 | 67k | 57% | 4:08h |
Qwen2.5-1.5B⁵ | 4 | bf16 | 8 | 47k | 68% | 5:54h |
Qwen2.5-3B² | 4 | fp8 | 4 | 34k | 55% | 8:10h |
Qwen2.5-3B² | 4 | bf16 | 4 | 24k | 69% | 12h |
Qwen2.5-7B⁶ | 4 | fp8 | 8 | 13k | 47% | 21h |
Qwen2.5-7B⁷ | 4 | bf16 | 8 | 7.0k | 46% | 40h |
Qwen2.5-14B⁶ | 4 | fp8 | 4 | 3.2k | 22% | 87h |
Qwen2.5-14B⁷ | 4 | bf16 | 4 | 2.5k | 33% | 111h |
^1: --offload-opt-m --offload-opt-v --recompute-swiglu --recompute-norm --offload-master
^2: --offload-opt-m --offload-opt-v --recompute-swiglu --recompute-norm
^3: --offload-opt-m --offload-opt-v --recompute-ffn --recompute-norm --recompute-att --offload-master --shard-weights --persistent-quants --offload-quants
^4: --offload-opt-m --offload-opt-v --recompute-ffn --recompute-norm --recompute-att --offload-master --shard-weights --memcpy-all-gather
^5: --recompute-norm --recompute-swiglu
^6: --offload-opt-m --offload-opt-v --recompute-ffn --recompute-norm --recompute-att --offload-master --shard-weights --memcpy-all-gather --shard-gradients --memcpy-send-recv --all-to-all-reduce --persistent-quants --offload-quants
^7: --offload-opt-m --offload-opt-v --recompute-ffn --recompute-norm --recompute-att --offload-master --shard-weights --memcpy-all-gather --shard-gradients --memcpy-send-recv --all-to-all-reduce
Model | nGPU | DType | Batch | TPS | SOL | TTB |
---|---|---|---|---|---|---|
Qwen2.5-0.5B | 1 | fp8 | 32 | 48k | 53% | 5.47h |
Qwen2.5-0.5B | 1 | bf16 | 16 | 42k | 72% | 6:36h |
Qwen2.5-1.5B | 1 | fp8 | 8 | 20k | 63% | 14h |
Qwen2.5-1.5B | 1 | bf16 | 8 | 17k | 87% | 16h |
Qwen2.5-3B | 1 | fp8 | 4 | 11k | 66% | 25h |
Qwen2.5-3B | 1 | bf16 | 4 | 8.9k | 93% | 31h |
Qwen2.5-7B¹ | 1 | fp8 | 4 | 5.1k | 65% | 71h |
Qwen2.5-7B² | 1 | bf16 | 4 | 3.9k | 93% | 45h |
Qwen2.5-0.5B | 4 | fp8 | 32 | 189k | 52% | 1.28h |
Qwen2.5-0.5B | 4 | bf16 | 16 | 170k | 72% | 1:38h |
Qwen2.5-1.5B | 4 | fp8 | 16 | 81k | 62% | 3.25h |
Qwen2.5-1.5B | 4 | bf16 | 8 | 65k | 85% | 4:16h |
Qwen2.5-3B | 4 | fp8 | 8 | 45k | 65% | 6:10h |
Qwen2.5-3B | 4 | bf16 | 8 | 35k | 90% | 7:56h |
Qwen2.5-7B¹ | 4 | fp8 | 4 | 18k | 59% | 15h |
Qwen2.5-7B² | 4 | bf16 | 4 | 15k | 88% | 18h |
Qwen2.5-14B³ | 4 | fp8 | 4 | 7.1k | 45% | 39h |
Qwen2.5-14B³ | 4 | bf16 | 4 | 4.8k | 57% | 58h |
^1: --offload-opt-m --offload-opt-v --recompute-swiglu --recompute-norm --offload-master
^2: --offload-opt-m --offload-opt-v --recompute-swiglu --recompute-norm
^3: --offload-opt-m --offload-opt-v --recompute-ffn --recompute-norm --offload-master --shard-weights --persistent-quants --offload-quants
^4: --offload-opt-m --offload-opt-v --recompute-ffn --recompute-norm --offload-master --shard-weights
The CMake target recompute-test
produces an executable that accepts a subset of the training arguments. It runs ten training
steps, then resets the data loader and runs the same steps again, but this time with all recomputation options disabled.
The resulting losses and gradient norms should be bitwise identical to the first run.
Example: ./recompute-test --matmul-dtype=e4m3 --recompute-ffn
validates that, when running in fp8 mode, the ffn block is recomputed bitwise identical to the first run.
This implementation wouldn't have been possible without Andrej Karpathy's llm.c.