From 28202e3084f2413c5b5c53f9909f19369de6f5f5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 8 May 2021 15:24:10 -0400 Subject: [PATCH 1/4] allow dims::Tuple in sum --- src/mapreduce.jl | 6 ++++++ test/mapreduce.jl | 3 +++ 2 files changed, 9 insertions(+) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 202dddf7..7b92b2d4 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -158,6 +158,12 @@ end end end +@inline function _mapreduce(f, op, D::Tuple, init, sz::Size{S}, a::StaticArray) where {S} + b = _mapreduce(f, op, first(D), init, sz, a) + return _mapreduce(f, op, Base.tail(D), init, Size(b), b) +end +_mapreduce(f, op, D::Tuple{}, init, sz::Size{S}, a::StaticArray) where {S} = a + @generated function _mapfoldl(f, op, dims::Val{D}, init, ::Size{S}, a::StaticArray) where {S,D} N = length(S) diff --git a/test/mapreduce.jl b/test/mapreduce.jl index 5c7c78ca..78fb96ad 100644 --- a/test/mapreduce.jl +++ b/test/mapreduce.jl @@ -102,6 +102,7 @@ using Statistics: mean RSArray1 = SArray{Tuple{1,J,K}} # reduced in dimension 1 RSArray2 = SArray{Tuple{I,1,K}} # reduced in dimension 2 RSArray3 = SArray{Tuple{I,J,1}} # reduced in dimension 3 + RSArray13 = SArray{Tuple{1,J,1}} # reduced in dimension 1 and 3 a = randn(I,J,K); sa = OSArray(a) b = rand(Bool,I,J,K); sb = OSArray(b) z = zeros(I,J,K); sz = OSArray(z) @@ -111,9 +112,11 @@ using Statistics: mean @test sum(sa) === sum(a) @test sum(abs2, sa) === sum(abs2, a) @test sum(sa, dims=2) === RSArray2(sum(a, dims=2)) + @test sum(sa, dims=(2,)) === RSArray2(sum(a, dims=2)) @test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2)) @test sum(abs2, sa; dims=2) === RSArray2(sum(abs2, a, dims=2)) @test sum(abs2, sa; dims=Val(2)) === RSArray2(sum(abs2, a, dims=2)) + @test_broken sum(abs2, sa; dims=(1,3)) === RSArray13(sum(abs2, a, dims=(1,3))) @test prod(sa) === prod(a) @test prod(abs2, sa) === prod(abs2, a) From 7a663281472f033b9e69d5ddb200db918b599aa1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 8 May 2021 15:46:33 -0400 Subject: [PATCH 2/4] multiple dimensions for reduce only, not mapreduce --- src/mapreduce.jl | 15 ++++++++++----- test/mapreduce.jl | 3 ++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 7b92b2d4..3da430b1 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -158,11 +158,8 @@ end end end -@inline function _mapreduce(f, op, D::Tuple, init, sz::Size{S}, a::StaticArray) where {S} - b = _mapreduce(f, op, first(D), init, sz, a) - return _mapreduce(f, op, Base.tail(D), init, Size(b), b) -end -_mapreduce(f, op, D::Tuple{}, init, sz::Size{S}, a::StaticArray) where {S} = a +@inline _mapreduce(f, op, D::Tuple{<:Any}, init, sz::Size{S}, a::StaticArray) where {S} = + _mapreduce(f, op, first(D), init, sz, a) @generated function _mapfoldl(f, op, dims::Val{D}, init, ::Size{S}, a::StaticArray) where {S,D} @@ -215,6 +212,14 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = @inline _reduce(op, a::StaticArray, dims, init = _InitialValue()) = _mapreduce(identity, op, dims, init, Size(a), a) +@inline function _reduce(op, a::StaticArray, dims::Tuple, init = _InitialValue()) + b = _reduce(op, a, first(dims), init) + return _reduce(op, b, Base.tail(dims)) +end +_reduce(op, a::StaticArray, dims::Tuple{}, ::_InitialValue) = a +_reduce(op, a::StaticArray, dims::Tuple{}, init) = op.(init, a) + + ################ ## (map)foldl ## ################ diff --git a/test/mapreduce.jl b/test/mapreduce.jl index 78fb96ad..ec58a321 100644 --- a/test/mapreduce.jl +++ b/test/mapreduce.jl @@ -114,9 +114,10 @@ using Statistics: mean @test sum(sa, dims=2) === RSArray2(sum(a, dims=2)) @test sum(sa, dims=(2,)) === RSArray2(sum(a, dims=2)) @test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2)) + @test sum(sa, dims=(1,3)) === RSArray13(sum(a, dims=(1,3))) @test sum(abs2, sa; dims=2) === RSArray2(sum(abs2, a, dims=2)) + @test sum(abs2, sa; dims=(2,)) === RSArray2(sum(abs2, a, dims=2)) @test sum(abs2, sa; dims=Val(2)) === RSArray2(sum(abs2, a, dims=2)) - @test_broken sum(abs2, sa; dims=(1,3)) === RSArray13(sum(abs2, a, dims=(1,3))) @test prod(sa) === prod(a) @test prod(abs2, sa) === prod(abs2, a) From 1a4aa623fe0462c3f280fdebfb1ff7d048eb2b20 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 8 May 2021 15:58:18 -0400 Subject: [PATCH 3/4] fix tests? --- src/mapreduce.jl | 4 ++-- test/mapreduce.jl | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 3da430b1..0b747e0b 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -213,8 +213,8 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = _mapreduce(identity, op, dims, init, Size(a), a) @inline function _reduce(op, a::StaticArray, dims::Tuple, init = _InitialValue()) - b = _reduce(op, a, first(dims), init) - return _reduce(op, b, Base.tail(dims)) + b = _reduce(op, a, first(dims)) + return _reduce(op, b, Base.tail(dims), init) end _reduce(op, a::StaticArray, dims::Tuple{}, ::_InitialValue) = a _reduce(op, a::StaticArray, dims::Tuple{}, init) = op.(init, a) diff --git a/test/mapreduce.jl b/test/mapreduce.jl index ec58a321..8158fdba 100644 --- a/test/mapreduce.jl +++ b/test/mapreduce.jl @@ -37,9 +37,12 @@ using Statistics: mean v2 = [4,3,2,1]; sv2 = SVector{4}(v2) @test reduce(+, sv1) === reduce(+, v1) @test reduce(+, sv1; init=0) === reduce(+, v1; init=0) + @test reduce(+, sv1; init=99) === reduce(+, v1; init=99) @test reduce(max, sa; dims=Val(1), init=-1.) === SMatrix{1,J}(reduce(max, a, dims=1, init=-1.)) @test reduce(max, sa; dims=1, init=-1.) === SMatrix{1,J}(reduce(max, a, dims=1, init=-1.)) @test reduce(max, sa; dims=2, init=-1.) === SMatrix{I,1}(reduce(max, a, dims=2, init=-1.)) + @test reduce(*, sa; dims=(1,2), init=2.0) === SMatrix{1,1}(reduce(*, a, dims=(1,2), init=2.0)) + @test reduce(*, sa; dims=(), init=(1.0+im)) === SMatrix{I,J}(reduce(*, a, dims=(), init=(1.0+im))) @test mapreduce(-, +, sv1) === mapreduce(-, +, v1) @test mapreduce(-, +, sv1; init=0) === mapreduce(-, +, v1, init=0) @test mapreduce(*, +, sv1, sv2) === 40 @@ -47,6 +50,11 @@ using Statistics: mean @test mapreduce(x->x^2, max, sa; dims=Val(1), init=-1.) == SMatrix{1,J}(mapreduce(x->x^2, max, a, dims=1, init=-1.)) @test mapreduce(x->x^2, max, sa; dims=1, init=-1.) == SMatrix{1,J}(mapreduce(x->x^2, max, a, dims=1, init=-1.)) @test mapreduce(x->x^2, max, sa; dims=2, init=-1.) == SMatrix{I,1}(mapreduce(x->x^2, max, a, dims=2, init=-1.)) + + # Zero-dim array + a0 = fill(rand()); sa0 = SArray{Tuple{}}(a0) + @test reduce(+, sa0) === reduce(+, a0) + @test reduce(/, sa0, dims=(), init=1.2) === SArray{Tuple{}}(reduce(/, a0, dims=(), init=1.2)) end @testset "[map]foldl" begin From 4a986db8b8765702bc56e1d80481bfb4f6f9e64f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 8 May 2021 20:58:10 -0400 Subject: [PATCH 4/4] one isapprox --- test/mapreduce.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mapreduce.jl b/test/mapreduce.jl index 8158fdba..9cacc21a 100644 --- a/test/mapreduce.jl +++ b/test/mapreduce.jl @@ -41,7 +41,7 @@ using Statistics: mean @test reduce(max, sa; dims=Val(1), init=-1.) === SMatrix{1,J}(reduce(max, a, dims=1, init=-1.)) @test reduce(max, sa; dims=1, init=-1.) === SMatrix{1,J}(reduce(max, a, dims=1, init=-1.)) @test reduce(max, sa; dims=2, init=-1.) === SMatrix{I,1}(reduce(max, a, dims=2, init=-1.)) - @test reduce(*, sa; dims=(1,2), init=2.0) === SMatrix{1,1}(reduce(*, a, dims=(1,2), init=2.0)) + @test reduce(*, sa; dims=(1,2), init=2.0) ≈ SMatrix{1,1}(reduce(*, a, dims=(1,2), init=2.0)) @test reduce(*, sa; dims=(), init=(1.0+im)) === SMatrix{I,J}(reduce(*, a, dims=(), init=(1.0+im))) @test mapreduce(-, +, sv1) === mapreduce(-, +, v1) @test mapreduce(-, +, sv1; init=0) === mapreduce(-, +, v1, init=0)