Summary
This release brings major upgrades in performance and platform compatibility (most notably, a new Metal
backend via WGPU passthrough). CubeCL now powers backends for Cuda
, Metal
, Rocm
, Vulkan
and WebGpu
. Tensor operation fusion support has been greatly expanded to optimize element-wise, reductions and matmul operations.
A new compilation cache and improved autotune cache speed up repeated runs by reusing precompiled binaries and tuned kernel configurations. Data parallel training now scales better across multiple GPUs with automatic batch assignment to each worker. A new tensor slice API offers a simpler, more intuitive way to index tensors.
This version also comes with broad performance gains across tensor operations, especially for reductions, matmul, and convolutions. An initial implementation of quantized matmul is now available, with further quantization improvements planned in the future.
As with previous releases, this includes various bug fixes, further optimizations and enhanced documentation.
Be sure to check out the new burn-bench to compare performance across different versions, hardware and backends.
CubeCL Backends
Burn supports Cuda
, Rocm
, Vulkan
, WebGpu
, and the newly added Metal
backend.
Each backend can be used through their respective type aliases, provided that the appropriate backend feature flag is also enabled.
Metal
burn = { version = "0.17.0", features = ["metal"] }
use burn::prelude::*;
use burn::backend::wgpu::{Metal, WgpuDevice};
let tensor = Tensor::<Metal, 2>::zeros([2, 4], &WgpuDevice::default());
Cuda
burn = { version = "0.17.0", features = ["cuda"] }
use burn::prelude::*;
use burn::backend::cuda::{Cuda, CudaDevice};
let tensor = Tensor::<Cuda, 2>::zeros([2, 4], &CudaDevice::default());
Rocm
burn = { version = "0.17.0", features = ["rocm"] }
use burn::prelude::*;
use burn::backend::rocm::{Rocm, HipDevice};
let tensor = Tensor::<Rocm, 2>::zeros([2, 4], &HipDevice::default());
Vulkan
burn = { version = "0.17.0", features = ["vulkan"] }
use burn::prelude::*;
use burn::backend::wgpu::{Vulkan, WgpuDevice};
let tensor = Tensor::<Vulkan, 2>::zeros([2, 4], &WgpuDevice::default());
WebGpu
burn = { version = "0.17.0", features = ["webgpu"] }
use burn::prelude::*;
use burn::backend::wgpu::{WebGpu, WgpuDevice};
let tensor = Tensor::<WebGpu, 2>::zeros([2, 4], &WgpuDevice::default());
Warning
When using one of the wgpu
backends, you may encounter compilation errors related to recursive type evaluation. This is due to complex type nesting within the wgpu
dependency chain.
To resolve this issue, add the following line at the top of your main.rs
or lib.rs
file:
#![recursion_limit = "256"]
The default recursion limit (128) is often just below the required depth (typically 130-150) due to deeply nested associated types and trait bounds.
Data Loader and Batcher
The Batcher
trait has been updated to improve multi-device support. Previously, batcher implementations stored a device internally, which could lead to all data being loaded on the same device. The latest changes have the DataLoader
generic over the backend, while the device is passed explicitly:
-impl<B: Backend> Batcher<MyItem, MyBatch<B>> for MyBatcher<B> {
+impl<B: Backend> Batcher<B, MyItem, MyBatch<B>> for MyBatcher {
- fn batch(&self, items: Vec<MyItem>) -> MyBatch<B> {
+ fn batch(&self, items: Vec<MyItem>, device: &B::Device) -> MyBatch<B> {
// The correct `device` is already provided for the batching logic to use
}
}
The device can now be set when building a data loader:
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(batch_size)
.shuffle(seed)
.num_workers(num_workers)
+ .set_device(device)
.build(dataset);
This step is not required for the Learner
, which handles the device configuration automatically.
Better Tensor Slicing & Indexing
Tensor slicing now fully adopts idiomatic Rust range syntax, replacing the older (i64, i64)
and Option tuple forms.
For example:
let tensor = Tensor::<B, 2>::zeros([m, n], &device);
-let slice = tensor.slice([(0, -1), (0, -2)]);
+let slice = tensor.slice([0..-1, 0..-2]);
For more complex or mixed range types, use the s![]
macro:
let tensor = Tensor::<B, 3>::zeros([b, s, d], &device);
-let slice = tensor.slice([None, Some((t as i64, t as i64 + 1)), None]);
+let slice = tensor.slice(s![.., t..t + 1, ..]);
The macro is inspired by ndarray's s![] (at least, by name) and helps build flexible slice patterns.
use burn::prelude::*;
let tensor = Tensor::<B, 4>::zeros([8, 4, 2, 3], &device);
let slice = tensor.slice(s![..=4, 0..=3, .., -1]);
assert_eq!(slice.dims(), [5, 4, 2, 1]);
Changelog
Module & Tensor
- Feature add new one hot function meeting multi-dimensions (ranks) (#2613) @tiruka
- Expand GRU support (#2704) @nwhitehead
- feat: bitwise-ops-for-tensors (#2498) @quinton11
- Feat: Add PoissonNLL loss (#2765) @salvomcl
- Add metric parametrized name (#2808) @laggui
- Add boolean and/or to bool tensors (#2802) @wingertge
- Add ATOL/RTOL defaults (#2824) @crutcher
- Feat: Add tan trig function (#2854) @Msa360
- Refactor quantization schemes (#2849 #3036) @laggui @maxtremblay
- Vectorize pooling for optimization (#2905) @wingertge
- Feat: Add Cosh and Sinh (#2959) @Msa360
- Refactor in-memory recorder load args (#2892) @BjornTheProgrammer
- Improve gradient checkpointing (#2997) @nathanielsimard
- Optimize minmax (#3009) @nathanielsimard
- Improve
tensor.slice(...)
to support multiple range types (#3061) @laggui
Bug Fixes
- Fix bce loss log (#2741) @laggui
- Fix repeat_dim backward w/ dim size > 1 (#2777) @laggui
- [Fix]
tch
upgrade (#2834) @wingertge - Check channels_in matches in convolution layers (#2944) @chlobes
- Fixed GroupNorm implementation (#2945) @computer-whisperer
Backends
- Migrate to type magic autotune (#2710) @wingertge
- Feat/fused matmul tune (#2726) @nathanielsimard
- Feat/shared sum (#2737) @maxtremblay
- Improve fusion for broadcasting, mix vectorization and reshape operation (#2773 #2833) @nathanielsimard
- Fuse gather (#2793) @nathanielsimard
- Feat/fuse select (#2797 #2804 #2903) @nathanielsimard
- Remove from_data conversions in backends (#2783) @laggui
- Feat fuse swap dims (#2801 #2877) @nathanielsimard
- [Feature] reduce fuse on read (#2870) @nathanielsimard
- [Feat] SIMD acceleration for ndarray backend (#2851) @wingertge
- Perf/reduce fuse on write (#2937) @nathanielsimard
- [metal] Add CubeCL metal compiler support (#2993) @syl20bnr
- Compilation Cache (#3020) @nathanielsimard
- Cubecl quantize matmul (#3022 #3030) @maxtremblay
Bug Fixes
- Fix from data fusion (#2735 #2778) @laggui @nathanielsimard
- Fix constant creation in fusion to cast at compile time, not runtime (#2782) @wingertge
- Fix two autotune issues on wasm (#2899) @ArthurBrussee
- Fix/reduce out of bounds (#2906) @nathanielsimard
- Fix fusion bug (#3031) @nathanielsimard
- Fix metal backend name (#3040) @nathanielsimard
- Fix matmul dynamic line size support (#3056) @nathanielsimard
- Fix: matmul lower precision / flex32 (#3059) @nathanielsimard
- Fix/autotune cache conflicts (#3070) @nathanielsimard
Documentation & Examples
- Wasserstein Generative Adversarial Network (#2660) @wangjiawen2013
- Add modern lstm (#2752) @wangjiawen2013
- Improve tensor docs (#2951) @PtiLuky
Fixes
- chore: fix some comments (#2717) @sunxunle
- Add hardsigmoid formula and fix WGAN doc + default lr (#2706) @laggui
- Fix db-pedia-infer backend (#2736) @laggui
- Fixed typo in the burn book chapter advanced unit no-std. (#2731) @xmy314
- typo - correct
smp_serde
tormp_serde
as per crate's name in url (#2744) @cameronbraid - typo - missing
tick
which was breaking formatting (#2745) @cameronbraid - Remove autodiff from generate (#2759) @laggui
- Remove empty format precision specifier (#2785) @hkBst
- Update tch instructions (#2844 #2976) @laggui
- Fix from_embedded and bool ops docs (#2848) @laggui
- Fix tiny typo in mathematical expression (#2867) @janhohenheim
- Fix typos (#2927) @crutcher
- Fix/web example (#2954 #2978) @laggui
- Fix: burn-book getting-started Use Declarations (#2966) @jerryshell
- chore: fix comment (#3008) @tsinghuacoder
ONNX Support
- Code generation bug fix for ONNX import (#2708) @antimora
- Floor Node (#2792) @akshitgaur2005
- One hot ONNX (#2784) @akshitgaur2005
- Onnx op topk (#2305) @oojo12
- Fix output elem type for
unsqueeze
andreshape
(#2807) @christeefy - Feat/Split ONNX Import (#2568) @agelas
- Refactor GatherNode to support scalar outputs. (#2828) @loloxwg
- Rename dim to rank for ONNX import (#2831) @antimora
- Add rank inference for tan (#2868) @Msa360
- Add Gemm (#2841) @akshitgaur2005
- Fix RandomNormalLike ONNX node output rank (#2936) @Knight-Ops
- Support multiple outputs being tracked in BurnGraph during ONNX conversion (#2938) @Knight-Ops
- Ignore ONNX optional node inputs/outputs (#2935) @Knight-Ops
- Fix ONNX flatten to match spec (#2940) @catch-twenty-two
- burn-import: add some tests for ConstantNode (#2623) @jameshiew @laggui
- Update SUPPORTED-ONNX-OPS.md with the latest info (#3064) @antimora
Enhancements
- Add new burn-vision crate (#2753 #2810 #2842) @wingertge
- Improve Burn compilation times (#2815 #2994) @nathanielsimard
- Support training in no-std (#2830) @ivila
- Perf: Speed up element and TensorData conversion (#2913) @wingertge
- Feat/cubecl caching (#2902) @nathanielsimard
- Improve multi-device data loading strategy (#2890 #3035) @laggui
- Autotune level matmul double buffering (#2988) @nathanielsimard @louisfd
Refactoring
- Remove deprecated Data and DataSerialize (#2703) @laggui
- Clean up train system metrics (#2707) @laggui
- Move IR to its own crate (#2796 #2798) @laggui
- Refactor burn jit => burn-cubecl (#2809) @nathanielsimard
- Cleanup Tensor Registry in fusion (#2826) @nathanielsimard
- Migrate conv2d to cubecl (#2908 #3018) @wingertge
- Update to edition 2024 (#2931) @laggui
- Update runtime names (#2909) @nathanielsimard
- Migrate backend comparison (#2961) @laggui
- Improve test tolerance assertions (#3024) @maxtremblay @laggui
- [hip] Move burn-hip to burn-rocm and rename backend to ROCm (#3062) @syl20bnr
Miscellaneous
- Fix no default features flags + update cubecl (#2725) @laggui
- Replace return with terminate (#2742) @maxtremblay
- Clean up -jit suffix in feature flags and modules (#2705) @laggui
- Fix types under autotune flag (#2750) @laggui
- Fix BackendValues in backend-comparison after removal of jit suffix (#2756) @syl20bnr
- Update cubecl (#2764) @wingertge
- Fix optional burn-import dep + impl module types for isize (#2774) @laggui
- Update cubecl with fix to shared_sum (#2779) @maxtremblay
- feat: using rustls instead of native-tls (#2799) @ShoofLLC
- bump cubecl version with dummy implementations (#2814) @maxtremblay
- Add data_dir optional argument to Huggingface DataLoader to enable some manual download use cases (#2817) @Pablo1785
- Bump xtask to 1.1.9 (#2896) @syl20bnr
- Fix test checks for macos (#2952) @PtiLuky
- Update cargo deps (#2962) @Brooooooklyn
- Add train end event (#2967) @laggui
- Update cubecl bitcast -> reinterpret (#2985) @maxtremblay
- Update cubecl (#2869 #2888 #2990 #2996) @louisfd
- Update wgpu to v25 (#3007) @syl20bnr
- update cubecl: sync full cyclic checked (#3025) @louisfd
- Fix autotune measurement (#3043) @nathanielsimard