diff --git a/src/array_partition.jl b/src/array_partition.jl index 0689f1fb..83c8a891 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -306,6 +306,8 @@ _npartitions(args::Tuple{}) = 0 # drop axes because it is easier to recompute @inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args)) @inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args)) +@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style <: Broadcast.DefaultArrayStyle = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) +@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style <: Broadcast.DefaultArrayStyle = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) unpack(x,::Any) = x unpack(x::ArrayPartition, i) = x.x[i] diff --git a/test/partitions_test.jl b/test/partitions_test.jl index b9049b30..6b4db521 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -172,6 +172,13 @@ up = ap .+ 1 up = 2 .* ap .+ 1 @test typeof(ap) == typeof(up) +# Test that `zeros()` does not get screwed up +ap = ArrayPartition(zeros(),[1.0]) +up = ap .+ 1 +@test typeof(ap) == typeof(up) + +up = 2 .* ap .+ 1 +@test typeof(ap) == typeof(up) @testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1,2, false), ([1], 2, false), ([1], [2], true)) @test ArrayInterface.ismutable(ArrayPartition(a, b)) == r