Skip to content

Commit e6d2680

Browse files
committed
Merge branch 'jc/NamedArrayPartition' of https://github.com/jlchan/RecursiveArrayTools.jl into jc/NamedArrayPartition
2 parents 7e10283 + 0339fa4 commit e6d2680

23 files changed

+700
-454
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ RecipesBase = "0.7, 0.8, 1.0"
3939
Requires = "1.0"
4040
StaticArraysCore = "1.1"
4141
Statistics = "1"
42-
SymbolicIndexingInterface = "0.1, 0.2"
42+
SymbolicIndexingInterface = "0.3"
4343
Tables = "1"
4444
Zygote = "0.6.56"
4545
julia = "1.6"

docs/make.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
66
include("pages.jl")
77

88
makedocs(sitename = "RecursiveArrayTools.jl",
9-
authors = "Chris Rackauckas",
10-
modules = [RecursiveArrayTools],
11-
clean = true, doctest = false, linkcheck = true,
12-
warnonly = [:missing_docs],
13-
format = Documenter.HTML(assets = ["assets/favicon.ico"],
14-
canonical = "https://docs.sciml.ai/RecursiveArrayTools/stable/"),
15-
pages = pages)
9+
authors = "Chris Rackauckas",
10+
modules = [RecursiveArrayTools],
11+
clean = true, doctest = false, linkcheck = true,
12+
warnonly = [:missing_docs],
13+
format = Documenter.HTML(assets = ["assets/favicon.ico"],
14+
canonical = "https://docs.sciml.ai/RecursiveArrayTools/stable/"),
15+
pages = pages)
1616

1717
deploydocs(repo = "github.com/SciML/RecursiveArrayTools.jl.git";
18-
push_preview = true)
18+
push_preview = true)

ext/RecursiveArrayToolsMeasurementsExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import RecursiveArrayTools
44
isdefined(Base, :get_extension) ? (import Measurements) : (import ..Measurements)
55

66
function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{
7-
<:Measurements.Measurement
8-
})
7+
<:Measurements.Measurement,
8+
})
99
typeof(oneunit(a))
1010
end
1111

ext/RecursiveArrayToolsMonteCarloMeasurementsExt.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
module RecursiveArrayToolsMonteCarloMeasurementsExt
22

33
import RecursiveArrayTools
4-
isdefined(Base, :get_extension) ? (import MonteCarloMeasurements) : (import ..MonteCarloMeasurements)
4+
isdefined(Base, :get_extension) ? (import MonteCarloMeasurements) :
5+
(import ..MonteCarloMeasurements)
56

67
function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{
7-
<:MonteCarloMeasurements.Particles
8-
})
8+
<:MonteCarloMeasurements.Particles,
9+
})
910
typeof(one(a))
1011
end
1112

12-
function RecursiveArrayTools.recursive_unitless_eltype(a::Type{<:MonteCarloMeasurements.Particles})
13+
function RecursiveArrayTools.recursive_unitless_eltype(a::Type{
14+
<:MonteCarloMeasurements.Particles,
15+
})
1316
typeof(one(a))
1417
end
1518

ext/RecursiveArrayToolsTrackerExt.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ import RecursiveArrayTools
44
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)
55

66
function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N},
7-
a::AbstractArray{T2, N}) where {
8-
T <:
9-
Tracker.TrackedArray,
10-
T2 <:
11-
Tracker.TrackedArray,
12-
N}
7+
a::AbstractArray{T2, N}) where {
8+
T <:
9+
Tracker.TrackedArray,
10+
T2 <:
11+
Tracker.TrackedArray,
12+
N}
1313
@inbounds for i in eachindex(a)
1414
b[i] = copy(a[i])
1515
end

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
1515

1616
function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
17-
xs::AbstractVectorOfArray)
17+
xs::AbstractVectorOfArray)
1818
T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ)
1919
end
2020

@@ -28,7 +28,7 @@ end
2828
end
2929

3030
@adjoint function getindex(VA::AbstractVectorOfArray,
31-
i::Union{BitArray, AbstractArray{Bool}})
31+
i::Union{BitArray, AbstractArray{Bool}})
3232
function AbstractVectorOfArray_getindex_adjoint(Δ)
3333
Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
3434
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -48,7 +48,7 @@ end
4848
end
4949

5050
@adjoint function getindex(VA::AbstractVectorOfArray,
51-
i::Union{Int, AbstractArray{Int}})
51+
i::Union{Int, AbstractArray{Int}})
5252
function AbstractVectorOfArray_getindex_adjoint(Δ)
5353
Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
5454
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -65,8 +65,8 @@ end
6565
end
6666

6767
@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
68-
j::Union{Int, AbstractArray{Int}, CartesianIndex,
69-
Colon, BitArray, AbstractArray{Bool}}...)
68+
j::Union{Int, AbstractArray{Int}, CartesianIndex,
69+
Colon, BitArray, AbstractArray{Bool}}...)
7070
function AbstractVectorOfArray_getindex_adjoint(Δ)
7171
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
7272
Δ′[i, j...] = Δ
@@ -76,11 +76,11 @@ end
7676
end
7777

7878
@adjoint function ArrayPartition(x::S,
79-
::Type{Val{copy_x}} = Val{false}) where {
80-
S <:
81-
Tuple,
82-
copy_x
83-
}
79+
::Type{Val{copy_x}} = Val{false}) where {
80+
S <:
81+
Tuple,
82+
copy_x,
83+
}
8484
function ArrayPartition_adjoint(_y)
8585
y = Array(_y)
8686
starts = vcat(0, cumsum(reduce(vcat, length.(x))))
@@ -93,14 +93,21 @@ end
9393

9494
@adjoint function VectorOfArray(u)
9595
VectorOfArray(u),
96-
y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
97-
for i in 1:size(y)[end]]),)
96+
y -> begin
97+
y isa Ref && (y = VectorOfArray(y[].u))
98+
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
99+
for i in 1:size(y.u)[end]]),)
100+
end
98101
end
99102

100103
@adjoint function DiffEqArray(u, t)
101104
DiffEqArray(u, t),
102-
y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]],
103-
t), nothing)
105+
y -> begin
106+
y isa Ref && (y = VectorOfArray(y[].u))
107+
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
108+
for i in 1:size(y.u)[end]],
109+
t), nothing)
110+
end
104111
end
105112

106113
@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})

src/RecursiveArrayTools.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ module RecursiveArrayTools
66

77
using DocStringExtensions
88
using RecipesBase, StaticArraysCore, Statistics,
9-
ArrayInterface, LinearAlgebra
9+
ArrayInterface, LinearAlgebra
1010
using SymbolicIndexingInterface
1111

1212
import Adapt
1313

1414
import Tables, IteratorInterfaceExtensions
1515

16-
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
16+
abstract type AbstractVectorOfArray{T, N, A} end
1717
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
1818

1919
include("utils.jl")
@@ -32,18 +32,24 @@ Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArra
3232
import Requires
3333
@static if !isdefined(Base, :get_extension)
3434
function __init__()
35-
Requires.@require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin include("../ext/RecursiveArrayToolsMeasurementsExt.jl") end
36-
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end
37-
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/RecursiveArrayToolsZygoteExt.jl") end
35+
Requires.@require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin
36+
include("../ext/RecursiveArrayToolsMeasurementsExt.jl")
37+
end
38+
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
39+
include("../ext/RecursiveArrayToolsTrackerExt.jl")
40+
end
41+
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
42+
include("../ext/RecursiveArrayToolsZygoteExt.jl")
43+
end
3844
end
3945
end
4046

4147
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
42-
AllObserved, vecarr_to_vectors, tuples
48+
AllObserved, vecarr_to_vectors, tuples
4349

4450
export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_push!,
45-
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
46-
recursive_unitless_bottom_eltype, recursive_unitless_eltype
51+
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
52+
recursive_unitless_bottom_eltype, recursive_unitless_eltype
4753

4854
export ArrayPartition, NamedArrayPartition
4955

0 commit comments

Comments
 (0)