Skip to content

Commit f7415e8

Browse files
Merge pull request #374 from jlchan/jc/fix_bcast_multidim_VoA
Fix multi-dimensional `VectorOfArray` broadcast
2 parents eacbe3f + b7f9c84 commit f7415e8

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

src/vector_of_array.jl

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,15 @@ the `VectorOfArray` into a matrix/tensor. Also, `vecarr_to_vectors(VA::AbstractV
2828
returns a vector of the series for each component, that is, `A[i,:]` for each `i`.
2929
A plot recipe is provided, which plots the `A[i,:]` series.
3030
31-
There is also support for `VectorOfArray` with constructed from multi-dimensional arrays
32-
31+
There is also support for `VectorOfArray` constructed from multi-dimensional arrays
3332
```julia
3433
VectorOfArray(u::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
3534
```
3635
3736
where `IndexStyle(typeof(u)) isa IndexLinear`.
3837
"""
3938
mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
40-
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
39+
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
4140
end
4241
# VectorOfArray with an added series for time
4342

@@ -719,7 +718,7 @@ end
719718
# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
720719
function Base.similar(vec::VectorOfArray{
721720
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
722-
return VectorOfArray(similar(Base.parent(vec)))
721+
return VectorOfArray(similar.(Base.parent(vec)))
723722
end
724723

725724
# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
@@ -728,6 +727,7 @@ function Base.similar(vec::VectorOfArray{
728727
return Base.similar(vec, eltype(vec))
729728
end
730729

730+
731731
# fill!
732732
# For DiffEqArray it ignores ts and fills only u
733733
function Base.fill!(VA::AbstractVectorOfArray, x)
@@ -840,12 +840,37 @@ end
840840
# make vectorofarrays broadcastable so they aren't collected
841841
Broadcast.broadcastable(x::AbstractVectorOfArray) = x
842842

843+
# recurse through broadcast arguments and return a parent array for
844+
# the first VoA or DiffEqArray in the bc arguments
845+
function find_VoA_parent(args)
846+
arg = Base.first(args)
847+
if arg isa AbstractDiffEqArray
848+
# if first(args) is a DiffEqArray, use the underlying
849+
# field `u` of DiffEqArray as a parent array.
850+
return arg.u
851+
elseif arg isa AbstractVectorOfArray
852+
return parent(arg)
853+
else
854+
return find_VoA_parent(Base.tail(args))
855+
end
856+
end
857+
843858
@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
844859
bc = Broadcast.flatten(bc)
845-
N = narrays(bc)
846-
VectorOfArray(map(1:N) do i
847-
copy(unpack_voa(bc, i))
848-
end)
860+
861+
parent = find_VoA_parent(bc.args)
862+
863+
if parent isa AbstractVector
864+
# this is the default behavior in v3.15.0
865+
N = narrays(bc)
866+
return VectorOfArray(map(1:N) do i
867+
copy(unpack_voa(bc, i))
868+
end)
869+
else # if parent isa AbstractArray
870+
return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
871+
copy(unpack_voa(bc, i))
872+
end)
873+
end
849874
end
850875

851876
for (type, N_expr) in [

test/basic_indexing.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ a[[1, 3, 8]]
238238
# multidimensional array of arrays
239239
####################################################################
240240

241-
u_matrix = VectorOfArray(fill([1, 2], 2, 3))
242-
u_vector = VectorOfArray(vec(u_matrix.u))
241+
u_matrix = VectorOfArray([[1, 2] for i in 1:2, j in 1:3])
242+
u_vector = VectorOfArray([[1, 2] for i in 1:6])
243243

244244
# test broadcasting
245245
function foo!(u)
@@ -248,11 +248,13 @@ function foo!(u)
248248
end
249249
foo!(u_matrix)
250250
foo!(u_vector)
251-
@test u_matrix u_vector
251+
@test all(u_matrix .== [3, 10])
252+
@test all(vec(u_matrix) .≈ vec(u_vector))
252253

253254
# test that, for VectorOfArray with multi-dimensional parent arrays,
254-
# `similar` preserves the structure of the parent array
255+
# broadcast and `similar` preserve the structure of the parent array
255256
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
257+
@test typeof(parent((x->x).(u_matrix))) == typeof(parent(u_matrix))
256258

257259
# test efficiency
258260
num_allocs = @allocations foo!(u_matrix)

0 commit comments

Comments
 (0)