From 6877ebfcdac92b0b931ff42616f67dfe3aa1518f Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 9 Sep 2023 15:10:08 +0200 Subject: [PATCH 1/2] Call user function only once in `mean` Override the standard `mapreduce` machinery to promote accumulator type. This avoid calling the function twice, which can be confusing. --- src/Statistics.jl | 33 ++++++++++++++++++++++++++++++--- test/runtests.jl | 2 +- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/Statistics.jl b/src/Statistics.jl index 560b227d..613c2593 100644 --- a/src/Statistics.jl +++ b/src/Statistics.jl @@ -44,6 +44,8 @@ if !isdefined(Base, :mean) """ mean(itr) = mean(identity, itr) + _mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y) + """ mean(f, itr) @@ -178,7 +180,33 @@ if !isdefined(Base, :mean) """ mean(A::AbstractArray; dims=:) = _mean(identity, A, dims) - _mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y) + struct _InitType end + + Base.add_sum(x::_InitType, y::Any) = y/1 + + Base._mapreduce_dim(f, op, ::_InitType, A::Base.AbstractArrayOrBroadcasted, dims) = + Base.mapreducedim!(f, op, Base.reducedim_init(f, op, A, dims), A) + Base._mapreduce_dim(f, op, ::_InitType, A::Base.AbstractArrayOrBroadcasted, ::Colon) = + Base.mapfoldl_impl(f, op, _InitType(), A) + promote_add(x::T, y::S) where {T,S} = + Base.add_sum(convert(promote_type(T, S), x), + convert(promote_type(T, S), y)) + + function Base.reducedim_init(f, op::typeof(promote_add), A::AbstractArray, region) + Base._reducedim_init(f, op, zero, mean, A, region) + end + function Base._reducedim_init(f, op::typeof(promote_add), fv, fop, A, region) + T = Base._realtype(f, Base.promote_union(eltype(A))) + if T !== Any && applicable(zero, T) + x = f(zero(T)/1) + z = op(fv(x), fv(x)) + Tr = z isa T ? T : typeof(z) + else + z = fv(fop(f, A)) + Tr = typeof(z) + end + return Base.reducedim_initarray(A, region, z, Tr) + end # ::Dims is there to force specializing on Colon (as it is a Function) function _mean(f, A::AbstractArray, dims::Dims=:) where Dims @@ -188,8 +216,7 @@ if !isdefined(Base, :mean) else n = mapreduce(i -> size(A, i), *, unique(dims); init=1) end - x1 = f(first(A)) / 1 - result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims) + result = mapreduce(f, promote_add, A, dims=dims, init=_InitType()) if dims === (:) return result / n else diff --git a/test/runtests.jl b/test/runtests.jl index 6903f90b..3e44c301 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -161,7 +161,7 @@ end ≈ float(typemax(Int))) end let x = rand(10000) # mean should use sum's accurate pairwise algorithm - @test mean(x) == sum(x) / length(x) + @test mean(x) == sum(x; init=0.0) / length(x) end @test mean(Number[1, 1.5, 2+3im]) === 1.5+1im # mixed-type array @test mean(v for v in Number[1, 1.5, 2+3im]) === 1.5+1im From 1595ca65a14e316798472e995d914518e2833a23 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 9 Sep 2023 17:41:41 +0200 Subject: [PATCH 2/2] Alternative, simpler approach --- src/Statistics.jl | 21 ++++++++------------- test/runtests.jl | 2 +- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/Statistics.jl b/src/Statistics.jl index 613c2593..cdab5990 100644 --- a/src/Statistics.jl +++ b/src/Statistics.jl @@ -180,17 +180,12 @@ if !isdefined(Base, :mean) """ mean(A::AbstractArray; dims=:) = _mean(identity, A, dims) - struct _InitType end - - Base.add_sum(x::_InitType, y::Any) = y/1 - - Base._mapreduce_dim(f, op, ::_InitType, A::Base.AbstractArrayOrBroadcasted, dims) = - Base.mapreducedim!(f, op, Base.reducedim_init(f, op, A, dims), A) - Base._mapreduce_dim(f, op, ::_InitType, A::Base.AbstractArrayOrBroadcasted, ::Colon) = - Base.mapfoldl_impl(f, op, _InitType(), A) - promote_add(x::T, y::S) where {T,S} = - Base.add_sum(convert(promote_type(T, S), x), - convert(promote_type(T, S), y)) + promote_add_type(x::S, y::T) where {S, T} = + promote_type(typeof(zero(S)/1), typeof(zero(T)/1)) + function promote_add(x::Any, y::Any) + T = promote_add_type(x, y) + return Base.add_sum(convert(T, x), convert(T, y)) + end function Base.reducedim_init(f, op::typeof(promote_add), A::AbstractArray, region) Base._reducedim_init(f, op, zero, mean, A, region) @@ -198,7 +193,7 @@ if !isdefined(Base, :mean) function Base._reducedim_init(f, op::typeof(promote_add), fv, fop, A, region) T = Base._realtype(f, Base.promote_union(eltype(A))) if T !== Any && applicable(zero, T) - x = f(zero(T)/1) + x = f(zero(T))/1 # /1 added for mean z = op(fv(x), fv(x)) Tr = z isa T ? T : typeof(z) else @@ -216,7 +211,7 @@ if !isdefined(Base, :mean) else n = mapreduce(i -> size(A, i), *, unique(dims); init=1) end - result = mapreduce(f, promote_add, A, dims=dims, init=_InitType()) + result = mapreduce(f, promote_add, A, dims=dims) if dims === (:) return result / n else diff --git a/test/runtests.jl b/test/runtests.jl index 3e44c301..6903f90b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -161,7 +161,7 @@ end ≈ float(typemax(Int))) end let x = rand(10000) # mean should use sum's accurate pairwise algorithm - @test mean(x) == sum(x; init=0.0) / length(x) + @test mean(x) == sum(x) / length(x) end @test mean(Number[1, 1.5, 2+3im]) === 1.5+1im # mixed-type array @test mean(v for v in Number[1, 1.5, 2+3im]) === 1.5+1im