Skip to content

Commit ccaa311

Browse files
Merge pull request #336 from AayushSabharwal/as/adjoint-fix
fix: fix several adjoints, copy and zero methods for VoA
2 parents 030923c + f1a40fc commit ccaa311

File tree

3 files changed

+106
-18
lines changed

3 files changed

+106
-18
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ end
5050
Colon, BitArray, AbstractArray{Bool}}...)
5151
function AbstractVectorOfArray_getindex_adjoint(Δ)
5252
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
53-
Δ′[i, j...] = Δ
53+
if isempty(j)
54+
Δ′.u[i] = Δ
55+
else
56+
Δ′[i, j...] = Δ
57+
end
5458
(Δ′, nothing, map(_ -> nothing, j)...)
5559
end
5660
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
@@ -104,13 +108,25 @@ end
104108
end
105109

106110
@adjoint function Base.Array(VA::AbstractVectorOfArray)
107-
Array(VA),
108-
y -> (Array(y),)
111+
adj = let VA=VA
112+
function Array_adjoint(y)
113+
VA = copy(VA)
114+
copyto!(VA, y)
115+
return (VA,)
116+
end
117+
end
118+
Array(VA), adj
109119
end
110120

111121
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
112-
view(A, I...),
113-
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
122+
adj = let A = A, I = I
123+
function view_adjoint(y)
124+
A = zero(A)
125+
view(A, I...) .= y
126+
return (A, map(_ -> nothing, I)...)
127+
end
128+
end
129+
view(A, I...), adj
114130
end
115131

116132
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))

src/vector_of_array.jl

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,24 @@ function DiffEqArray(vec::AbstractVector{T},
160160
p,
161161
sys)
162162
end
163+
164+
# ambiguity resolution
165+
function DiffEqArray(vec::AbstractVector{VT},
166+
ts::AbstractVector,
167+
::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}}
168+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec,
169+
ts,
170+
nothing,
171+
nothing)
172+
end
173+
function DiffEqArray(vec::AbstractVector{VT},
174+
ts::AbstractVector,
175+
::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}}
176+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec,
177+
ts,
178+
p,
179+
nothing)
180+
end
163181
# Assume that the first element is representative of all other elements
164182

165183
function DiffEqArray(vec::AbstractVector,
@@ -174,9 +192,10 @@ function DiffEqArray(vec::AbstractVector,
174192
something(parameters, []),
175193
something(independent_variables, [])))
176194
_size = size(vec[1])
195+
T = eltype(vec[1])
177196
return DiffEqArray{
178-
eltype(eltype(vec)),
179-
length(_size),
197+
T,
198+
length(_size) + 1,
180199
typeof(vec),
181200
typeof(ts),
182201
typeof(p),
@@ -466,19 +485,25 @@ end
466485
tuples(VA::DiffEqArray) = tuple.(VA.t, VA.u)
467486

468487
# Growing the array simply adds to the container vector
469-
function Base.copy(VA::AbstractDiffEqArray)
470-
typeof(VA)(copy(VA.u),
471-
copy(VA.t),
472-
(VA.p === nothing) ? nothing : copy(VA.p),
473-
(VA.sys === nothing) ? nothing : copy(VA.sys))
488+
function _copyfield(VA, fname)
489+
if fname == :u
490+
copy(VA.u)
491+
elseif fname == :t
492+
copy(VA.t)
493+
else
494+
getfield(VA, fname)
495+
end
496+
end
497+
function Base.copy(VA::AbstractVectorOfArray)
498+
typeof(VA)((_copyfield(VA, fname) for fname in fieldnames(typeof(VA)))...)
474499
end
475-
Base.copy(VA::AbstractVectorOfArray) = typeof(VA)(copy(VA.u))
476-
477-
Base.zero(VA::AbstractVectorOfArray) = VectorOfArray(Base.zero.(VA.u))
478500

479-
function Base.zero(VA::AbstractDiffEqArray)
480-
u = Base.zero.(VA.u)
481-
DiffEqArray(u, VA.t, parameter_values(VA), symbolic_container(VA))
501+
function Base.zero(VA::AbstractVectorOfArray)
502+
val = copy(VA)
503+
for i in eachindex(VA.u)
504+
val.u[i] = zero(VA.u[i])
505+
end
506+
return val
482507
end
483508

484509
Base.sizehint!(VA::AbstractVectorOfArray{T, N}, i) where {T, N} = sizehint!(VA.u, i)
@@ -563,6 +588,16 @@ end
563588
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N}
564589
copyto!.(dest.u, src.u)
565590
end
591+
function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T, N}) where {T, N}
592+
for (i, slice) in enumerate(eachslice(src, dims = ndims(src)))
593+
copyto!(dest.u[i], slice)
594+
end
595+
dest
596+
end
597+
function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T}) where {T, N}
598+
copyto!(dest.u, src)
599+
dest
600+
end
566601
# Required for broadcasted setindex! when slicing across subarrays
567602
# E.g. if `va = VectorOfArray([rand(3, 3) for i in 1:5])`
568603
# Need this method for `va[2, :, :] .= 3.0`

test/interface_tests.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using RecursiveArrayTools, StaticArrays, Test
22
using FastBroadcast
3+
using SymbolicIndexingInterface: SymbolCache
34

45
t = 1:3
56
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -149,6 +150,42 @@ testda = DiffEqArray(recursivecopy(testva.u), testts)
149150
fill!(testda, testval)
150151
@test all(x -> (x == testval), testda)
151152

153+
# copyto!
154+
testva = VectorOfArray(collect(0.1:0.1:1.0))
155+
arr = 0.2:0.2:2.0
156+
copyto!(testva, arr)
157+
@test Array(testva) == arr
158+
testva = VectorOfArray([i * ones(3, 2) for i in 1:4])
159+
arr = rand(3, 2, 4)
160+
copyto!(testva, arr)
161+
@test Array(testva) == arr
162+
testva = VectorOfArray([
163+
ones(3, 2, 2),
164+
VectorOfArray([
165+
2ones(3, 2),
166+
VectorOfArray([3ones(3), 4ones(3)])
167+
]),
168+
DiffEqArray([
169+
5ones(3, 2),
170+
VectorOfArray([
171+
6ones(3),
172+
7ones(3),
173+
]),
174+
], [0.1, 0.2], [100.0, 200.0], SymbolCache([:x, :y], [:a, :b], :t))
175+
])
176+
arr = rand(3, 2, 2, 3)
177+
copyto!(testva, arr)
178+
@test Array(testva) == arr
179+
# ensure structure and fields are maintained
180+
@test testva.u[1] isa Array
181+
@test testva.u[2] isa VectorOfArray
182+
@test testva.u[2].u[2] isa VectorOfArray
183+
@test testva.u[3] isa DiffEqArray
184+
@test testva.u[3].u[2] isa VectorOfArray
185+
@test testva.u[3].t == [0.1, 0.2]
186+
@test testva.u[3].p == [100.0, 200.0]
187+
@test testva.u[3].sys isa SymbolCache
188+
152189
# check any
153190
recs = [collect(1:5), collect(6:10), collect(11:15)]
154191
testts = rand(5)

0 commit comments

Comments
 (0)