@@ -18,23 +18,14 @@ function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.Abstra
18
18
T (xs), ȳ -> (ChainRulesCore. NoTangent (), ȳ)
19
19
end
20
20
21
- @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int )
22
- function AbstractVectorOfArray_getindex_adjoint (Δ)
23
- Δ′ = [(i == j ? Δ : FillArrays. Fill (zero (eltype (x)), size (x)))
24
- for (x, j) in zip (VA. u, 1 : length (VA))]
25
- (VectorOfArray (Δ′), nothing )
26
- end
27
- VA[i], AbstractVectorOfArray_getindex_adjoint
28
- end
29
-
30
21
@adjoint function getindex (VA:: AbstractVectorOfArray ,
31
22
i:: Union{BitArray, AbstractArray{Bool}} )
32
23
function AbstractVectorOfArray_getindex_adjoint (Δ)
33
24
Δ′ = [(i[j] ? Δ[j] : FillArrays. Fill (zero (eltype (x)), size (x)))
34
25
for (x, j) in zip (VA. u, 1 : length (VA))]
35
26
(VectorOfArray (Δ′), nothing )
36
27
end
37
- VA[i], AbstractVectorOfArray_getindex_adjoint
28
+ VA[:, i], AbstractVectorOfArray_getindex_adjoint
38
29
end
39
30
40
31
@adjoint function getindex (VA:: AbstractVectorOfArray , i:: AbstractArray{Int} )
44
35
for (x, j) in zip (VA. u, 1 : length (VA))]
45
36
(VectorOfArray (Δ′), nothing )
46
37
end
47
- VA[i], AbstractVectorOfArray_getindex_adjoint
48
- end
49
-
50
- @adjoint function getindex (VA:: AbstractVectorOfArray ,
51
- i:: Union{Int, AbstractArray{Int}} )
52
- function AbstractVectorOfArray_getindex_adjoint (Δ)
53
- Δ′ = [(i[j] ? Δ[j] : FillArrays. Fill (zero (eltype (x)), size (x)))
54
- for (x, j) in zip (VA. u, 1 : length (VA))]
55
- (VectorOfArray (Δ′), nothing )
56
- end
57
- VA[i], AbstractVectorOfArray_getindex_adjoint
38
+ VA[:, i], AbstractVectorOfArray_getindex_adjoint
58
39
end
59
40
60
41
@adjoint function getindex (VA:: AbstractVectorOfArray , i:: Colon )
61
42
function AbstractVectorOfArray_getindex_adjoint (Δ)
62
43
(VectorOfArray (Δ), nothing )
63
44
end
64
- VA[i], AbstractVectorOfArray_getindex_adjoint
45
+ VA. u [i], AbstractVectorOfArray_getindex_adjoint
65
46
end
66
47
67
48
@adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int ,
0 commit comments