Skip to content

Commit 6cca897

Browse files
Merge branch 'voa_broadcast'
2 parents 6c7df3c + d77105a commit 6cca897

File tree

2 files changed

+76
-23
lines changed

2 files changed

+76
-23
lines changed

src/vector_of_array.jl

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -147,27 +147,69 @@ end
147147
VA.t,VA.u
148148
end
149149

150-
# Broadcast
151-
152-
#add_idxs(x,expr) = expr
153-
#add_idxs{T<:AbstractVectorOfArray}(::Type{T},expr) = :($(expr)[i])
154-
#add_idxs{T<:AbstractArray}(::Type{Vector{T}},expr) = :($(expr)[i])
155-
#=
156-
@generated function Base.broadcast!(f,A::AbstractVectorOfArray,B...)
157-
exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...)
158-
:(for i in eachindex(A)
159-
broadcast!(f,A[i],$(exs...))
160-
end)
161-
end
162-
163-
@generated function Base.broadcast(f,B::Union{Number,AbstractVectorOfArray}...)
164-
arr_idx = 0
165-
for (i,b) in enumerate(B)
166-
if b <: ArrayPartition
167-
arr_idx = i
168-
break
150+
## broadcasting
151+
152+
struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
153+
VectorOfArrayStyle(::Any) = VectorOfArrayStyle()
154+
VectorOfArrayStyle(::Any, ::Any) = VectorOfArrayStyle()
155+
156+
# promotion rules
157+
#@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
158+
# VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
159+
#end
160+
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.BroadcastStyle) = VectorOfArrayStyle()
161+
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
162+
163+
function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,S}}) where {T, S}
164+
VectorOfArrayStyle()
165+
end
166+
167+
@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle})
168+
N = narrays(bc)
169+
x = unpack_voa(bc, 1)
170+
VectorOfArray(map(1:N) do i
171+
copy(unpack_voa(bc, i))
172+
end)
173+
end
174+
175+
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle})
176+
N = narrays(bc)
177+
@inbounds for i in 1:N
178+
copyto!(dest[i], unpack_voa(bc, i))
169179
end
170-
end
171-
:(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A)
180+
dest
172181
end
173-
=#
182+
183+
## broadcasting utils
184+
185+
"""
186+
narrays(A...)
187+
188+
Retrieve number of arrays in the AbstractVectorOfArrays of a broadcast
189+
"""
190+
narrays(A) = 0
191+
narrays(A::AbstractVectorOfArray) = length(A)
192+
narrays(bc::Broadcast.Broadcasted) = _narrays(bc.args)
193+
narrays(A, Bs...) = common_length(narrays(A), _narrays(Bs))
194+
195+
common_length(a, b) =
196+
a == 0 ? b :
197+
(b == 0 ? a :
198+
(a == b ? a :
199+
throw(DimensionMismatch("number of arrays must be equal"))))
200+
201+
_narrays(args::AbstractVectorOfArray) = length(args)
202+
@inline _narrays(args::Tuple) = common_length(narrays(args[1]), _narrays(Base.tail(args)))
203+
_narrays(args::Tuple{Any}) = _narrays(args[1])
204+
_narrays(args::Tuple{}) = 0
205+
206+
# drop axes because it is easier to recompute
207+
@inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
208+
@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args))
209+
unpack_voa(x,::Any) = x
210+
unpack_voa(x::AbstractVectorOfArray, i) = x.u[i]
211+
unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i]
212+
213+
@inline unpack_args_voa(i, args::Tuple) = (unpack_voa(args[1], i), unpack_args_voa(i, Base.tail(args))...)
214+
unpack_args_voa(i, args::Tuple{Any}) = (unpack_voa(args[1], i),)
215+
unpack_args_voa(::Any, args::Tuple{}) = ()

test/basic_indexing.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools
1+
using RecursiveArrayTools, Test
22

33
# Example Problem
44
recs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
@@ -86,3 +86,14 @@ v = VectorOfArray([zeros(20), zeros(10,10), zeros(3,3,3)])
8686
v[CartesianIndex((2, 3, 2, 3))] = 1
8787
@test v[CartesianIndex((2, 3, 2, 3))] == 1
8888
@test v.u[3][2, 3, 2] == 1
89+
90+
v = VectorOfArray([rand(20), rand(10,10), rand(3,3,3)])
91+
w = v .* v
92+
@test w isa VectorOfArray
93+
@test w[1] isa Vector
94+
@test w[1] == v[1] .* v[1]
95+
@test w[2] == v[2] .* v[2]
96+
@test w[3] == v[3] .* v[3]
97+
x = copy(v)
98+
x .= v .* v
99+
@test x.u == w.u

0 commit comments

Comments
 (0)