diff --git a/src/array_partition.jl b/src/array_partition.jl index 0903d25b..93ae79f9 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -7,9 +7,9 @@ end ArrayPartition(x...) = ArrayPartition((x...,)) function ArrayPartition(x::S, ::Type{Val{copy_x}}=Val{false}) where {S<:Tuple,copy_x} - T = promote_type(recursive_bottom_eltype.(x)...) + T = promote_type(map(recursive_bottom_eltype,x)...) if copy_x - return ArrayPartition{T,S}(copy.(x)) + return ArrayPartition{T,S}(map(copy,x)) else return ArrayPartition{T,S}(x) end @@ -81,31 +81,31 @@ end for op in (:+, :-) @eval begin function Base.$op(A::ArrayPartition, B::ArrayPartition) - Base.broadcast($op, A, B) + ArrayPartition(map((x, y)->Base.broadcast($op, x, y), A.x, B.x)) end function Base.$op(A::ArrayPartition, B::Number) - Base.broadcast($op, A, B) + ArrayPartition(map(y->Base.broadcast($op, y, B), A.x)) end function Base.$op(A::Number, B::ArrayPartition) - Base.broadcast($op, A, B) + ArrayPartition(map(y->Base.broadcast($op, A, y), B.x)) end end end for op in (:*, :/) @eval function Base.$op(A::ArrayPartition, B::Number) - Base.broadcast($op, A, B) + ArrayPartition(map(y->Base.broadcast($op, y, B), A.x)) end end function Base.:*(A::Number, B::ArrayPartition) - Base.broadcast(*, A, B) + ArrayPartition(map(y->Base.broadcast(*, A, y), B.x)) end function Base.:\(A::Number, B::ArrayPartition) - Base.broadcast(/, B, A) + ArrayPartition(map(y->Base.broadcast(/, y, A), B.x)) end Base.:(==)(A::ArrayPartition,B::ArrayPartition) = A.x == B.x @@ -134,7 +134,7 @@ end function Base.copyto!(A::ArrayPartition,src::ArrayPartition) @assert length(src) == length(A) if size.(A.x) == size.(src.x) - A .= src + map(copyto!, A.x, src.x) else cnt = 0 for i in eachindex(A.x) @@ -281,9 +281,10 @@ end @inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style N = npartitions(dest, bc) - @inbounds for i in 1:N + @inline function f(i) copyto!(dest.x[i], unpack(bc, i)) end + ntuple(f, Val(N)) dest end diff --git a/test/partitions_test.jl b/test/partitions_test.jl index cce344b3..0866d647 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -101,7 +101,7 @@ _scalar_op(y) = y + 1 # Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function: _broadcast_wrapper(y) = _scalar_op.(y) # Issue #8 -# @inferred _broadcast_wrapper(x) +@inferred _broadcast_wrapper(x) # Testing map @test map(x->x^2, x) == ArrayPartition(x.x[1].^2, x.x[2].^2)