Skip to content

Commit a354738

Browse files
Merge pull request #103 from SciML/myb/bcbc
More robust broadcast
2 parents 9a7c1e4 + c89d00c commit a354738

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.4.0"
4+
version = "2.4.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/vector_of_array.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ Base.vec(VA::AbstractVectorOfArray) = vec(convert(Array,VA)) # Allocates
129129
@inline Statistics.cor(VA::AbstractVectorOfArray;kwargs...) = cor(Array(VA);kwargs...)
130130

131131
# make it show just like its data
132-
Base.show(io::IO, x::AbstractVectorOfArray) = show(io, x.u)
133-
Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray) = show(io, m, x.u)
132+
Base.show(io::IO, x::AbstractVectorOfArray) = Base.print_array(io, x.u)
133+
Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray) = (println(io, summary(x), ':'); show(io, m, x.u))
134134
Base.summary(A::AbstractVectorOfArray) = string("VectorOfArray{",eltype(A),",",ndims(A),"}")
135135

136136
Base.show(io::IO, x::AbstractDiffEqArray) = (print(io,"t: ");show(io, x.t);println(io);print(io,"u: ");show(io, x.u))
@@ -149,30 +149,25 @@ end
149149

150150
## broadcasting
151151

152-
struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
153-
VectorOfArrayStyle(::Any) = VectorOfArrayStyle()
154-
VectorOfArrayStyle(::Any, ::Any) = VectorOfArrayStyle()
152+
struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays
153+
VectorOfArrayStyle(::Val{N}) where N = VectorOfArrayStyle{N}()
155154

156-
# promotion rules
157-
#@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
158-
# VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
159-
#end
160-
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.BroadcastStyle) = VectorOfArrayStyle()
161-
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
155+
# The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle.
156+
Broadcast.BroadcastStyle(a::VectorOfArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a
157+
Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle{M}) where {M,N} = Base.Broadcast.DefaultArrayStyle(Val(max(M, N)))
158+
Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.AbstractArrayStyle{M}) where {M,N} = typeof(a)(Val(max(M, N)))
159+
Broadcast.BroadcastStyle(::VectorOfArrayStyle{M}, ::VectorOfArrayStyle{N}) where {M,N} = VectorOfArrayStyle(Val(max(M, N)))
160+
Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,N}}) where {T,N} = VectorOfArrayStyle{N}()
162161

163-
function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,S}}) where {T, S}
164-
VectorOfArrayStyle()
165-
end
166-
167-
@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle})
162+
@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
168163
N = narrays(bc)
169164
x = unpack_voa(bc, 1)
170165
VectorOfArray(map(1:N) do i
171166
copy(unpack_voa(bc, i))
172167
end)
173168
end
174169

175-
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle})
170+
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
176171
N = narrays(bc)
177172
@inbounds for i in 1:N
178173
copyto!(dest[i], unpack_voa(bc, i))
@@ -201,11 +196,11 @@ common_length(a, b) =
201196
_narrays(args::AbstractVectorOfArray) = length(args)
202197
@inline _narrays(args::Tuple) = common_length(narrays(args[1]), _narrays(Base.tail(args)))
203198
_narrays(args::Tuple{Any}) = _narrays(args[1])
204-
_narrays(args::Tuple{}) = 0
199+
_narrays(::Any) = 0
205200

206201
# drop axes because it is easier to recompute
207202
@inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
208-
@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args))
203+
@inline unpack_voa(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args))
209204
unpack_voa(x,::Any) = x
210205
unpack_voa(x::AbstractVectorOfArray, i) = x.u[i]
211206
unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i]

test/basic_indexing.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@ using RecursiveArrayTools, Test
44
recs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
55
testa = cat(recs..., dims=2)
66
testva = VectorOfArray(recs)
7+
8+
# broadcast with array
9+
X = rand(3, 3)
10+
mulX = testva .* X
11+
ref = mapreduce((x,y)->x.*y, hcat, testva, eachcol(X))
12+
@test mulX == ref
13+
fill!(mulX, 0)
14+
mulX .= testva .* X
15+
@test mulX == ref
16+
717
t = [1,2,3]
818
diffeq = DiffEqArray(recs,t)
919
@test Array(testva) == [1 4 7
@@ -73,11 +83,6 @@ testva = VectorOfArray(recs) #TODO: clearly this printed form is nonsense
7383
@test testva[:, 1] == recs[1]
7484
testva[1:2, 1:2]
7585

76-
# Test broadcast
77-
a = testva .+ rand(3,3)
78-
a.= testva
79-
@test all(a .== testva)
80-
8186
recs = [rand(2,2) for i in 1:5]
8287
testva = VectorOfArray(recs)
8388
@test Array(testva) isa Array{Float64,3}
@@ -97,3 +102,8 @@ w = v .* v
97102
x = copy(v)
98103
x .= v .* v
99104
@test x.u == w.u
105+
106+
# broadcast with number
107+
w = v .+ 1
108+
@test w isa VectorOfArray
109+
@test w.u == map(x -> x .+ 1, v.u)

0 commit comments

Comments
 (0)