Skip to content

Commit 0e7d97f

Browse files
remove extra chainrules
1 parent 0564506 commit 0e7d97f

File tree

1 file changed

+0
-66
lines changed

1 file changed

+0
-66
lines changed

src/zygote.jl

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,2 @@
1-
function ChainRulesCore.rrule(::typeof(getindex), VA::AbstractVectorOfArray,
2-
i::Union{Int, AbstractArray{Int}, CartesianIndex, Colon,
3-
BitArray, AbstractArray{Bool}})
4-
function AbstractVectorOfArray_getindex_adjoint(Δ)
5-
Δ′ = [(i == j ? Δ : zero(x)) for (x, j) in zip(VA.u, 1:length(VA))]
6-
(NoTangent(), VectorOfArray(Δ′), NoTangent())
7-
end
8-
VA[i], AbstractVectorOfArray_getindex_adjoint
9-
end
10-
11-
function ChainRulesCore.rrule(::typeof(getindex), VA::AbstractVectorOfArray,
12-
indices::Union{Int, AbstractArray{Int}, CartesianIndex, Colon,
13-
BitArray, AbstractArray{Bool}}...)
14-
function AbstractVectorOfArray_getindex_adjoint(Δ)
15-
Δ′ = zero(VA)
16-
Δ′[indices...] = Δ
17-
(NoTangent(), VectorOfArray(Δ′), map(_ -> NoTangent(), indices)...)
18-
end
19-
VA[indices...], AbstractVectorOfArray_getindex_adjoint
20-
end
21-
22-
function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S,
23-
::Type{Val{copy_x}} = Val{false}) where {S <: Tuple, copy_x}
24-
function ArrayPartition_adjoint(_y)
25-
y = Array(_y)
26-
starts = vcat(0, cumsum(reduce(vcat, length.(x))))
27-
NoTangent(),
28-
ntuple(i -> reshape(y[(starts[i] + 1):starts[i + 1]], size(x[i])), length(x)),
29-
NoTangent()
30-
end
31-
32-
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
33-
end
34-
35-
function ChainRulesCore.rrule(::Type{<:VectorOfArray}, u)
36-
VectorOfArray(u),
37-
y -> (NoTangent(),
38-
[y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]])
39-
end
40-
41-
function ChainRulesCore.rrule(::Type{<:DiffEqArray}, u, t)
42-
DiffEqArray(u, t),
43-
y -> (NoTangent(),
44-
[y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]],
45-
NoTangent())
46-
end
47-
48-
function ChainRulesCore.rrule(::typeof(getproperty), A::ArrayPartition, s::Symbol)
49-
if s !== :x
50-
error("$s is not a field of ArrayPartition")
51-
end
52-
function literal_ArrayPartition_x_adjoint(d)
53-
(NoTangent(),
54-
ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...))
55-
end
56-
A.x, literal_ArrayPartition_x_adjoint
57-
end
58-
591
# Define a new species of projection operator for this type:
602
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
61-
62-
# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix
63-
#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx)
64-
# Gradient from broadcasting will be another AbstractArray
65-
#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
66-
67-
# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
68-
# definition first, and finds its own before finding those.

0 commit comments

Comments
 (0)