diff --git a/docs/src/array_types.md b/docs/src/array_types.md index 79a59adf..d3fb25b0 100644 --- a/docs/src/array_types.md +++ b/docs/src/array_types.md @@ -13,4 +13,5 @@ mapping and iteration functions, and more. VectorOfArray DiffEqArray ArrayPartition +NamedArrayPartition ``` diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index b3ee8fc8..28b46235 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -21,6 +21,7 @@ include("utils.jl") include("vector_of_array.jl") include("tabletraits.jl") include("array_partition.jl") +include("named_array_partition.jl") function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray}) invoke(show, Tuple{typeof(io), Any}, io, x) @@ -52,6 +53,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype, recursive_unitless_bottom_eltype, recursive_unitless_eltype -export ArrayPartition +export ArrayPartition, NamedArrayPartition end # module diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl new file mode 100644 index 00000000..873fc49d --- /dev/null +++ b/src/named_array_partition.jl @@ -0,0 +1,114 @@ +""" + NamedArrayPartition(; kwargs...) + NamedArrayPartition(x::NamedTuple) + +Similar to an `ArrayPartition` but the individual arrays can be accessed via the +constructor-specified names. However, unlike `ArrayPartition`, each individual array +must have the same element type. +""" +struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T} + array_partition::A + names_to_indices::NT +end +NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs)) +function NamedArrayPartition(x::NamedTuple) + names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x))) + + # enforce homogeneity of eltypes + @assert all(eltype.(values(x)) .== eltype(first(x))) + T = eltype(first(x)) + S = typeof(values(x)) + return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices) +end + +# Note: overloading `getproperty` means we cannot access `NamedArrayPartition` +# fields except through `getfield` and accessor functions. +ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition) + +Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x)) + +Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} = + NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices)) +Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors + + +Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices)) +Base.getproperty(x::NamedArrayPartition, s::Symbol) = + getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s)) + +# this enables x.s = some_array. +@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v) + index = getproperty(getfield(x, :names_to_indices), s) + ArrayPartition(x).x[index] .= v +end + +# print out NamedArrayPartition as a NamedTuple +Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:") +Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) = + show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x))) + +Base.size(x::NamedArrayPartition) = size(ArrayPartition(x)) +Base.length(x::NamedArrayPartition) = length(ArrayPartition(x)) +Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...) + +Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...) +Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) +Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x)) +# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x)) + +Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} = + NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) + +# broadcasting +Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}() +function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, + ::Type{ElType}) where {ElType} + x = find_NamedArrayPartition(bc) + return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) +end + +# when broadcasting with ArrayPartition + another array type, the output is the other array tupe +Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) = + Broadcast.DefaultArrayStyle{1}() + +# hook into ArrayPartition broadcasting routines +@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x)) +@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = + Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args)) +@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i) + +Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} = + NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices)) + +@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function = + NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices) + +@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) + N = npartitions(bc) + @inline function f(i) + copy(unpack(bc, i)) + end + x = find_NamedArrayPartition(bc) + NamedArrayPartition(f, N, getfield(x, :names_to_indices)) +end + +@inline function Base.copyto!(dest::NamedArrayPartition, + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) + N = npartitions(dest, bc) + @inline function f(i) + copyto!(ArrayPartition(dest).x[i], unpack(bc, i)) + end + ntuple(f, Val(N)) + return dest +end + +# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments. +find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args) +find_NamedArrayPartition(args::Tuple) = + find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args)) +find_NamedArrayPartition(x) = x +find_NamedArrayPartition(::Tuple{}) = nothing +find_NamedArrayPartition(x::NamedArrayPartition, rest) = x +find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest) + + diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl new file mode 100644 index 00000000..e823246a --- /dev/null +++ b/test/named_array_partition_tests.jl @@ -0,0 +1,34 @@ +using RecursiveArrayTools, Test + +@testset "NamedArrayPartition tests" begin + x = NamedArrayPartition(a = ones(10), b = rand(20)) + @test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition + @test typeof(x.^2) <: NamedArrayPartition + @test x.a ≈ ones(10) + @test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence + @test all(x .== x[1:end]) + y = copy(x) + @test zero(x, (10, 20)) == zero(x) # test that ignoring dims works + @test typeof(zero(x)) <: NamedArrayPartition + @test (y .*= 2).a[1] ≈ 2 # test in-place bcast + + @test length(Array(x))==30 + @test typeof(Array(x)) <: Array + @test propertynames(x) == (:a, :b) + + x = NamedArrayPartition(a = ones(1), b = 2*ones(1)) + @test Base.summary(x) == string(typeof(x), " with arrays:") + io = IOBuffer() + Base.show(io, MIME"text/plain"(), x) + @test String(take!(io)) == "(a = [1.0], b = [2.0])" + + using StructArrays + using StaticArrays: SVector + x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))), + b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2)))) + @test typeof(x.a) <: StructVector{<:SVector{2}} + @test typeof(x.b) <: StructArray{<:SVector{2}, 2} + @test typeof((x->x[1]).(x)) <: NamedArrayPartition + @test typeof(map(x->x[1], x)) <: NamedArrayPartition +end + diff --git a/test/runtests.jl b/test/runtests.jl index 5819f70d..16285ee5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,9 +27,12 @@ end @time @safetestset "Utils Tests" begin include("utils_test.jl") end + @time @safetestset "NamedArrayPartition Tests" begin + include("named_array_partition_tests.jl") + end @time @safetestset "Partitions Tests" begin include("partitions_test.jl") - end + end @time @safetestset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end