From 9920fed013fe266e978ce546b69cd5760f46d4c1 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 23 Jul 2021 23:53:50 +0200 Subject: [PATCH 1/2] speed up some use cases of ArrayPartition --- src/array_partition.jl | 22 ++++++++++------------ test/partitions_test.jl | 2 +- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 0903d25b..761ac547 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -7,9 +7,9 @@ end ArrayPartition(x...) = ArrayPartition((x...,)) function ArrayPartition(x::S, ::Type{Val{copy_x}}=Val{false}) where {S<:Tuple,copy_x} - T = promote_type(recursive_bottom_eltype.(x)...) + T = promote_type(map(recursive_bottom_eltype,x)...) if copy_x - return ArrayPartition{T,S}(copy.(x)) + return ArrayPartition{T,S}(map(copy,x)) else return ArrayPartition{T,S}(x) end @@ -81,31 +81,31 @@ end for op in (:+, :-) @eval begin function Base.$op(A::ArrayPartition, B::ArrayPartition) - Base.broadcast($op, A, B) + ArrayPartition(map((x, y)->Base.broadcast($op, x, y), A.x, B.x)) end function Base.$op(A::ArrayPartition, B::Number) - Base.broadcast($op, A, B) + ArrayPartition(map(y->Base.broadcast($op, y, B), A.x)) end function Base.$op(A::Number, B::ArrayPartition) - Base.broadcast($op, A, B) + ArrayPartition(map(y->Base.broadcast($op, A, y), B.x)) end end end for op in (:*, :/) @eval function Base.$op(A::ArrayPartition, B::Number) - Base.broadcast($op, A, B) + ArrayPartition(map(y->Base.broadcast($op, y, B), A.x)) end end function Base.:*(A::Number, B::ArrayPartition) - Base.broadcast(*, A, B) + ArrayPartition(map(y->Base.broadcast(*, A, y), B.x)) end function Base.:\(A::Number, B::ArrayPartition) - Base.broadcast(/, B, A) + ArrayPartition(map(y->Base.broadcast(/, y, A), B.x)) end Base.:(==)(A::ArrayPartition,B::ArrayPartition) = A.x == B.x @@ -134,7 +134,7 @@ end function Base.copyto!(A::ArrayPartition,src::ArrayPartition) @assert length(src) == length(A) if size.(A.x) == size.(src.x) - A .= src + map(copyto!, A.x, src.x) else cnt = 0 for i in eachindex(A.x) @@ -281,9 +281,7 @@ end @inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style N = npartitions(dest, bc) - @inbounds for i in 1:N - copyto!(dest.x[i], unpack(bc, i)) - end + ntuple(i->copyto!(dest.x[i], unpack(bc, i)), Val(N)) dest end diff --git a/test/partitions_test.jl b/test/partitions_test.jl index cce344b3..0866d647 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -101,7 +101,7 @@ _scalar_op(y) = y + 1 # Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function: _broadcast_wrapper(y) = _scalar_op.(y) # Issue #8 -# @inferred _broadcast_wrapper(x) +@inferred _broadcast_wrapper(x) # Testing map @test map(x->x^2, x) == ArrayPartition(x.x[1].^2, x.x[2].^2) From 5de3bb1362e311d2483626af7a5ddecf500ba4bf Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sat, 24 Jul 2021 22:42:55 +0200 Subject: [PATCH 2/2] inline one more local function --- src/array_partition.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 761ac547..93ae79f9 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -281,7 +281,10 @@ end @inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style N = npartitions(dest, bc) - ntuple(i->copyto!(dest.x[i], unpack(bc, i)), Val(N)) + @inline function f(i) + copyto!(dest.x[i], unpack(bc, i)) + end + ntuple(f, Val(N)) dest end