Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
117b8e8
Remove test_ad calls from interface tests
penelopeysm Aug 13, 2025
d0e6a6f
remove dead code
penelopeysm Aug 13, 2025
fb6d71c
don't remove that
penelopeysm Aug 13, 2025
b020f9e
Remove macOS from CI
penelopeysm Aug 13, 2025
cd4e62f
fix AD
penelopeysm Aug 13, 2025
78f3ef9
Refactor AD tests
penelopeysm Aug 14, 2025
982d19b
add doctests, fix other things
penelopeysm Aug 14, 2025
1b9964a
fix
penelopeysm Aug 14, 2025
4f1004c
format
penelopeysm Aug 14, 2025
481c5ca
Disable failing Enzyme rule tests
penelopeysm Aug 14, 2025
eb4bdfb
Re-enable Enzyme tests
penelopeysm Aug 14, 2025
1400d75
format
penelopeysm Aug 14, 2025
c993dfb
separate mooncake rule test into its own thing
penelopeysm Aug 14, 2025
ef21f50
Skip broken Enzyme tests
penelopeysm Aug 15, 2025
e81001e
add macOS and windows test
penelopeysm Aug 15, 2025
d67f0c0
format
penelopeysm Aug 15, 2025
28d0ca8
flows only fail on 1.11 apparently
penelopeysm Aug 15, 2025
a7852b6
comment
penelopeysm Aug 15, 2025
a86b716
Use new DI patch
penelopeysm Aug 22, 2025
098ef8d
Merge branch 'main' into py/tests2
penelopeysm Aug 22, 2025
ce870f7
Fix Enzyme tests again
penelopeysm Aug 22, 2025
ba4db0e
Merge branch 'main' into py/tests2
penelopeysm Sep 11, 2025
1d8e224
fix merge
penelopeysm Sep 11, 2025
9971c97
Add link to issue
penelopeysm Sep 12, 2025
af877e8
Remove warnonly=true from docs build
penelopeysm Sep 15, 2025
4b2cbba
Fix wrong function in test
penelopeysm Sep 15, 2025
fac6c00
update Enzyme failures
penelopeysm Sep 16, 2025
7382f7b
fix more stuff
penelopeysm Sep 16, 2025
ebe3de3
That should fix ReverseDiff
penelopeysm Sep 26, 2025
99cf586
I meant this
penelopeysm Sep 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions .github/workflows/AD.yml → .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: AD tests
name: CI

on:
push:
Expand All @@ -14,33 +14,30 @@ concurrency:

jobs:
test:
runs-on: ${{ matrix.os }}
runs-on: ${{ matrix.runner.os }}
strategy:
fail-fast: false
matrix:
version:
- 'min'
- '1'
os:
- ubuntu-latest
- macOS-latest
AD:
- Enzyme
- ForwardDiff
- Mooncake
- Tracker
- ReverseDiff
runner:
- version: '1'
os: 'ubuntu-latest'
- version: '1'
os: 'macos-latest'
- version: '1'
os: 'windows-latest'
- version: 'min'
os: 'ubuntu-latest'
group:
- 'Interface'
- 'AD'

steps:
- uses: actions/checkout@v5

- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}

version: ${{ matrix.runner.version }}
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1

- uses: julia-actions/julia-runtest@v1
continue-on-error: ${{ matrix.AD == 'Enzyme' && matrix.version == '1' }}
env:
GROUP: AD
AD: ${{ matrix.AD }}
GROUP: ${{ matrix.group }}
41 changes: 41 additions & 0 deletions .github/workflows/DocTests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# We want to only run doctests on a single version of Julia, because
# things like error messages / output can change between versions and
# is fragile to test against.
name: Doctests

on:
push:
branches:
- main
pull_request:
merge_group:
types: [checks_requested]

# needed to allow julia-actions/cache to delete old caches that it has created
permissions:
actions: write
contents: read

# Cancel existing tests on the same PR if a new commit is added to a pull request
concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- uses: julia-actions/setup-julia@v2
with:
version: '1'

- uses: julia-actions/cache@v2

- uses: julia-actions/julia-buildpkg@v1

- uses: julia-actions/julia-runtest@v1
env:
GROUP: Doctests
36 changes: 0 additions & 36 deletions .github/workflows/Interface.yml

This file was deleted.

2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
Documenter = "1.14"
Documenter = "1"
Functors = "0.3"
StableRNGs = "1"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ makedocs(;
"Examples" => "examples.md",
],
checkdocs=:exports,
doctest=false,
)
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand Down Expand Up @@ -32,6 +33,7 @@ AdvancedHMC = "0.6, 0.7, 0.8"
ChainRulesTestUtils = "0.7, 1"
ChangesOfVariables = "0.1"
Combinatorics = "1.0.2"
DifferentiationInterface = "0.7.7"
DistributionsAD = "0.6.3"
Documenter = "1"
Enzyme = "0.13.12"
Expand Down
47 changes: 9 additions & 38 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using Random: Xoshiro
module BijectorsChainRulesTests

using Bijectors
using LinearAlgebra
using ChainRulesTestUtils: ChainRulesCore
using ChainRulesTestUtils: ChainRulesCore, test_frule, test_rrule, ⊢, ChainRulesTestUtils
using FiniteDifferences: FiniteDifferences
using Random: Xoshiro
using Test

# HACK: This is a workaround to test `Bijectors._inv_link_chol_lkj` which produces an
# upper-triangular `Matrix`, leading to `test_rrule` comaring the _full_ `Matrix`,
Expand Down Expand Up @@ -29,42 +34,6 @@ end
test_frule(Bijectors.find_alpha, x, y, z)
test_rrule(Bijectors.find_alpha, x, y, z)

if @isdefined Mooncake
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to test/ad/mooncake.jl

rng = Xoshiro(123456)
@testset "$mode" for mode in (Mooncake.ReverseMode, Mooncake.ForwardMode)
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
z;
is_primitive=true,
perf_flag=:none,
mode=mode,
)
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
3;
is_primitive=true,
perf_flag=:none,
mode=mode,
)
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
UInt32(3);
is_primitive=true,
perf_flag=:none,
mode=mode,
)
end
end

test_rrule(
Bijectors.combine,
Bijectors.PartitionMask(3, [1], [2]) ⊢ ChainRulesTestUtils.NoTangent(),
Expand Down Expand Up @@ -182,3 +151,5 @@ end
end
end
end

end # module BijectorsChainRulesTests
62 changes: 38 additions & 24 deletions test/ad/corr.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,49 @@
@testset "AD for VecCorrBijector" begin
d = 4
dist = LKJ(d, 2.0)
b = bijector(dist)
binv = inverse(b)
using Enzyme: ForwardMode

x = rand(dist)
y = b(x)
@testset "VecCorrBijector: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
ENZYME_FWD_AND_1p11 = VERSION >= v"1.11" && adtype isa AutoEnzyme{<:Enzyme.ForwardMode}

test_ad(y) do x
sum(transform(b, binv(x)))
end
@testset "d = $d" for d in (1, 2, 4)
dist = LKJ(d, 2.0)
b = bijector(dist)
binv = inverse(b)

x = rand(dist)
y = b(x)

test_ad(y) do y
sum(transform(binv, y))
roundtrip(y) = sum(transform(b, binv(y)))
inverse_only(y) = sum(transform(binv, y))
if d == 4 && ENZYME_FWD_AND_1p11
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError test_ad(
roundtrip, adtype, y
)
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError test_ad(
inverse_only, adtype, y
)
else
test_ad(roundtrip, adtype, y)
test_ad(inverse_only, adtype, y)
end
end
end

@testset "AD for VecCholeskyBijector" begin
d = 4
dist = LKJCholesky(d, 2.0)
b = bijector(dist)
binv = inverse(b)
@testset "VecCholeskyBijector: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain TEST_ADTYPES array to me in more detail? Is this the best way to do things with this fixed array?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much. The idea is that Bijectors provides these functions f and you want to make sure that it can be differentiated with all adtypes of interest. The same sort of loop is used elsewhere see e.g.

https://github.com/TuringLang/Turing.jl/blob/296f6540d069afd7347ad17680ea87cf7a666652/test/ad.jl#L240-L264

https://github.com/TuringLang/DynamicPPL.jl/blob/72491583d9b04f9e4689b756d0d593fa94caade1/test/ad.jl#L40-L47

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and each adtype corresponds to an AD backend of interest (plus some configuration options e.g. whether to use forward- or reverse-mode).

@testset "d = $d, uplo = $uplo" for d in (1, 2, 4), uplo in ('U', 'L')
dist = LKJCholesky(d, 2.0, uplo)
b = bijector(dist)
binv = inverse(b)

x = rand(dist)
y = b(x)
x = rand(dist)
y = b(x)
cholesky_to_triangular =
uplo == 'U' ? Bijectors.cholesky_upper : Bijectors.cholesky_lower

test_ad(y) do y
sum(transform(b, binv(y)))
end
roundtrip(y) = sum(transform(b, binv(y)))
test_ad(roundtrip, adtype, y)

test_ad(y) do y
sum(Bijectors.cholesky_upper(transform(binv, y)))
# we need to tack on `cholesky_upper`/`cholesky_lower`, because directly calling
# `sum` on a LinearAlgebra.Cholesky doesn't give a scalar
inverse_only(y) = sum(cholesky_to_triangular(transform(binv, y)))
test_ad(inverse_only, adtype, y)
end
end
Loading
Loading