Skip to content
11 changes: 11 additions & 0 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ end
end
end

Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.AbstractVectorOfArray, ::Val{:u})
function literal_VectorOfArray_x_adjoint(d)
m = map(enumerate(d)) do (idx, d_i)
isnothing(d_i) && return zero(A.u[idx])
d_i
end
(VectorOfArray(m), nothing)
end
A.u, literal_VectorOfArray_x_adjoint
end

@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
function literal_ArrayPartition_x_adjoint(d)
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
Expand Down
6 changes: 6 additions & 0 deletions test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,9 @@ loss(x)
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)

voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3))
voa_gs, = Zygote.gradient(voa) do x
sum(sum.(x.u))
end
@test voa_gs isa RecursiveArrayTools.VectorOfArray
Loading