Skip to content

Commit b039107

Browse files
lbittarellonalimilan
authored andcommitted
Simplify weights (#526)
1 parent 95b794a commit b039107

File tree

3 files changed

+151
-147
lines changed

3 files changed

+151
-147
lines changed

src/deprecates.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,7 @@ end
3535
@deprecate wmedian(v::RealVector, w::RealVector) median(v, weights(w))
3636

3737
@deprecate quantile(v::AbstractArray{<:Real}) quantile(v, [.0, .25, .5, .75, 1.0])
38+
39+
### Deprecated September 2019
40+
@deprecate sum(A::AbstractArray, w::AbstractWeights, dims::Int) sum(A, w, dims=dims)
41+
@deprecate values(wv::AbstractWeights) convert(Vector, wv)

src/weights.jl

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
###### Weight vector #####
1+
##### Weight vector #####
22
abstract type AbstractWeights{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractVector{T} end
33

44
"""
@@ -18,12 +18,24 @@ macro weights(name)
1818
end
1919

2020
length(wv::AbstractWeights) = length(wv.values)
21-
values(wv::AbstractWeights) = wv.values
2221
sum(wv::AbstractWeights) = wv.sum
2322
isempty(wv::AbstractWeights) = isempty(wv.values)
2423
size(wv::AbstractWeights) = size(wv.values)
2524

26-
Base.getindex(wv::AbstractWeights, i) = getindex(wv.values, i)
25+
Base.convert(::Type{Vector}, wv::AbstractWeights) = convert(Vector, wv.values)
26+
27+
@propagate_inbounds function Base.getindex(wv::AbstractWeights, i::Integer)
28+
@boundscheck checkbounds(wv, i)
29+
@inbounds wv.values[i]
30+
end
31+
32+
@propagate_inbounds function Base.getindex(wv::W, i::AbstractArray) where W <: AbstractWeights
33+
@boundscheck checkbounds(wv, i)
34+
@inbounds v = wv.values[i]
35+
W(v, sum(v))
36+
end
37+
38+
Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), sum(wv))
2739

2840
@propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int)
2941
s = v - wv[i]
@@ -247,7 +259,7 @@ eweights(n::Integer, λ::Real) = eweights(1:n, λ)
247259
eweights(t::AbstractVector, r::AbstractRange, λ::Real) =
248260
eweights(something.(indexin(t, r)), λ)
249261

250-
# NOTE: No variance correction is implemented for exponential weights
262+
# NOTE: no variance correction is implemented for exponential weights
251263

252264
struct UnitWeights{T<:Real} <: AbstractWeights{Int, T, V where V<:Vector{T}}
253265
len::Int
@@ -260,23 +272,24 @@ Construct a `UnitWeights` vector with length `s` and weight elements of type `T`
260272
All weight elements are identically one.
261273
""" UnitWeights
262274

263-
values(wv::UnitWeights{T}) where T = fill(one(T), length(wv))
264275
sum(wv::UnitWeights{T}) where T = convert(T, length(wv))
265276
isempty(wv::UnitWeights) = iszero(wv.len)
266277
length(wv::UnitWeights) = wv.len
267278
size(wv::UnitWeights) = Tuple(length(wv))
268279

280+
Base.convert(::Type{Vector}, wv::UnitWeights{T}) where {T} = ones(T, length(wv))
281+
269282
@propagate_inbounds function Base.getindex(wv::UnitWeights{T}, i::Integer) where T
270283
@boundscheck checkbounds(wv, i)
271284
one(T)
272285
end
273286

274287
@propagate_inbounds function Base.getindex(wv::UnitWeights{T}, i::AbstractArray{<:Int}) where T
275288
@boundscheck checkbounds(wv, i)
276-
fill(one(T), size(i))
289+
UnitWeights{T}(length(i))
277290
end
278291

279-
Base.getindex(wv::UnitWeights{T}, ::Colon) where T = fill(one(T), length(wv))
292+
Base.getindex(wv::UnitWeights{T}, ::Colon) where {T} = UnitWeights{T}(wv.len)
280293

281294
"""
282295
uweights(s::Integer)
@@ -315,7 +328,7 @@ This definition is equivalent to the correction applied to unweighted data.
315328
corrected ? (1 / (w.len - 1)) : (1 / w.len)
316329
end
317330

318-
##### Equality tests #####
331+
#### Equality tests #####
319332

320333
for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights)
321334
@eval begin
@@ -341,22 +354,7 @@ Compute the weighted sum of an array `v` with weights `w`, optionally over the d
341354
"""
342355
wsum(v::AbstractVector, w::AbstractVector) = dot(v, w)
343356
wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w)
344-
345-
# Note: the methods for BitArray and SparseMatrixCSC are to avoid ambiguities
346-
Base.sum(v::BitArray, w::AbstractWeights) = wsum(v, values(w))
347-
Base.sum(v::SparseArrays.SparseMatrixCSC, w::AbstractWeights) = wsum(v, values(w))
348-
Base.sum(v::AbstractArray, w::AbstractWeights) = dot(v, values(w))
349-
350-
for v in (AbstractArray{<:Number}, BitArray, SparseArrays.SparseMatrixCSC, AbstractArray)
351-
@eval begin
352-
function Base.sum(v::$v, w::UnitWeights)
353-
if length(v) != length(w)
354-
throw(DimensionMismatch("Inconsistent array dimension."))
355-
end
356-
return sum(v)
357-
end
358-
end
359-
end
357+
wsum(v::AbstractArray, w::AbstractVector, dims::Colon) = wsum(v, w)
360358

361359
## wsum along dimension
362360
#
@@ -392,7 +390,6 @@ end
392390
# (d) A is a general dense array with eltype <: BlasReal:
393391
# dim <= 2: delegate to (a) and (b)
394392
# otherwise, decompose A into multiple pages
395-
#
396393

397394
function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
398395
r = wsum(A, w)
@@ -455,7 +452,8 @@ function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, di
455452
return R
456453
end
457454

458-
# General Cartesian-based weighted sum across dimensions
455+
## general Cartesian-based weighted sum across dimensions
456+
459457
@generated function _wsum_general!(R::AbstractArray{RT}, f::supertype(typeof(abs)),
460458
A::AbstractArray{T,N}, w::AbstractVector{WT}, dim::Int, init::Bool) where {T,RT,WT,N}
461459
quote
@@ -512,7 +510,6 @@ end
512510
end
513511
end
514512

515-
516513
# N = 1
517514
_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
518515
_wsum1!(R, A, w, init)
@@ -533,7 +530,6 @@ _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bo
533530
wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W))
534531
wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T
535532

536-
537533
"""
538534
wsum!(R, A, w, dim; init=true)
539535
@@ -559,19 +555,21 @@ function wsum(A::AbstractArray{<:Number}, w::UnitWeights, dim::Int)
559555
return sum(A, dims=dim)
560556
end
561557

562-
# extended sum! and wsum
558+
## extended sum! and wsum
563559

564560
Base.sum!(R::AbstractArray, A::AbstractArray, w::AbstractWeights{<:Real}, dim::Int; init::Bool=true) =
565-
wsum!(R, A, values(w), dim; init=init)
561+
wsum!(R, A, w, dim; init=init)
566562

567-
Base.sum(A::AbstractArray{<:Number}, w::AbstractWeights{<:Real}, dim::Int) = wsum(A, values(w), dim)
563+
Base.sum(A::AbstractArray, w::AbstractWeights{<:Real}; dims::Union{Colon,Int}=:) =
564+
wsum(A, w, dims)
568565

569-
function Base.sum(A::AbstractArray{<:Number}, w::UnitWeights, dim::Int)
570-
size(A, dim) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
571-
return sum(A, dims=dim)
566+
function Base.sum(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
567+
a = (dims === :) ? length(A) : size(A, dims)
568+
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
569+
return sum(A, dims=dims)
572570
end
573571

574-
###### Weighted means #####
572+
##### Weighted means #####
575573

576574
"""
577575
wmean(v, w::AbstractVector)
@@ -589,9 +587,10 @@ end
589587
Compute the weighted mean of array `A` with weight vector `w`
590588
(of type `AbstractWeights`) along dimension `dims`, and write results to `R`.
591589
"""
592-
mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights;
593-
dims::Union{Nothing,Int}=nothing) = _mean!(R, A, w, dims)
594-
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Nothing) = throw(ArgumentError("dims argument must be provided"))
590+
mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights; dims::Union{Nothing,Int}=nothing) =
591+
_mean!(R, A, w, dims)
592+
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Nothing) =
593+
throw(ArgumentError("dims argument must be provided"))
595594
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) =
596595
rmul!(Base.sum!(R, A, w, dims), inv(sum(w)))
597596

@@ -611,24 +610,21 @@ w = rand(n)
611610
mean(x, weights(w))
612611
```
613612
"""
614-
mean(A::AbstractArray, w::AbstractWeights; dims::Union{Nothing,Int}=nothing) =
613+
mean(A::AbstractArray, w::AbstractWeights; dims::Union{Colon,Int}=:) =
615614
_mean(A, w, dims)
616-
_mean(A::AbstractArray, w::AbstractWeights, dims::Nothing) =
615+
_mean(A::AbstractArray, w::AbstractWeights, dims::Colon) =
617616
sum(A, w) / sum(w)
618617
_mean(A::AbstractArray{T}, w::AbstractWeights{W}, dims::Int) where {T,W} =
619618
_mean!(similar(A, wmeantype(T, W), Base.reduced_indices(axes(A), dims)), A, w, dims)
620619

621-
function _mean(A::AbstractArray, w::UnitWeights, dims::Nothing)
622-
length(A) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
623-
return mean(A)
624-
end
625-
626-
function _mean(A::AbstractArray, w::UnitWeights, dims::Int)
627-
size(A, dims) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
620+
function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
621+
a = (dims === :) ? length(A) : size(A, dims)
622+
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
628623
return mean(A, dims=dims)
629624
end
630625

631-
###### Weighted quantile #####
626+
##### Weighted quantile #####
627+
632628
"""
633629
quantile(v, w::AbstractWeights, p)
634630
@@ -723,9 +719,8 @@ end
723719

724720
quantile(v::RealVector, w::AbstractWeights{<:Real}, p::Number) = quantile(v, w, [p])[1]
725721

722+
##### Weighted median #####
726723

727-
728-
###### Weighted median #####
729724
"""
730725
median(v::RealVector, w::AbstractWeights)
731726

0 commit comments

Comments
 (0)