Skip to content

Commit 8719310

Browse files
committed
Minor cleanup and refactoring
1 parent ba0e222 commit 8719310

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

src/weights.jl

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -200,21 +200,35 @@ end
200200
@weights ExponentialWeights false
201201

202202
"""
203-
ExponentialWeights
203+
ExponentialWeights(vs)
204204
205-
# Fields
205+
Construct an `ExponentialWeights` vector with weight values `vs`, which must sum to 1.
206206
207-
* `λ::Float64`: is a smoothing factor or rate paremeter between 0 .. 1.
208-
As this value approaches 0 the resulting weights will be almost equal(),
209-
while values closer to 1 will put higher weight on the end elements of the vector.
207+
Exponential weights are a common form of temporal weights which assign exponentially
208+
greater weight to past observations, which in this case corresponds to the tail end of
209+
the vector.
210+
"""
211+
function ExponentialWeights(vs::V) where {T<:Real, V<:AbstractVector{T}}
212+
s = sum(vs)
213+
s one(T) || throw(ArgumentError("weight values do not sum to 1 (got $s)"))
214+
ExponentialWeights{T, T, V}(vs, s)
215+
end
210216

211-
When called with a desired length `n` (`Int`) a vector of length `n` will
212-
be returned, where each element is set to `λ * (1 - λ)^(1 - i)`.
217+
"""
218+
eweights(n, λ)
213219
214-
# Usage
220+
Construct an [`ExponentialWeights`](@ref) vector with length `n`,
221+
where each element in position ``i`` is set to ``λ (1 - λ)^{1 - i}``.
222+
The entire set of weights are then normalized to sum to 1.
215223
216-
```julia
217-
w = ExponentialWeights(10, 0.3)
224+
``λ`` is a smoothing factor or rate parameter such that ``0 < λ \\leq 1``.
225+
As this value approaches 0, the resulting weights will be almost equal,
226+
while values closer to 1 will put greater weight on the tail elements of the vector.
227+
228+
# Examples
229+
230+
```julia-repl
231+
julia> eweights(10, 0.3)
218232
10-element ExponentialWeights{Float64,Float64,Array{Float64,1}}:
219233
0.012458
220234
0.0177971
@@ -228,41 +242,35 @@ w = ExponentialWeights(10, 0.3)
228242
0.308721
229243
```
230244
"""
231-
function ExponentialWeights(vs::V) where {T<:Real, V<:AbstractVector{T}}
232-
s = sum(vs)
233-
s one(T) || throw(ArgumentError("weight values do not sum to 1 (got $s)"))
234-
ExponentialWeights{T, T, V}(vs, s)
235-
end
236-
237-
function ExponentialWeights(n::Integer, λ::Real)
238-
n > 0 || throw(ArgumentError("cannot construct weights of length < 1"))
245+
function eweights(n::Integer, λ::Real)
246+
n > 0 || throw(ArgumentError("cannot construct exponential weights of length < 1"))
239247
0 < λ <= 1 || throw(ArgumentError("smoothing factor must be between 0 and 1"))
240248
w0 = map(i -> λ * (1 - λ)^(1 - i), 1:n)
241249
s = sum(w0)
242-
ExponentialWeights(w0 / s)
250+
w0 ./= s
251+
ExponentialWeights{typeof(s), eltype(w0), typeof(w0)}(w0, s)
243252
end
244253

245254
"""
246-
eweights(n, λ)
247-
248-
Construct an `ExponentialWeights` vector with length `n`,
249-
where each element in position ``i`` is set to ``λ * (1 - λ)^(1 - i)``.
250-
The entire set of weights are then normalized so that they sum to 1.0
255+
eweights(vs)
251256
252-
``λ`` is a smoothing factor or rate parameter between 0 and 1.
253-
As this value approaches 0 the resulting weights will be almost equal,
254-
while values closer to 1 will put higher weight on the end elements of the vector.
257+
Construct an [`ExponentialWeights`](@ref) vector using the given array.
255258
"""
256-
eweights(n::Integer, λ::Real) = ExponentialWeights(n, λ)
259+
eweights(v::RealVector) = ExponentialWeights(v)
260+
eweights(v::RealArray) = ExponentialWeights(vec(v))
257261

258262
"""
259263
varcorrection(w::ExponentialWeights, corrected=false)
260264
261265
* `corrected=true`: ``\\frac{1}{1 - \\sum {w^2}}``
262-
* `corrected=false`: ``1.0``
266+
* `corrected=false`: ``1``
263267
"""
264268
@inline function varcorrection(w::ExponentialWeights, corrected::Bool=false)
265-
corrected ? 1 / (1 - sum(x -> x^2, w)) : 1.0
269+
if corrected
270+
1 / (1 - sum(abs2, w.values))
271+
else
272+
1 / one(w.sum) # just 1 promoted to the same type as the other branch
273+
end
266274
end
267275

268276
##### Equality tests #####

test/weights.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ end
504504
θ = 5.25
505505
λ = 1 - exp(-1 / θ) # simple conversion for the more common/readable method
506506

507-
w = ExponentialWeights(4, λ)
507+
v =*(1-λ)^(1-i) for i = 1:4]
508+
w = ExponentialWeights(v ./ sum(v))
508509

509510
@test round.(w, digits=4) == [0.1837, 0.2222, 0.2688, 0.3253]
510511
@test eweights(4, λ) w
@@ -513,6 +514,8 @@ end
513514
@testset "Failure Conditions" begin
514515
@test_throws ArgumentError eweights(0, 0.3)
515516
@test_throws ArgumentError eweights(1, 1.1)
517+
@test_throws ArgumentError eweights(rand(4))
518+
@test_throws ArgumentError eweights(rand(4, 4))
516519
@test_throws ArgumentError ExponentialWeights(rand(4))
517520
end
518521

0 commit comments

Comments
 (0)