-
-
Notifications
You must be signed in to change notification settings - Fork 69
adding NamedArrayPartition
type
#293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
b6f5bf2
adding NamedArrayPartition and tests
jlchan 64cadb3
exporting
jlchan 5e2dc65
adding tests
jlchan 9f61bfc
Update test/named_array_partition_tests.jl
jlchan 0339fa4
Merge branch 'master' into jc/NamedArrayPartition
jlchan 7e10283
bumping version back to 2.39.0
jlchan e6d2680
Merge branch 'jc/NamedArrayPartition' of https://github.com/jlchan/Re…
jlchan f673e37
Merge remote-tracking branch 'origin/master' into jc/NamedArrayPartition
jlchan fcd37dc
Merge branch 'master' into jc/NamedArrayPartition
jlchan 267be45
Update named_array_partition_tests.jl
ChrisRackauckas 18a75d0
replacing @capture_out with IOBuffer
jlchan d663c49
Update array_types.md
ChrisRackauckas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ mapping and iteration functions, and more. | |
VectorOfArray | ||
DiffEqArray | ||
ArrayPartition | ||
NamedArrayPartition | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
""" | ||
NamedArrayPartition(; kwargs...) | ||
NamedArrayPartition(x::NamedTuple) | ||
|
||
Similar to an `ArrayPartition` but the individual arrays can be accessed via the | ||
constructor-specified names. However, unlike `ArrayPartition`, each individual array | ||
must have the same element type. | ||
""" | ||
struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T} | ||
array_partition::A | ||
names_to_indices::NT | ||
end | ||
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs)) | ||
function NamedArrayPartition(x::NamedTuple) | ||
names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x))) | ||
|
||
# enforce homogeneity of eltypes | ||
@assert all(eltype.(values(x)) .== eltype(first(x))) | ||
T = eltype(first(x)) | ||
S = typeof(values(x)) | ||
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices) | ||
end | ||
|
||
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition` | ||
# fields except through `getfield` and accessor functions. | ||
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition) | ||
|
||
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x)) | ||
|
||
Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} = | ||
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices)) | ||
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors | ||
|
||
|
||
Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices)) | ||
Base.getproperty(x::NamedArrayPartition, s::Symbol) = | ||
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s)) | ||
|
||
# this enables x.s = some_array. | ||
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v) | ||
index = getproperty(getfield(x, :names_to_indices), s) | ||
ArrayPartition(x).x[index] .= v | ||
end | ||
|
||
# print out NamedArrayPartition as a NamedTuple | ||
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:") | ||
Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) = | ||
show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x))) | ||
|
||
Base.size(x::NamedArrayPartition) = size(ArrayPartition(x)) | ||
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x)) | ||
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...) | ||
|
||
Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...) | ||
Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) | ||
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x)) | ||
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x)) | ||
|
||
Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} = | ||
NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) | ||
|
||
# broadcasting | ||
Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}() | ||
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, | ||
::Type{ElType}) where {ElType} | ||
x = find_NamedArrayPartition(bc) | ||
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) | ||
end | ||
|
||
# when broadcasting with ArrayPartition + another array type, the output is the other array tupe | ||
Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) = | ||
Broadcast.DefaultArrayStyle{1}() | ||
|
||
# hook into ArrayPartition broadcasting routines | ||
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x)) | ||
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = | ||
Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args)) | ||
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i) | ||
|
||
Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} = | ||
NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices)) | ||
|
||
@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function = | ||
NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices) | ||
|
||
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) | ||
N = npartitions(bc) | ||
@inline function f(i) | ||
copy(unpack(bc, i)) | ||
end | ||
x = find_NamedArrayPartition(bc) | ||
NamedArrayPartition(f, N, getfield(x, :names_to_indices)) | ||
end | ||
|
||
@inline function Base.copyto!(dest::NamedArrayPartition, | ||
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) | ||
N = npartitions(dest, bc) | ||
@inline function f(i) | ||
copyto!(ArrayPartition(dest).x[i], unpack(bc, i)) | ||
end | ||
ntuple(f, Val(N)) | ||
return dest | ||
end | ||
|
||
# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments. | ||
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args) | ||
find_NamedArrayPartition(args::Tuple) = | ||
find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args)) | ||
find_NamedArrayPartition(x) = x | ||
find_NamedArrayPartition(::Tuple{}) = nothing | ||
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x | ||
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest) | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
using RecursiveArrayTools, Test | ||
|
||
@testset "NamedArrayPartition tests" begin | ||
x = NamedArrayPartition(a = ones(10), b = rand(20)) | ||
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition | ||
@test typeof(x.^2) <: NamedArrayPartition | ||
@test x.a ≈ ones(10) | ||
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence | ||
@test all(x .== x[1:end]) | ||
y = copy(x) | ||
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works | ||
@test typeof(zero(x)) <: NamedArrayPartition | ||
@test (y .*= 2).a[1] ≈ 2 # test in-place bcast | ||
|
||
@test length(Array(x))==30 | ||
@test typeof(Array(x)) <: Array | ||
@test propertynames(x) == (:a, :b) | ||
|
||
x = NamedArrayPartition(a = ones(1), b = 2*ones(1)) | ||
@test Base.summary(x) == string(typeof(x), " with arrays:") | ||
io = IOBuffer() | ||
Base.show(io, MIME"text/plain"(), x) | ||
@test String(take!(io)) == "(a = [1.0], b = [2.0])" | ||
|
||
using StructArrays | ||
using StaticArrays: SVector | ||
x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))), | ||
b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2)))) | ||
@test typeof(x.a) <: StructVector{<:SVector{2}} | ||
@test typeof(x.b) <: StructArray{<:SVector{2}, 2} | ||
@test typeof((x->x[1]).(x)) <: NamedArrayPartition | ||
@test typeof(map(x->x[1], x)) <: NamedArrayPartition | ||
end | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.