|
1 |
| -function ChainRulesCore.rrule(::typeof(getindex), VA::AbstractVectorOfArray, |
2 |
| - i::Union{Int, AbstractArray{Int}, CartesianIndex, Colon, |
3 |
| - BitArray, AbstractArray{Bool}}) |
4 |
| - function AbstractVectorOfArray_getindex_adjoint(Δ) |
5 |
| - Δ′ = [(i == j ? Δ : zero(x)) for (x, j) in zip(VA.u, 1:length(VA))] |
6 |
| - (NoTangent(), VectorOfArray(Δ′), NoTangent()) |
7 |
| - end |
8 |
| - VA[i], AbstractVectorOfArray_getindex_adjoint |
9 |
| -end |
10 |
| - |
11 |
| -function ChainRulesCore.rrule(::typeof(getindex), VA::AbstractVectorOfArray, |
12 |
| - indices::Union{Int, AbstractArray{Int}, CartesianIndex, Colon, |
13 |
| - BitArray, AbstractArray{Bool}}...) |
14 |
| - function AbstractVectorOfArray_getindex_adjoint(Δ) |
15 |
| - Δ′ = zero(VA) |
16 |
| - Δ′[indices...] = Δ |
17 |
| - (NoTangent(), VectorOfArray(Δ′), map(_ -> NoTangent(), indices)...) |
18 |
| - end |
19 |
| - VA[indices...], AbstractVectorOfArray_getindex_adjoint |
20 |
| -end |
21 |
| - |
22 |
| -function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, |
23 |
| - ::Type{Val{copy_x}} = Val{false}) where {S <: Tuple, copy_x} |
24 |
| - function ArrayPartition_adjoint(_y) |
25 |
| - y = Array(_y) |
26 |
| - starts = vcat(0, cumsum(reduce(vcat, length.(x)))) |
27 |
| - NoTangent(), |
28 |
| - ntuple(i -> reshape(y[(starts[i] + 1):starts[i + 1]], size(x[i])), length(x)), |
29 |
| - NoTangent() |
30 |
| - end |
31 |
| - |
32 |
| - ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint |
33 |
| -end |
34 |
| - |
35 |
| -function ChainRulesCore.rrule(::Type{<:VectorOfArray}, u) |
36 |
| - VectorOfArray(u), |
37 |
| - y -> (NoTangent(), |
38 |
| - [y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]]) |
39 |
| -end |
40 |
| - |
41 |
| -function ChainRulesCore.rrule(::Type{<:DiffEqArray}, u, t) |
42 |
| - DiffEqArray(u, t), |
43 |
| - y -> (NoTangent(), |
44 |
| - [y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]], |
45 |
| - NoTangent()) |
46 |
| -end |
47 |
| - |
48 |
| -function ChainRulesCore.rrule(::typeof(getproperty), A::ArrayPartition, s::Symbol) |
49 |
| - if s !== :x |
50 |
| - error("$s is not a field of ArrayPartition") |
51 |
| - end |
52 |
| - function literal_ArrayPartition_x_adjoint(d) |
53 |
| - (NoTangent(), |
54 |
| - ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...)) |
55 |
| - end |
56 |
| - A.x, literal_ArrayPartition_x_adjoint |
57 |
| -end |
58 |
| - |
59 | 1 | # Define a new species of projection operator for this type:
|
60 | 2 | ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
|
61 |
| - |
62 |
| -# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix |
63 |
| -#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) |
64 |
| -# Gradient from broadcasting will be another AbstractArray |
65 |
| -#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx |
66 |
| - |
67 |
| -# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` |
68 |
| -# definition first, and finds its own before finding those. |
0 commit comments