Skip to content

Commit 9980686

Browse files
Merge pull request #331 from AayushSabharwal/as/linear-indexing
fix: remove linear indexing from getindex adjoints
2 parents dd34e93 + e56b3da commit 9980686

File tree

2 files changed

+6
-22
lines changed

2 files changed

+6
-22
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,14 @@ function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.Abstra
1818
T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ)
1919
end
2020

21-
@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
22-
function AbstractVectorOfArray_getindex_adjoint(Δ)
23-
Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x)))
24-
for (x, j) in zip(VA.u, 1:length(VA))]
25-
(VectorOfArray(Δ′), nothing)
26-
end
27-
VA[i], AbstractVectorOfArray_getindex_adjoint
28-
end
29-
3021
@adjoint function getindex(VA::AbstractVectorOfArray,
3122
i::Union{BitArray, AbstractArray{Bool}})
3223
function AbstractVectorOfArray_getindex_adjoint(Δ)
3324
Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
3425
for (x, j) in zip(VA.u, 1:length(VA))]
3526
(VectorOfArray(Δ′), nothing)
3627
end
37-
VA[i], AbstractVectorOfArray_getindex_adjoint
28+
VA[:, i], AbstractVectorOfArray_getindex_adjoint
3829
end
3930

4031
@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
@@ -44,24 +35,14 @@ end
4435
for (x, j) in zip(VA.u, 1:length(VA))]
4536
(VectorOfArray(Δ′), nothing)
4637
end
47-
VA[i], AbstractVectorOfArray_getindex_adjoint
48-
end
49-
50-
@adjoint function getindex(VA::AbstractVectorOfArray,
51-
i::Union{Int, AbstractArray{Int}})
52-
function AbstractVectorOfArray_getindex_adjoint(Δ)
53-
Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
54-
for (x, j) in zip(VA.u, 1:length(VA))]
55-
(VectorOfArray(Δ′), nothing)
56-
end
57-
VA[i], AbstractVectorOfArray_getindex_adjoint
38+
VA[:, i], AbstractVectorOfArray_getindex_adjoint
5839
end
5940

6041
@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon)
6142
function AbstractVectorOfArray_getindex_adjoint(Δ)
6243
(VectorOfArray(Δ), nothing)
6344
end
64-
VA[i], AbstractVectorOfArray_getindex_adjoint
45+
VA.u[i], AbstractVectorOfArray_getindex_adjoint
6546
end
6647

6748
@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,

src/vector_of_array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ function Base.Array{U}(VA::AbstractVectorOfArray) where {U}
128128
vecs = vec.(VA.u)
129129
Array(reshape(reduce(hcat, vecs), size(VA.u[1])..., length(VA.u)))
130130
end
131+
function Adapt.adapt_structure(to, VA::AbstractVectorOfArray)
132+
Adapt.adapt(to, Array(VA))
133+
end
131134

132135
function VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N}
133136
VectorOfArray{eltype(T), N, typeof(vec)}(vec)

0 commit comments

Comments
 (0)