|
1 |
| -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Int) |
| 1 | +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) |
2 | 2 | function AbstractVectorOfArray_getindex_adjoint(Δ)
|
3 | 3 | Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
|
4 | 4 | (NoTangent(),Δ′,NoTangent())
|
5 | 5 | end
|
6 | 6 | VA[i],AbstractVectorOfArray_getindex_adjoint
|
7 | 7 | end
|
8 | 8 |
|
9 |
| -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Vararg{Int,N}) where {N} |
| 9 | +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) |
10 | 10 | function AbstractVectorOfArray_getindex_adjoint(Δ)
|
11 | 11 | Δ′ = zero(VA)
|
12 | 12 | Δ′[indices...] = Δ
|
@@ -43,15 +43,16 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
|
43 | 43 | A.x,literal_ArrayPartition_x_adjoint
|
44 | 44 | end
|
45 | 45 |
|
46 |
| -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) |
| 46 | +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) |
47 | 47 | function AbstractVectorOfArray_getindex_adjoint(Δ)
|
48 | 48 | Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
|
49 | 49 | (Δ′,nothing)
|
50 | 50 | end
|
| 51 | + @show VA[i] |
51 | 52 | VA[i],AbstractVectorOfArray_getindex_adjoint
|
52 | 53 | end
|
53 | 54 |
|
54 |
| -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Int...) |
| 55 | +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) |
55 | 56 | function AbstractVectorOfArray_getindex_adjoint(Δ)
|
56 | 57 | Δ′ = zero(VA)
|
57 | 58 | Δ′[i,j...] = Δ
|
|
0 commit comments