diff --git a/Project.toml b/Project.toml index a4361b59..6b23e199 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] ArrayInterface = "2.7, 3.0" -ChainRulesCore = "0.10.7, 1" +ChainRulesCore = "0.10.7" DocStringExtensions = "0.8" RecipesBase = "0.7, 0.8, 1.0" Requires = "0.5, 1.0" diff --git a/src/zygote.jl b/src/zygote.jl index 047ae548..e21aa3e9 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Int) +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] (NoTangent(),Δ′,NoTangent()) @@ -6,7 +6,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::I VA[i],AbstractVectorOfArray_getindex_adjoint end -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Vararg{Int,N}) where {N} +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[indices...] = Δ @@ -43,15 +43,16 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol A.x,literal_ArrayPartition_x_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] (Δ′,nothing) end + @show VA[i] VA[i],AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Int...) +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[i,j...] = Δ