Skip to content
Merged
3 changes: 2 additions & 1 deletion src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("utils.jl")
include("vector_of_array.jl")
include("tabletraits.jl")
include("array_partition.jl")
include("named_array_partition.jl")

function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray})
invoke(show, Tuple{typeof(io), Any}, io, x)
Expand Down Expand Up @@ -52,6 +53,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype

export ArrayPartition
export ArrayPartition, NamedArrayPartition

end # module
114 changes: 114 additions & 0 deletions src/named_array_partition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
NamedArrayPartition(; kwargs...)
NamedArrayPartition(x::NamedTuple)

Similar to an `ArrayPartition` but the individual arrays can be accessed via the
constructor-specified names. However, unlike `ArrayPartition`, each individual array
must have the same element type.
"""
struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T}
array_partition::A
names_to_indices::NT
end
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs))
function NamedArrayPartition(x::NamedTuple)
names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x)))

# enforce homogeneity of eltypes
@assert all(eltype.(values(x)) .== eltype(first(x)))
T = eltype(first(x))
S = typeof(values(x))
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
end

# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
# fields except through `getfield` and accessor functions.
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)

Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))

Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} =
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors


Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
Base.getproperty(x::NamedArrayPartition, s::Symbol) =
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))

# this enables x.s = some_array.
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
index = getproperty(getfield(x, :names_to_indices), s)
ArrayPartition(x).x[index] .= v
end

# print out NamedArrayPartition as a NamedTuple
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:")
Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) =
show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x)))

Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)

Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x))

Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} =
NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices))

# broadcasting
Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}()
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}},
::Type{ElType}) where {ElType}
x = find_NamedArrayPartition(bc)
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
end

# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) =
Broadcast.DefaultArrayStyle{1}()

# hook into ArrayPartition broadcasting routines
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x))
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) =
Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args))
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i)

Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} =
NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))

@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function =
NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices)

@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
N = npartitions(bc)
@inline function f(i)
copy(unpack(bc, i))
end
x = find_NamedArrayPartition(bc)
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
end

@inline function Base.copyto!(dest::NamedArrayPartition,
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
N = npartitions(dest, bc)
@inline function f(i)
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
end
ntuple(f, Val(N))
return dest
end

# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments.
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args)
find_NamedArrayPartition(args::Tuple) =
find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args))
find_NamedArrayPartition(x) = x
find_NamedArrayPartition(::Tuple{}) = nothing
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest)


32 changes: 32 additions & 0 deletions test/named_array_partition_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using RecursiveArrayTools, Test

@testset "NamedArrayPartition tests" begin
x = NamedArrayPartition(a = ones(10), b = rand(20))
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition
@test typeof(x.^2) <: NamedArrayPartition
@test x.a ≈ ones(10)
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
@test all(x .== x[1:end])
y = copy(x)
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works
@test typeof(zero(x)) <: NamedArrayPartition
@test (y .*= 2).a[1] ≈ 2 # test in-place bcast

@test length(Array(x))==30
@test typeof(Array(x)) <: Array
@test propertynames(x) == (:a, :b)

x = NamedArrayPartition(a = ones(1), b = 2*ones(1))
@test Base.summary(x) == string(typeof(x), " with arrays:")
@test (@capture_out Base.show(stdout, MIME"text/plain"(), x)) == "(a = [1.0], b = [2.0])"

using StructArrays
using StaticArrays: SVector
x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))),
b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2))))
@test typeof(x.a) <: StructVector{<:SVector{2}}
@test typeof(x.b) <: StructArray{<:SVector{2}, 2}
@test typeof((x->x[1]).(x)) <: NamedArrayPartition
@test typeof(map(x->x[1], x)) <: NamedArrayPartition
end

5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ end
@time @safetestset "Utils Tests" begin
include("utils_test.jl")
end
@time @safetestset "NamedArrayPartition Tests" begin
include("named_array_partition_tests.jl")
end
@time @safetestset "Partitions Tests" begin
include("partitions_test.jl")
end
end
@time @safetestset "VecOfArr Indexing Tests" begin
include("basic_indexing.jl")
end
Expand Down