Skip to content

Commit f1e9526

Browse files
fix: fix several adjoints, copy and zero methods for VoA
1 parent 030923c commit f1e9526

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ end
5050
Colon, BitArray, AbstractArray{Bool}}...)
5151
function AbstractVectorOfArray_getindex_adjoint(Δ)
5252
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
53-
Δ′[i, j...] = Δ
53+
if isempty(j)
54+
Δ′.u[i] = Δ
55+
else
56+
Δ′[i, j...] = Δ
57+
end
5458
(Δ′, nothing, map(_ -> nothing, j)...)
5559
end
5660
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
@@ -104,13 +108,25 @@ end
104108
end
105109

106110
@adjoint function Base.Array(VA::AbstractVectorOfArray)
107-
Array(VA),
108-
y -> (Array(y),)
111+
adj = let VA=VA
112+
function Array_adjoint(y)
113+
VA = copy(VA)
114+
VA .= y
115+
return (VA,)
116+
end
117+
end
118+
Array(VA), adj
109119
end
110120

111121
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
112-
view(A, I...),
113-
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
122+
adj = let A = A, I = I
123+
function view_adjoint(y)
124+
A = zero(A)
125+
view(A, I...) .= y
126+
return (A, map(_ -> nothing, I)...)
127+
end
128+
end
129+
view(A, I...), adj
114130
end
115131

116132
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))

src/vector_of_array.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,16 @@ function DiffEqArray(vec::AbstractVector{T},
160160
p,
161161
sys)
162162
end
163+
function DiffEqArray(vec::AbstractVector{VT},
164+
ts::AbstractVector,
165+
::NTuple{N, Int},
166+
p = nothing,
167+
sys = nothing) where {T, N, VT <: AbstractArray{T, N}}
168+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec,
169+
ts,
170+
p,
171+
sys)
172+
end
163173
# Assume that the first element is representative of all other elements
164174

165175
function DiffEqArray(vec::AbstractVector,
@@ -466,19 +476,25 @@ end
466476
tuples(VA::DiffEqArray) = tuple.(VA.t, VA.u)
467477

468478
# Growing the array simply adds to the container vector
469-
function Base.copy(VA::AbstractDiffEqArray)
470-
typeof(VA)(copy(VA.u),
471-
copy(VA.t),
472-
(VA.p === nothing) ? nothing : copy(VA.p),
473-
(VA.sys === nothing) ? nothing : copy(VA.sys))
479+
function _copyfield(VA, fname)
480+
if fname == :u
481+
copy(VA.u)
482+
elseif fname == :t
483+
copy(VA.t)
484+
else
485+
getfield(VA, fname)
486+
end
487+
end
488+
function Base.copy(VA::AbstractVectorOfArray)
489+
typeof(VA)((_copyfield(VA, fname) for fname in fieldnames(typeof(VA)))...)
474490
end
475-
Base.copy(VA::AbstractVectorOfArray) = typeof(VA)(copy(VA.u))
476-
477-
Base.zero(VA::AbstractVectorOfArray) = VectorOfArray(Base.zero.(VA.u))
478491

479-
function Base.zero(VA::AbstractDiffEqArray)
480-
u = Base.zero.(VA.u)
481-
DiffEqArray(u, VA.t, parameter_values(VA), symbolic_container(VA))
492+
function Base.zero(VA::AbstractVectorOfArray)
493+
val = copy(VA)
494+
for i in eachindex(VA.u)
495+
val.u[i] = zero(VA[i])
496+
end
497+
return val
482498
end
483499

484500
Base.sizehint!(VA::AbstractVectorOfArray{T, N}, i) where {T, N} = sizehint!(VA.u, i)

0 commit comments

Comments
 (0)