@@ -147,27 +147,69 @@ end
147
147
VA. t,VA. u
148
148
end
149
149
150
- # Broadcast
151
-
152
- # add_idxs(x,expr) = expr
153
- # add_idxs{T<:AbstractVectorOfArray}(::Type{T},expr) = :($(expr)[i])
154
- # add_idxs{T<:AbstractArray}(::Type{Vector{T}},expr) = :($(expr)[i])
155
- #=
156
- @generated function Base.broadcast!(f,A::AbstractVectorOfArray,B...)
157
- exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...)
158
- :(for i in eachindex(A)
159
- broadcast!(f,A[i],$(exs...))
160
- end)
161
- end
162
-
163
- @generated function Base.broadcast(f,B::Union{Number,AbstractVectorOfArray}...)
164
- arr_idx = 0
165
- for (i,b) in enumerate(B)
166
- if b <: ArrayPartition
167
- arr_idx = i
168
- break
150
+ # # broadcasting
151
+
152
+ struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
153
+ VectorOfArrayStyle (:: Any ) = VectorOfArrayStyle ()
154
+ VectorOfArrayStyle (:: Any , :: Any ) = VectorOfArrayStyle ()
155
+
156
+ # promotion rules
157
+ # @inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
158
+ # VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
159
+ # end
160
+ Broadcast. BroadcastStyle (:: VectorOfArrayStyle , :: Broadcast.BroadcastStyle ) = VectorOfArrayStyle ()
161
+ Broadcast. BroadcastStyle (:: VectorOfArrayStyle , :: Broadcast.DefaultArrayStyle{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
162
+
163
+ function Broadcast. BroadcastStyle (:: Type{<:AbstractVectorOfArray{T,S}} ) where {T, S}
164
+ VectorOfArrayStyle ()
165
+ end
166
+
167
+ @inline function Base. copy (bc:: Broadcast.Broadcasted{VectorOfArrayStyle} )
168
+ N = narrays (bc)
169
+ x = unpack_voa (bc, 1 )
170
+ VectorOfArray (map (1 : N) do i
171
+ copy (unpack_voa (bc, i))
172
+ end )
173
+ end
174
+
175
+ @inline function Base. copyto! (dest:: AbstractVectorOfArray , bc:: Broadcast.Broadcasted{VectorOfArrayStyle} )
176
+ N = narrays (bc)
177
+ @inbounds for i in 1 : N
178
+ copyto! (dest[i], unpack_voa (bc, i))
169
179
end
170
- end
171
- :(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A)
180
+ dest
172
181
end
173
- =#
182
+
183
+ # # broadcasting utils
184
+
185
+ """
186
+ narrays(A...)
187
+
188
+ Retrieve number of arrays in the AbstractVectorOfArrays of a broadcast
189
+ """
190
+ narrays (A) = 0
191
+ narrays (A:: AbstractVectorOfArray ) = length (A)
192
+ narrays (bc:: Broadcast.Broadcasted ) = _narrays (bc. args)
193
+ narrays (A, Bs... ) = common_length (narrays (A), _narrays (Bs))
194
+
195
+ common_length (a, b) =
196
+ a == 0 ? b :
197
+ (b == 0 ? a :
198
+ (a == b ? a :
199
+ throw (DimensionMismatch (" number of arrays must be equal" ))))
200
+
201
+ _narrays (args:: AbstractVectorOfArray ) = length (args)
202
+ @inline _narrays (args:: Tuple ) = common_length (narrays (args[1 ]), _narrays (Base. tail (args)))
203
+ _narrays (args:: Tuple{Any} ) = _narrays (args[1 ])
204
+ _narrays (args:: Tuple{} ) = 0
205
+
206
+ # drop axes because it is easier to recompute
207
+ @inline unpack_voa (bc:: Broadcast.Broadcasted{Style} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args_voa (i, bc. args))
208
+ @inline unpack_voa (bc:: Broadcast.Broadcasted{VectorOfArrayStyle} , i) = Broadcast. Broadcasted (bc. f, unpack_args_voa (i, bc. args))
209
+ unpack_voa (x,:: Any ) = x
210
+ unpack_voa (x:: AbstractVectorOfArray , i) = x. u[i]
211
+ unpack_voa (x:: AbstractArray{T,N} , i) where {T,N} = @view x[ntuple (x-> Colon (),N- 1 )... ,i]
212
+
213
+ @inline unpack_args_voa (i, args:: Tuple ) = (unpack_voa (args[1 ], i), unpack_args_voa (i, Base. tail (args))... )
214
+ unpack_args_voa (i, args:: Tuple{Any} ) = (unpack_voa (args[1 ], i),)
215
+ unpack_args_voa (:: Any , args:: Tuple{} ) = ()
0 commit comments