Skip to content

Commit 6d022cb

Browse files
Merge pull request #257 from SciML/fillarrays
Remove FillArrays dependency to reduce import times
2 parents 35b1c8c + 9f10f06 commit 6d022cb

File tree

6 files changed

+132
-190
lines changed

6 files changed

+132
-190
lines changed

Project.toml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ version = "2.38.0"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
109
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
11-
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1210
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1311
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
1412
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -18,26 +16,24 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1816
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1917
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2018
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
21-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2219

2320
[compat]
2421
Adapt = "3"
2522
ArrayInterface = "7"
26-
ChainRulesCore = "0.10.7, 1"
2723
DocStringExtensions = "0.8, 0.9"
28-
FillArrays = "0.11, 0.12, 0.13"
2924
GPUArraysCore = "0.1"
3025
IteratorInterfaceExtensions = "1"
3126
RecipesBase = "0.7, 0.8, 1.0"
3227
Requires = "1.0"
3328
StaticArraysCore = "1.1"
3429
SymbolicIndexingInterface = "0.1, 0.2"
3530
Tables = "1"
36-
ZygoteRules = "0.2"
31+
Zygote = "< 0.6.56"
3732
julia = "1.6"
3833

3934
[extensions]
4035
RecursiveArrayToolsTrackerExt = "Tracker"
36+
RecursiveArrayToolsZygoteExt = "Zygote"
4137

4238
[extras]
4339
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
@@ -60,3 +56,4 @@ test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "Ord
6056

6157
[weakdeps]
6258
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
59+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
module RecursiveArrayToolsZygoteExt
2+
3+
using RecursiveArrayTools
4+
5+
if isdefined(Base, :get_extension)
6+
using Zygote
7+
using Zygote: FillArrays, ChainRulesCore, literal_getproperty, @adjoint
8+
else
9+
using ..Zygote
10+
using ..Zygote: FillArrays, ChainRulesCore, literal_getproperty, @adjoint
11+
end
12+
13+
# Define a new species of projection operator for this type:
14+
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
15+
16+
function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
17+
xs::AbstractVectorOfArray)
18+
T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ)
19+
end
20+
21+
@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
22+
function AbstractVectorOfArray_getindex_adjoint(Δ)
23+
Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x)))
24+
for (x, j) in zip(VA.u, 1:length(VA))]
25+
(VectorOfArray(Δ′), nothing)
26+
end
27+
VA[i], AbstractVectorOfArray_getindex_adjoint
28+
end
29+
30+
@adjoint function getindex(VA::AbstractVectorOfArray,
31+
i::Union{BitArray, AbstractArray{Bool}})
32+
function AbstractVectorOfArray_getindex_adjoint(Δ)
33+
Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x)))
34+
for (x, j) in zip(VA.u, 1:length(VA))]
35+
(VectorOfArray(Δ′), nothing)
36+
end
37+
VA[i], AbstractVectorOfArray_getindex_adjoint
38+
end
39+
40+
@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
41+
function AbstractVectorOfArray_getindex_adjoint(Δ)
42+
iter = 0
43+
Δ′ = [(j i ? Δ[iter += 1] : Fill(zero(eltype(x)), size(x)))
44+
for (x, j) in zip(VA.u, 1:length(VA))]
45+
(VectorOfArray(Δ′), nothing)
46+
end
47+
VA[i], AbstractVectorOfArray_getindex_adjoint
48+
end
49+
50+
@adjoint function getindex(VA::AbstractVectorOfArray,
51+
i::Union{Int, AbstractArray{Int}})
52+
function AbstractVectorOfArray_getindex_adjoint(Δ)
53+
Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x)))
54+
for (x, j) in zip(VA.u, 1:length(VA))]
55+
(VectorOfArray(Δ′), nothing)
56+
end
57+
VA[i], AbstractVectorOfArray_getindex_adjoint
58+
end
59+
60+
@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon)
61+
function AbstractVectorOfArray_getindex_adjoint(Δ)
62+
(VectorOfArray(Δ), nothing)
63+
end
64+
VA[i], AbstractVectorOfArray_getindex_adjoint
65+
end
66+
67+
@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
68+
j::Union{Int, AbstractArray{Int}, CartesianIndex,
69+
Colon, BitArray, AbstractArray{Bool}}...)
70+
function AbstractVectorOfArray_getindex_adjoint(Δ)
71+
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
72+
Δ′[i, j...] = Δ
73+
(Δ′, nothing, map(_ -> nothing, j)...)
74+
end
75+
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
76+
end
77+
78+
@adjoint function ArrayPartition(x::S,
79+
::Type{Val{copy_x}} = Val{false}) where {
80+
S <:
81+
Tuple,
82+
copy_x
83+
}
84+
function ArrayPartition_adjoint(_y)
85+
y = Array(_y)
86+
starts = vcat(0, cumsum(reduce(vcat, length.(x))))
87+
ntuple(i -> reshape(y[(starts[i] + 1):starts[i + 1]], size(x[i])), length(x)),
88+
nothing
89+
end
90+
91+
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
92+
end
93+
94+
@adjoint function VectorOfArray(u)
95+
VectorOfArray(u),
96+
y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
97+
for i in 1:size(y)[end]]),)
98+
end
99+
100+
@adjoint function DiffEqArray(u, t)
101+
DiffEqArray(u, t),
102+
y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]],
103+
t), nothing)
104+
end
105+
106+
@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
107+
function literal_ArrayPartition_x_adjoint(d)
108+
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
109+
end
110+
A.x, literal_ArrayPartition_x_adjoint
111+
end
112+
113+
end

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ using RecipesBase, StaticArraysCore, Statistics,
99
ArrayInterface, LinearAlgebra
1010
using SymbolicIndexingInterface
1111

12-
import ChainRulesCore
13-
import ChainRulesCore: NoTangent
14-
import ZygoteRules, Adapt
15-
16-
using FillArrays
12+
import Adapt
1713

1814
import Tables, IteratorInterfaceExtensions
1915

@@ -24,23 +20,19 @@ include("utils.jl")
2420
include("vector_of_array.jl")
2521
include("tabletraits.jl")
2622
include("array_partition.jl")
27-
include("zygote.jl")
2823

2924
function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray})
3025
invoke(show, Tuple{typeof(io), Any}, io, x)
3126
end
3227

3328
import GPUArraysCore
3429
Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA)
35-
function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray},
36-
xs::AbstractVectorOfArray)
37-
T(xs), ȳ -> (NoTangent(), ȳ)
38-
end
3930

4031
import Requires
4132
@static if !isdefined(Base, :get_extension)
4233
function __init__()
4334
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end
35+
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/RecursiveArrayToolsZygoteExt.jl") end
4436
end
4537
end
4638

src/array_partition.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Base.ones(A::ArrayPartition, dims::NTuple{N, Int}) where {N} = ones(A)
115115

116116
# mutable iff all components of ArrayPartition are mutable
117117
@generated function ArrayInterface.ismutable(::Type{<:ArrayPartition{T, S}}) where {T, S
118-
}
118+
}
119119
res = all(ArrayInterface.ismutable, S.parameters)
120120
return :($res)
121121
end

src/zygote.jl

Lines changed: 0 additions & 160 deletions
This file was deleted.

0 commit comments

Comments
 (0)