diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index ac2f1622..c4ddc1a7 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -18,15 +18,6 @@ function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.Abstra T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ) end -@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) - function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) - for (x, j) in zip(VA.u, 1:length(VA))] - (VectorOfArray(Δ′), nothing) - end - VA[i], AbstractVectorOfArray_getindex_adjoint -end - @adjoint function getindex(VA::AbstractVectorOfArray, i::Union{BitArray, AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) @@ -34,7 +25,7 @@ end for (x, j) in zip(VA.u, 1:length(VA))] (VectorOfArray(Δ′), nothing) end - VA[i], AbstractVectorOfArray_getindex_adjoint + VA[:, i], AbstractVectorOfArray_getindex_adjoint end @adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int}) @@ -44,24 +35,14 @@ end for (x, j) in zip(VA.u, 1:length(VA))] (VectorOfArray(Δ′), nothing) end - VA[i], AbstractVectorOfArray_getindex_adjoint -end - -@adjoint function getindex(VA::AbstractVectorOfArray, - i::Union{Int, AbstractArray{Int}}) - function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x))) - for (x, j) in zip(VA.u, 1:length(VA))] - (VectorOfArray(Δ′), nothing) - end - VA[i], AbstractVectorOfArray_getindex_adjoint + VA[:, i], AbstractVectorOfArray_getindex_adjoint end @adjoint function getindex(VA::AbstractVectorOfArray, i::Colon) function AbstractVectorOfArray_getindex_adjoint(Δ) (VectorOfArray(Δ), nothing) end - VA[i], AbstractVectorOfArray_getindex_adjoint + VA.u[i], AbstractVectorOfArray_getindex_adjoint end @adjoint function getindex(VA::AbstractVectorOfArray, i::Int, diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 44fd15bc..a30a476b 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -128,6 +128,9 @@ function Base.Array{U}(VA::AbstractVectorOfArray) where {U} vecs = vec.(VA.u) Array(reshape(reduce(hcat, vecs), size(VA.u[1])..., length(VA.u))) end +function Adapt.adapt_structure(to, VA::AbstractVectorOfArray) + Adapt.adapt(to, Array(VA)) +end function VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} VectorOfArray{eltype(T), N, typeof(vec)}(vec)