diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 202dddf7..0b747e0b 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -158,6 +158,9 @@ end end end +@inline _mapreduce(f, op, D::Tuple{<:Any}, init, sz::Size{S}, a::StaticArray) where {S} = + _mapreduce(f, op, first(D), init, sz, a) + @generated function _mapfoldl(f, op, dims::Val{D}, init, ::Size{S}, a::StaticArray) where {S,D} N = length(S) @@ -209,6 +212,14 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = @inline _reduce(op, a::StaticArray, dims, init = _InitialValue()) = _mapreduce(identity, op, dims, init, Size(a), a) +@inline function _reduce(op, a::StaticArray, dims::Tuple, init = _InitialValue()) + b = _reduce(op, a, first(dims)) + return _reduce(op, b, Base.tail(dims), init) +end +_reduce(op, a::StaticArray, dims::Tuple{}, ::_InitialValue) = a +_reduce(op, a::StaticArray, dims::Tuple{}, init) = op.(init, a) + + ################ ## (map)foldl ## ################ diff --git a/test/mapreduce.jl b/test/mapreduce.jl index 5c7c78ca..9cacc21a 100644 --- a/test/mapreduce.jl +++ b/test/mapreduce.jl @@ -37,9 +37,12 @@ using Statistics: mean v2 = [4,3,2,1]; sv2 = SVector{4}(v2) @test reduce(+, sv1) === reduce(+, v1) @test reduce(+, sv1; init=0) === reduce(+, v1; init=0) + @test reduce(+, sv1; init=99) === reduce(+, v1; init=99) @test reduce(max, sa; dims=Val(1), init=-1.) === SMatrix{1,J}(reduce(max, a, dims=1, init=-1.)) @test reduce(max, sa; dims=1, init=-1.) === SMatrix{1,J}(reduce(max, a, dims=1, init=-1.)) @test reduce(max, sa; dims=2, init=-1.) === SMatrix{I,1}(reduce(max, a, dims=2, init=-1.)) + @test reduce(*, sa; dims=(1,2), init=2.0) ≈ SMatrix{1,1}(reduce(*, a, dims=(1,2), init=2.0)) + @test reduce(*, sa; dims=(), init=(1.0+im)) === SMatrix{I,J}(reduce(*, a, dims=(), init=(1.0+im))) @test mapreduce(-, +, sv1) === mapreduce(-, +, v1) @test mapreduce(-, +, sv1; init=0) === mapreduce(-, +, v1, init=0) @test mapreduce(*, +, sv1, sv2) === 40 @@ -47,6 +50,11 @@ using Statistics: mean @test mapreduce(x->x^2, max, sa; dims=Val(1), init=-1.) == SMatrix{1,J}(mapreduce(x->x^2, max, a, dims=1, init=-1.)) @test mapreduce(x->x^2, max, sa; dims=1, init=-1.) == SMatrix{1,J}(mapreduce(x->x^2, max, a, dims=1, init=-1.)) @test mapreduce(x->x^2, max, sa; dims=2, init=-1.) == SMatrix{I,1}(mapreduce(x->x^2, max, a, dims=2, init=-1.)) + + # Zero-dim array + a0 = fill(rand()); sa0 = SArray{Tuple{}}(a0) + @test reduce(+, sa0) === reduce(+, a0) + @test reduce(/, sa0, dims=(), init=1.2) === SArray{Tuple{}}(reduce(/, a0, dims=(), init=1.2)) end @testset "[map]foldl" begin @@ -102,6 +110,7 @@ using Statistics: mean RSArray1 = SArray{Tuple{1,J,K}} # reduced in dimension 1 RSArray2 = SArray{Tuple{I,1,K}} # reduced in dimension 2 RSArray3 = SArray{Tuple{I,J,1}} # reduced in dimension 3 + RSArray13 = SArray{Tuple{1,J,1}} # reduced in dimension 1 and 3 a = randn(I,J,K); sa = OSArray(a) b = rand(Bool,I,J,K); sb = OSArray(b) z = zeros(I,J,K); sz = OSArray(z) @@ -111,8 +120,11 @@ using Statistics: mean @test sum(sa) === sum(a) @test sum(abs2, sa) === sum(abs2, a) @test sum(sa, dims=2) === RSArray2(sum(a, dims=2)) + @test sum(sa, dims=(2,)) === RSArray2(sum(a, dims=2)) @test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2)) + @test sum(sa, dims=(1,3)) === RSArray13(sum(a, dims=(1,3))) @test sum(abs2, sa; dims=2) === RSArray2(sum(abs2, a, dims=2)) + @test sum(abs2, sa; dims=(2,)) === RSArray2(sum(abs2, a, dims=2)) @test sum(abs2, sa; dims=Val(2)) === RSArray2(sum(abs2, a, dims=2)) @test prod(sa) === prod(a)