Skip to content

adding NamedArrayPartition type #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 4, 2024
Merged
1 change: 1 addition & 0 deletions docs/src/array_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ mapping and iteration functions, and more.
VectorOfArray
DiffEqArray
ArrayPartition
NamedArrayPartition
```
3 changes: 2 additions & 1 deletion src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("utils.jl")
include("vector_of_array.jl")
include("tabletraits.jl")
include("array_partition.jl")
include("named_array_partition.jl")

function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray})
invoke(show, Tuple{typeof(io), Any}, io, x)
Expand Down Expand Up @@ -52,6 +53,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype

export ArrayPartition
export ArrayPartition, NamedArrayPartition

end # module
114 changes: 114 additions & 0 deletions src/named_array_partition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
NamedArrayPartition(; kwargs...)
NamedArrayPartition(x::NamedTuple)

Similar to an `ArrayPartition` but the individual arrays can be accessed via the
constructor-specified names. However, unlike `ArrayPartition`, each individual array
must have the same element type.
"""
struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T}
array_partition::A
names_to_indices::NT
end
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs))
function NamedArrayPartition(x::NamedTuple)
names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x)))

# enforce homogeneity of eltypes
@assert all(eltype.(values(x)) .== eltype(first(x)))
T = eltype(first(x))
S = typeof(values(x))
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
end

# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
# fields except through `getfield` and accessor functions.
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)

Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))

Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} =
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors


Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
Base.getproperty(x::NamedArrayPartition, s::Symbol) =
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))

# this enables x.s = some_array.
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
index = getproperty(getfield(x, :names_to_indices), s)
ArrayPartition(x).x[index] .= v
end

# print out NamedArrayPartition as a NamedTuple
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:")
Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) =
show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x)))

Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)

Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x))

Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} =
NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices))

# broadcasting
Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}()
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}},
::Type{ElType}) where {ElType}
x = find_NamedArrayPartition(bc)
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
end

# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) =
Broadcast.DefaultArrayStyle{1}()

# hook into ArrayPartition broadcasting routines
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x))
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) =
Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args))
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i)

Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} =
NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))

@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function =
NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices)

@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
N = npartitions(bc)
@inline function f(i)
copy(unpack(bc, i))
end
x = find_NamedArrayPartition(bc)
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
end

@inline function Base.copyto!(dest::NamedArrayPartition,
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
N = npartitions(dest, bc)
@inline function f(i)
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
end
ntuple(f, Val(N))
return dest
end

# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments.
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args)
find_NamedArrayPartition(args::Tuple) =
find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args))
find_NamedArrayPartition(x) = x
find_NamedArrayPartition(::Tuple{}) = nothing
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest)


34 changes: 34 additions & 0 deletions test/named_array_partition_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using RecursiveArrayTools, Test

@testset "NamedArrayPartition tests" begin
x = NamedArrayPartition(a = ones(10), b = rand(20))
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition
@test typeof(x.^2) <: NamedArrayPartition
@test x.a ≈ ones(10)
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
@test all(x .== x[1:end])
y = copy(x)
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works
@test typeof(zero(x)) <: NamedArrayPartition
@test (y .*= 2).a[1] ≈ 2 # test in-place bcast

@test length(Array(x))==30
@test typeof(Array(x)) <: Array
@test propertynames(x) == (:a, :b)

x = NamedArrayPartition(a = ones(1), b = 2*ones(1))
@test Base.summary(x) == string(typeof(x), " with arrays:")
io = IOBuffer()
Base.show(io, MIME"text/plain"(), x)
@test String(take!(io)) == "(a = [1.0], b = [2.0])"

using StructArrays
using StaticArrays: SVector
x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))),
b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2))))
@test typeof(x.a) <: StructVector{<:SVector{2}}
@test typeof(x.b) <: StructArray{<:SVector{2}, 2}
@test typeof((x->x[1]).(x)) <: NamedArrayPartition
@test typeof(map(x->x[1], x)) <: NamedArrayPartition
end

5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ end
@time @safetestset "Utils Tests" begin
include("utils_test.jl")
end
@time @safetestset "NamedArrayPartition Tests" begin
include("named_array_partition_tests.jl")
end
@time @safetestset "Partitions Tests" begin
include("partitions_test.jl")
end
end
@time @safetestset "VecOfArr Indexing Tests" begin
include("basic_indexing.jl")
end
Expand Down