From 5c6d7fd207692fedb81cb3264134db259aa68213 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 29 Sep 2021 12:51:04 -0400 Subject: [PATCH 1/6] take 1 --- src/zygote.jl | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/zygote.jl b/src/zygote.jl index d1b5a09e..c30ec23e 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -43,10 +43,28 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol A.x,literal_ArrayPartition_x_adjoint end +#= + +# Define a new species of projection operator for this type: +ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}() + +# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix +(::ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) +# Gradient from broadcasting will be another AbstractArray +(::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx + +But this may not be necessary? + +=# + + +# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` +# definition first, and finds its own before finding those. + ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] - (Δ′,nothing) + (VectorOfArray(Δ′),nothing) end VA[i],AbstractVectorOfArray_getindex_adjoint end @@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[i,j...] = Δ - (Δ′, i,map(_ -> nothing, j)...) + @show Δ′ + # (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug? + (Δ′, nothing, map(_ -> nothing, j)...) + # (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...) end VA[i,j...],AbstractVectorOfArray_getindex_adjoint end - ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x} function ArrayPartition_adjoint(_y) y = Array(_y) From 6af53d520ca89667f6019cc08ba92b265720372d Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 29 Sep 2021 16:30:39 -0400 Subject: [PATCH 2/6] fix up a few more overloads --- src/zygote.jl | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/zygote.jl b/src/zygote.jl index c30ec23e..1777df67 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] - (NoTangent(),Δ′,NoTangent()) + (NoTangent(),VectorOfArray(Δ′),NoTangent()) end VA[i],AbstractVectorOfArray_getindex_adjoint end @@ -10,7 +10,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indi function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[indices...] = Δ - (NoTangent(), Δ′, indices[1],map(_ -> NoTangent(), indices[2:end])...) + (NoTangent(), VectorOfArray(Δ′), indices[1],map(_ -> NoTangent(), indices[2:end])...) end VA[indices...],AbstractVectorOfArray_getindex_adjoint end @@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x} function ArrayPartition_adjoint(_y) y = Array(_y) starts = vcat(0,cumsum(reduce(vcat,length.(x)))) - NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent() + NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent() end ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint @@ -43,8 +43,6 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol A.x,literal_ArrayPartition_x_adjoint end -#= - # Define a new species of projection operator for this type: ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}() @@ -53,11 +51,6 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}() # Gradient from broadcasting will be another AbstractArray (::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx -But this may not be necessary? - -=# - - # These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` # definition first, and finds its own before finding those. @@ -73,10 +66,7 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[i,j...] = Δ - @show Δ′ - # (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug? - (Δ′, nothing, map(_ -> nothing, j)...) - # (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...) + (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...) end VA[i,j...],AbstractVectorOfArray_getindex_adjoint end @@ -91,11 +81,11 @@ ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{fal end ZygoteRules.@adjoint function VectorOfArray(u) - VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],) + VectorOfArray(u),y -> (VectorOfArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]]),) end ZygoteRules.@adjoint function DiffEqArray(u,t) - DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing) + DiffEqArray(u,t),y -> (DiffEqArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],t),nothing) end ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x}) From b554abc94d26de6e8e32c19cadfa1f9c5bdd449e Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 29 Sep 2021 16:41:06 -0400 Subject: [PATCH 3/6] namespace --- src/zygote.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zygote.jl b/src/zygote.jl index 1777df67..031fb9ae 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -44,12 +44,12 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol end # Define a new species of projection operator for this type: -ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}() +ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() # Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix -(::ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) +(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) # Gradient from broadcasting will be another AbstractArray -(::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx +(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx # These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` # definition first, and finds its own before finding those. From fe2db07529818eb4b16bf545467589a52be3fc73 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 30 Sep 2021 01:40:54 -0400 Subject: [PATCH 4/6] Update src/zygote.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/zygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zygote.jl b/src/zygote.jl index 031fb9ae..ce930d65 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -10,7 +10,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indi function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[indices...] = Δ - (NoTangent(), VectorOfArray(Δ′), indices[1],map(_ -> NoTangent(), indices[2:end])...) + (NoTangent(), VectorOfArray(Δ′), map(_ -> NoTangent(), indices)...) end VA[indices...],AbstractVectorOfArray_getindex_adjoint end From c84842aaa4cd2a739b74de9480cd9cd195b0eee9 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 30 Sep 2021 01:43:22 -0400 Subject: [PATCH 5/6] comment out extra projection --- src/zygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zygote.jl b/src/zygote.jl index ce930d65..328b1adf 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -49,7 +49,7 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr # Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix (::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) # Gradient from broadcasting will be another AbstractArray -(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx +#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx # These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` # definition first, and finds its own before finding those. From b3ed973b0167f0e60f4b7c4e801af850ce021b17 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 30 Sep 2021 02:47:22 -0400 Subject: [PATCH 6/6] fix pullbacks and use some FillArrays --- Project.toml | 1 + src/RecursiveArrayTools.jl | 3 +++ src/zygote.jl | 8 ++++---- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ecf5418d..3a356377 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "2.17.2" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 59ad81e4..88fb4240 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -11,6 +11,9 @@ using Requires, RecipesBase, StaticArrays, Statistics, import ChainRulesCore import ChainRulesCore: NoTangent import ZygoteRules + +using FillArrays + abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end diff --git a/src/zygote.jl b/src/zygote.jl index 328b1adf..fa7bd019 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -47,7 +47,7 @@ end ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() # Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix -(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) +#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) # Gradient from broadcasting will be another AbstractArray #(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx @@ -56,7 +56,7 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] + Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))] (VectorOfArray(Δ′),nothing) end VA[i],AbstractVectorOfArray_getindex_adjoint @@ -64,8 +64,8 @@ end ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = zero(VA) - Δ′[i,j...] = Δ + Δ′ = [(i == j ? zero(x) : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))] + Δ′[i][j...] = Δ (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...) end VA[i,j...],AbstractVectorOfArray_getindex_adjoint