From bb33f3955a12b5173fee6f740338ec639e6d4a3e Mon Sep 17 00:00:00 2001 From: Klaus Crusius Date: Thu, 7 Jan 2021 12:03:42 +0100 Subject: [PATCH 1/2] fix ArrayInterface.ismutable(::ArrayPartition) --- .gitignore | 1 + src/array_partition.jl | 5 +++++ test/partitions_test.jl | 6 +++++- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3f02ca74..49a815a8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.jl.*.cov *.jl.mem Manifest.toml +.vscode diff --git a/src/array_partition.jl b/src/array_partition.jl index 45043081..688c567b 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -70,6 +70,11 @@ end # ignore dims since array partitions are vectors Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A) +# mutable iff all components of ArrayPartition are mutable +function ArrayInterface.ismutable(::Type{<:ArrayPartition{T,S}}) where {T,S} + all(ArrayInterface.ismutable, S.parameters) +end + ## vector space operations for op in (:+, :-) diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 89f4a198..0b5f9b9b 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -1,4 +1,4 @@ -using RecursiveArrayTools, Test, Statistics +using RecursiveArrayTools, Test, Statistics, ArrayInterface A = (rand(5),rand(5)) p = ArrayPartition(A) @test (p.x[1][1],p.x[2][1]) == (p[1],p[6]) @@ -135,3 +135,7 @@ function foo(y, x) end foo(xcde0, xce0) #@test 0 == @allocated foo(xcde0, xce0) + +@testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1,2, false), ([1], 2, false), ([1], [2], true)) + @test ArrayInterface.ismutable(ArrayPartition(a, b)) == r +end From 3402cb5c6fda5f78924787aa2e1238148b57e515 Mon Sep 17 00:00:00 2001 From: Klaus Crusius Date: Fri, 8 Jan 2021 11:08:35 +0100 Subject: [PATCH 2/2] ismutable as generated function --- src/array_partition.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 688c567b..a2db9a0a 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -71,8 +71,9 @@ end Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A) # mutable iff all components of ArrayPartition are mutable -function ArrayInterface.ismutable(::Type{<:ArrayPartition{T,S}}) where {T,S} - all(ArrayInterface.ismutable, S.parameters) +@generated function ArrayInterface.ismutable(::Type{<:ArrayPartition{T,S}}) where {T,S} + res = all(ArrayInterface.ismutable, S.parameters) + return :( $res ) end ## vector space operations