From 4d8130e60b500014c36419d94530e7cf66958996 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Dec 2023 13:08:14 +0530 Subject: [PATCH 1/6] feat: add ability to set VectorOfArray with Array using broadcast --- src/vector_of_array.jl | 15 ++++++++++++++- test/interface_tests.jl | 8 ++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 42c24089..ae5901fb 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -663,7 +663,20 @@ end bc = Broadcast.flatten(bc) N = narrays(bc) @inbounds for i in 1:N - if dest[:, i] isa AbstractArray && !isa(dest[:, i], StaticArraysCore.SArray) + if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) + copyto!(dest[:, i], unpack_voa(bc, i)) + else + dest[:, i] = copy(unpack_voa(bc, i)) + end + end + dest +end + +@inline function Base.copyto!(dest::AbstractVectorOfArray, + bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}) + bc = Broadcast.flatten(bc) + @inbounds for i in 1:length(dest.u) + if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) copyto!(dest[:, i], unpack_voa(bc, i)) else dest[:, i] = copy(unpack_voa(bc, i)) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 4f79c3e6..ad3fec43 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -125,3 +125,11 @@ z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})]) z .= x .+ y @test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) + +yy = [2.0 1.0; 2.0 1.0] +zz = x .+ yy +@test zz == [4.0 2.0; 4.0 2.0] + +z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})]) +z .= zz +@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) From 8854baa7ad3c2b3b7c11602751bcc396c4faf6ef Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 21 Dec 2023 08:12:17 -0500 Subject: [PATCH 2/6] Update test/interface_tests.jl --- test/interface_tests.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index ad3fec43..7f2288c0 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -126,6 +126,16 @@ z .= x .+ y @test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) +u1 = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})]) +u2 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})]) +u3 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})]) + +function f(u1,u2,u3) + u3 .= u1 .+ u2 +end +f(u1,u2,u3) +@test@allocated f(u1,u2,u3) == 0 + yy = [2.0 1.0; 2.0 1.0] zz = x .+ yy @test zz == [4.0 2.0; 4.0 2.0] From 472a6069eaf469690e2cab13a4d7511250bc20e8 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 21 Dec 2023 08:14:18 -0500 Subject: [PATCH 3/6] Update interface_tests.jl --- test/interface_tests.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 7f2288c0..8d8fa804 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -134,7 +134,7 @@ function f(u1,u2,u3) u3 .= u1 .+ u2 end f(u1,u2,u3) -@test@allocated f(u1,u2,u3) == 0 +@test @allocated f(u1,u2,u3) == 0 yy = [2.0 1.0; 2.0 1.0] zz = x .+ yy @@ -143,3 +143,9 @@ zz = x .+ yy z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})]) z .= zz @test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) + +function f!(z,zz) + z .= zz +end +f!(z,zz) +@test @allocated f!(z,zz) == 0 From adcbc85fb98a17f39b866b91acae05e7f7780707 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Dec 2023 18:50:02 +0530 Subject: [PATCH 4/6] fixup! Update interface_tests.jl --- test/interface_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 8d8fa804..2ee949d2 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -148,4 +148,4 @@ function f!(z,zz) z .= zz end f!(z,zz) -@test @allocated f!(z,zz) == 0 +@test (@allocated f!(z,zz)) == 0 From fed6857b253b7e2c64d28b7eed2a3daeb5dbbf39 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 21 Dec 2023 08:33:37 -0500 Subject: [PATCH 5/6] Update test/interface_tests.jl --- test/interface_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 2ee949d2..cba4727a 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -134,7 +134,7 @@ function f(u1,u2,u3) u3 .= u1 .+ u2 end f(u1,u2,u3) -@test @allocated f(u1,u2,u3) == 0 +@test (@allocated f(u1,u2,u3)) == 0 yy = [2.0 1.0; 2.0 1.0] zz = x .+ yy From 300f692a3f42e5538f03264e4c95f29633b53892 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 22 Dec 2023 12:16:24 +0530 Subject: [PATCH 6/6] refactor: avoid allocations when broadcasting over VoA containing SArrays --- src/vector_of_array.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index ae5901fb..933c542d 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -666,7 +666,8 @@ end if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) copyto!(dest[:, i], unpack_voa(bc, i)) else - dest[:, i] = copy(unpack_voa(bc, i)) + unpacked = unpack_voa(bc, i) + dest[:, i] = unpacked.f(unpacked.args...) end end dest @@ -679,7 +680,8 @@ end if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) copyto!(dest[:, i], unpack_voa(bc, i)) else - dest[:, i] = copy(unpack_voa(bc, i)) + unpacked = unpack_voa(bc, i) + dest[:, i] = unpacked.f(unpacked.args...) end end dest