From 28f9101dca0634e90cbe016ed72fbde916d4fbb6 Mon Sep 17 00:00:00 2001 From: Jeff Eldredge Date: Sat, 17 Apr 2021 07:26:57 -0700 Subject: [PATCH 1/2] This allows DefaultArrayStyle to be broadcast according to normal rules --- src/array_partition.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/array_partition.jl b/src/array_partition.jl index 0689f1fb..83c8a891 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -306,6 +306,8 @@ _npartitions(args::Tuple{}) = 0 # drop axes because it is easier to recompute @inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args)) @inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args)) +@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style <: Broadcast.DefaultArrayStyle = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) +@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style <: Broadcast.DefaultArrayStyle = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) unpack(x,::Any) = x unpack(x::ArrayPartition, i) = x.x[i] From ffcbf5ec59aa5f594ff1c2641b100c938c1d8431 Mon Sep 17 00:00:00 2001 From: Jeff Eldredge Date: Sat, 17 Apr 2021 07:52:53 -0700 Subject: [PATCH 2/2] Unit tests added --- test/partitions_test.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/partitions_test.jl b/test/partitions_test.jl index b9049b30..6b4db521 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -172,6 +172,13 @@ up = ap .+ 1 up = 2 .* ap .+ 1 @test typeof(ap) == typeof(up) +# Test that `zeros()` does not get screwed up +ap = ArrayPartition(zeros(),[1.0]) +up = ap .+ 1 +@test typeof(ap) == typeof(up) + +up = 2 .* ap .+ 1 +@test typeof(ap) == typeof(up) @testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1,2, false), ([1], 2, false), ([1], [2], true)) @test ArrayInterface.ismutable(ArrayPartition(a, b)) == r