Skip to content

Commit 57d167c

Browse files
hotfix add back some zygoterules
1 parent f59c6c2 commit 57d167c

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1212
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1313
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
15+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1516

1617
[compat]
1718
ArrayInterface = "2.7, 3.0"

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using DocStringExtensions
1010

1111
import ChainRulesCore
1212
import ChainRulesCore: NoTangent
13+
import ZygoteRules
1314
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
1415
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
1516

src/zygote.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,45 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4242
end
4343
A.x,literal_ArrayPartition_x_adjoint
4444
end
45+
46+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i)
47+
function AbstractVectorOfArray_getindex_adjoint(Δ)
48+
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
49+
(Δ′,nothing)
50+
end
51+
VA[i],AbstractVectorOfArray_getindex_adjoint
52+
end
53+
54+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...)
55+
function AbstractVectorOfArray_getindex_adjoint(Δ)
56+
Δ′ = zero(VA)
57+
Δ′[i,j...] = Δ
58+
(Δ′, i,map(_ -> nothing, j)...)
59+
end
60+
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
61+
end
62+
63+
ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
64+
function ArrayPartition_adjoint(_y)
65+
y = Array(_y)
66+
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
67+
ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), nothing
68+
end
69+
70+
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
71+
end
72+
73+
ZygoteRules.@adjoint function VectorOfArray(u)
74+
VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],)
75+
end
76+
77+
ZygoteRules.@adjoint function DiffEqArray(u,t)
78+
DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing)
79+
end
80+
81+
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
82+
function literal_ArrayPartition_x_adjoint(d)
83+
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
84+
end
85+
A.x,literal_ArrayPartition_x_adjoint
86+
end

0 commit comments

Comments
 (0)