From 7a0414d2d35b9bb7982bd3aaa78335547eacb9bf Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 24 Apr 2024 06:34:46 +0000 Subject: [PATCH 1/4] feat: handle vectorofarray better while projecting --- ext/RecursiveArrayToolsZygoteExt.jl | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index a5c9f4a5..bad985a9 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -31,6 +31,7 @@ end @adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int}) function AbstractVectorOfArray_getindex_adjoint(Δ) + @show "in hete at vecint" iter = 0 Δ′ = [(j ∈ i ? Δ[iter += 1] : FillArrays.Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] @@ -77,14 +78,14 @@ end ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint end -@adjoint function VectorOfArray(u) - VectorOfArray(u), - y -> begin - y isa Ref && (y = VectorOfArray(y[].u)) - (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] - for i in 1:size(y)[end]]),) - end -end +# @adjoint function VectorOfArray(u) +# VectorOfArray(u), +# y -> begin +# y isa Ref && (y = VectorOfArray(y[].u)) +# (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] +# for i in 1:size(y)[end]]),) +# end +# end @adjoint function Base.copy(u::VectorOfArray) copy(u), @@ -145,8 +146,12 @@ end function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ AbstractArray, AbstractVectorOfArray}) - arr = reshape(x, p.sz) - return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) + if eltype(x) <: Number + arr = reshape(x, p.sz) + return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) + elseif eltype(x) <: AbstractArray + return VectorOfArray(x) + end end @adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, @@ -271,6 +276,7 @@ end ȳ -> (nothing, Zygote._project(x, ȳ)) function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄) + @show x̄ N = ndims(x̄) if length(x) == length(x̄) Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors From 90761a2e9c9b729e1a7b77e7b4ca40aa6502707f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 24 Apr 2024 06:38:27 +0000 Subject: [PATCH 2/4] chore: rm unused changes --- ext/RecursiveArrayToolsZygoteExt.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index bad985a9..087cea6d 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -31,7 +31,6 @@ end @adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int}) function AbstractVectorOfArray_getindex_adjoint(Δ) - @show "in hete at vecint" iter = 0 Δ′ = [(j ∈ i ? Δ[iter += 1] : FillArrays.Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] @@ -78,14 +77,14 @@ end ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint end -# @adjoint function VectorOfArray(u) -# VectorOfArray(u), -# y -> begin -# y isa Ref && (y = VectorOfArray(y[].u)) -# (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] -# for i in 1:size(y)[end]]),) -# end -# end +@adjoint function VectorOfArray(u) + VectorOfArray(u), + y -> begin + y isa Ref && (y = VectorOfArray(y[].u)) + (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] + for i in 1:size(y)[end]]),) + end +end @adjoint function Base.copy(u::VectorOfArray) copy(u), From f5228217cffc37c53ffc7cbc85430d58b73c65ba Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 24 Apr 2024 06:39:18 +0000 Subject: [PATCH 3/4] chore: rm debug statements --- ext/RecursiveArrayToolsZygoteExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 087cea6d..26613aee 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -275,7 +275,6 @@ end ȳ -> (nothing, Zygote._project(x, ȳ)) function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄) - @show x̄ N = ndims(x̄) if length(x) == length(x̄) Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors From a8e204c16eeb537140df4b2395f67dd2a3eacb04 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 27 Apr 2024 12:37:08 +0000 Subject: [PATCH 4/4] test: symbolic indexing adjoint --- test/downstream/symbol_indexing.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 06b724a9..e384f714 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -34,6 +34,12 @@ sol_new = DiffEqArray(sol.u[1:10], @test_throws Exception sol[τ] @test_throws Exception sol_new[τ] +gs, = Zygote.gradient(sol) do sol + sum(sol[fol_separate.x]) +end + +@test "Symbolic Indexing ADjoint" all(all.(isone, gs.u)) + # Tables interface test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))