diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index a5c9f4a5..26613aee 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -145,8 +145,12 @@ end function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ AbstractArray, AbstractVectorOfArray}) - arr = reshape(x, p.sz) - return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) + if eltype(x) <: Number + arr = reshape(x, p.sz) + return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) + elseif eltype(x) <: AbstractArray + return VectorOfArray(x) + end end @adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 06b724a9..e384f714 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -34,6 +34,12 @@ sol_new = DiffEqArray(sol.u[1:10], @test_throws Exception sol[τ] @test_throws Exception sol_new[τ] +gs, = Zygote.gradient(sol) do sol + sum(sol[fol_separate.x]) +end + +@test "Symbolic Indexing ADjoint" all(all.(isone, gs.u)) + # Tables interface test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))