From 181dbb67374b6ca22695f847d1c00234948f672e Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 16 Mar 2022 15:11:17 +0100 Subject: [PATCH] Fix shaped similar for ArrayPartition --- src/array_partition.jl | 20 ++++++++++++++++---- test/partitions_test.jl | 6 ++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 06da62f5..67db9f21 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -19,8 +19,14 @@ end Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(A.x)) -# ignore dims since array partitions are vectors -Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = similar(A) +# return ArrayPartition when possible, otherwise next best thing of the correct size +function Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} + if dims == size(A) + return similar(A) + else + return similar(A.x[1], eltype(A), dims) + end +end # similar array partition of common type @inline function Base.similar(A::ArrayPartition, ::Type{T}) where {T} @@ -28,8 +34,14 @@ Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = similar(A) ArrayPartition(i->similar(A.x[i], T), N) end -# ignore dims since array partitions are vectors -Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(A, T) +# return ArrayPartition when possible, otherwise next best thing of the correct size +function Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} + if dims == size(A) + return similar(A, T) + else + return similar(A.x[1], T, dims) + end +end # similar array partition with different types function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S} diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 65f77195..21d7382f 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -61,9 +61,11 @@ x = ArrayPartition([1, 2], [3.0, 4.0]) # similar partitions @inferred similar(x) -@inferred similar(x, (2, 2)) +@test similar(x, (4,)) isa ArrayPartition{Float64} +@test (@inferred similar(x, (2, 2))) isa AbstractMatrix{Float64} @inferred similar(x, Int) -@inferred similar(x, Int, (2, 2)) +@test similar(x, Int, (4,)) isa ArrayPartition{Int} +@test (@inferred similar(x, Int, (2, 2))) isa AbstractMatrix{Int} # @inferred similar(x, Int, Float64) # zero