Skip to content
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
name = "ChangesOfVariables"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.1"
version = "0.1.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "1"
julia = "1"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[targets]
test = ["Documenter", "ForwardDiff"]
test = ["ChainRulesTestUtils", "Documenter", "ForwardDiff"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ changes for functions that perform a change of variables (like coordinate
transformations).

`ChangesOfVariables` is a very lightweight package and has no dependencies
beyond `Base`, `LinearAlgebra` and `Test`.
beyond `Base`, `LinearAlgebra`, `Test` and `ChainRulesCore`.

## Documentation

Expand Down
1 change: 1 addition & 0 deletions src/ChangesOfVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ transformations).
"""
module ChangesOfVariables

using ChainRulesCore
using LinearAlgebra
using Test

Expand Down
27 changes: 20 additions & 7 deletions src/with_ladj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,30 @@ export with_logabsdet_jacobian
end


@inline _get_y(y_with_ladj::NTuple{2,Any,}) = y_with_ladj[1]
@inline _get_ladj(y_with_ladj::NTuple{2,Any}) = y_with_ladj[2]

_with_ladj_on_mapped(map_or_bc::Function, y_with_ladj::Tuple{Any,Real}) = y_with_ladj
function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj::Tuple{Any,Real}) where {F<:Union{typeof(map),typeof(broadcast)}}
return y_with_ladj
end

function _with_ladj_on_mapped(map_or_bc::Function, y_with_ladj)
y = map_or_bc(_get_y, y_with_ladj)
ladj = sum(map_or_bc(_get_ladj, y_with_ladj))
function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
y = map_or_bc(first, y_with_ladj)
ladj = sum(last, y_with_ladj)
(y, ladj)
end


# Need to use a type for this, type inference fails when using a pullback
# closure over YLT in the rrule, resulting in bad performance:
struct WithLadjOnMappedPullback{YLT} <: Function end
function (::WithLadjOnMappedPullback{YLT})(thunked_ΔΩ) where YLT
ys, ladj = unthunk(thunked_ΔΩ)
return NoTangent(), NoTangent(), map(y -> Tangent{YLT}(y, ladj), ys)
end

function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
YLT = eltype(y_with_ladj)
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), WithLadjOnMappedPullback{YLT}()
end

function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)
map_or_bc = mapped_f.f
f = mapped_f.x
Expand Down
8 changes: 8 additions & 0 deletions test/test_with_ladj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using Test

using LinearAlgebra

using ChangesOfVariables
using ChangesOfVariables: test_with_logabsdet_jacobian
using ChainRulesTestUtils

include("getjacobian.jl")

Expand Down Expand Up @@ -59,4 +61,10 @@ include("getjacobian.jl")
test_with_logabsdet_jacobian(f, x, getjacobian)
end
end

@testset "rrules" begin
for map_or_bc in (map, broadcast)
test_rrule(ChangesOfVariables._with_ladj_on_mapped, map_or_bc, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
end
end
end