Skip to content

Commit e8f5e3d

Browse files
Merge pull request #164 from SciML/adjoint_dispatch
catch adjoint dispatch
2 parents ba9749d + 8124486 commit e8f5e3d

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1616

1717
[compat]
1818
ArrayInterface = "2.7, 3.0"
19-
ChainRulesCore = "0.10.7, 1"
19+
ChainRulesCore = "0.10.7"
2020
DocStringExtensions = "0.8"
2121
RecipesBase = "0.7, 0.8, 1.0"
2222
Requires = "0.5, 1.0"

src/zygote.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Int)
1+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
22
function AbstractVectorOfArray_getindex_adjoint(Δ)
33
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
44
(NoTangent(),Δ′,NoTangent())
55
end
66
VA[i],AbstractVectorOfArray_getindex_adjoint
77
end
88

9-
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Vararg{Int,N}) where {N}
9+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
1010
function AbstractVectorOfArray_getindex_adjoint(Δ)
1111
Δ′ = zero(VA)
1212
Δ′[indices...] = Δ
@@ -43,15 +43,16 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4343
A.x,literal_ArrayPartition_x_adjoint
4444
end
4545

46-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
46+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
4747
function AbstractVectorOfArray_getindex_adjoint(Δ)
4848
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
4949
(Δ′,nothing)
5050
end
51+
@show VA[i]
5152
VA[i],AbstractVectorOfArray_getindex_adjoint
5253
end
5354

54-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Int...)
55+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
5556
function AbstractVectorOfArray_getindex_adjoint(Δ)
5657
Δ′ = zero(VA)
5758
Δ′[i,j...] = Δ

0 commit comments

Comments
 (0)