Skip to content

Commit b18d10a

Browse files
catch adjoint dispatch
1 parent ba9749d commit b18d10a

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/zygote.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Int)
1+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
22
function AbstractVectorOfArray_getindex_adjoint(Δ)
33
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
44
(NoTangent(),Δ′,NoTangent())
55
end
66
VA[i],AbstractVectorOfArray_getindex_adjoint
77
end
88

9-
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Vararg{Int,N}) where {N}
9+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
1010
function AbstractVectorOfArray_getindex_adjoint(Δ)
1111
Δ′ = zero(VA)
1212
Δ′[indices...] = Δ
@@ -43,15 +43,16 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4343
A.x,literal_ArrayPartition_x_adjoint
4444
end
4545

46-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
46+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
4747
function AbstractVectorOfArray_getindex_adjoint(Δ)
4848
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
4949
(Δ′,nothing)
5050
end
51+
@show VA[i]
5152
VA[i],AbstractVectorOfArray_getindex_adjoint
5253
end
5354

54-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Int...)
55+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
5556
function AbstractVectorOfArray_getindex_adjoint(Δ)
5657
Δ′ = zero(VA)
5758
Δ′[i,j...] = Δ

0 commit comments

Comments
 (0)