diff --git a/src/array_partition.jl b/src/array_partition.jl index 2a19461c..0689f1fb 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -304,8 +304,8 @@ _npartitions(args::Tuple{Any}) = npartitions(args[1]) _npartitions(args::Tuple{}) = 0 # drop axes because it is easier to recompute -@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) -@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) +@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args)) +@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args)) unpack(x,::Any) = x unpack(x::ArrayPartition, i) = x.x[i] diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 0b5f9b9b..b9049b30 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -136,6 +136,43 @@ end foo(xcde0, xce0) #@test 0 == @allocated foo(xcde0, xce0) +# Custom AbstractArray types broadcasting +struct MyType{T} <: AbstractVector{T} + data :: Vector{T} +end +Base.similar(A::MyType{T}) where {T} = MyType{T}(similar(A.data)) +Base.similar(A::MyType{T},::Type{S}) where {T,S} = MyType(similar(A.data,S)) + +Base.size(A::MyType) = size(A.data) +Base.getindex(A::MyType, i::Int) = getindex(A.data,i) +Base.setindex!(A::MyType, v, i::Int) = setindex!(A.data,v,i) +Base.IndexStyle(::MyType) = IndexLinear() + +Base.BroadcastStyle(::Type{<:MyType}) = Broadcast.ArrayStyle{MyType}() + +function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}},::Type{T}) where {T} + similar(find_mt(bc),T) +end + +function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}}) + similar(find_mt(bc)) +end + +find_mt(bc::Base.Broadcast.Broadcasted) = find_mt(bc.args) +find_mt(args::Tuple) = find_mt(find_mt(args[1]), Base.tail(args)) +find_mt(x) = x +find_mt(::Tuple{}) = nothing +find_mt(a::MyType, rest) = a +find_mt(::Any, rest) = find_mt(rest) + +ap = ArrayPartition(MyType(ones(10)),collect(1:2)) +up = ap .+ 1 +@test typeof(ap) == typeof(up) + +up = 2 .* ap .+ 1 +@test typeof(ap) == typeof(up) + + @testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1,2, false), ([1], 2, false), ([1], [2], true)) @test ArrayInterface.ismutable(ArrayPartition(a, b)) == r end