Skip to content

Commit 3d01356

Browse files
committed
also for normalisation layers
1 parent c6265f7 commit 3d01356

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

src/layers/normalise.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,12 @@ testmode!(m::AlphaDropout, mode=true) =
179179
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
180180

181181
"""
182-
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
182+
LayerNorm(size..., λ=identity; affine=true, eps=1f-5)
183183
184184
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
185185
used with recurrent hidden states.
186186
The argument `size` should be an integer or a tuple of integers.
187+
187188
In the forward pass, the layer normalises the mean and standard
188189
deviation of the input, then applies the elementwise activation `λ`.
189190
The input is normalised along the first `length(size)` dimensions
@@ -217,9 +218,10 @@ struct LayerNorm{F,D,T,N}
217218
affine::Bool
218219
end
219220

220-
function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
221+
function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5, ϵ=nothing)
222+
ε = Losses._greek_ascii_depwarn=> eps, :LayerNorm, "ϵ" => "eps")
221223
diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity
222-
return LayerNorm(λ, diag, ϵ, size, affine)
224+
return LayerNorm(λ, diag, ε, size, affine)
223225
end
224226
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
225227
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)
@@ -287,7 +289,7 @@ ChainRulesCore.@non_differentiable _track_stats!(::Any...)
287289
BatchNorm(channels::Integer, λ=identity;
288290
initβ=zeros32, initγ=ones32,
289291
affine = true, track_stats = true,
290-
ϵ=1f-5, momentum= 0.1f0)
292+
eps=1f-5, momentum= 0.1f0)
291293
292294
[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
293295
`channels` should be the size of the channel dimension in your data (see below).
@@ -340,15 +342,17 @@ end
340342
function BatchNorm(chs::Int, λ=identity;
341343
initβ=zeros32, initγ=ones32,
342344
affine=true, track_stats=true,
343-
ϵ=1f-5, momentum=0.1f0)
345+
eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing)
346+
347+
ε = Losses._greek_ascii_depwarn=> eps, :BatchNorm, "ϵ" => "eps")
344348

345349
β = affine ? initβ(chs) : nothing
346350
γ = affine ? initγ(chs) : nothing
347351
μ = track_stats ? zeros32(chs) : nothing
348352
σ² = track_stats ? ones32(chs) : nothing
349353

350354
return BatchNorm(λ, β, γ,
351-
μ, σ², ϵ, momentum,
355+
μ, σ², ε, momentum,
352356
affine, track_stats,
353357
nothing, chs)
354358
end
@@ -379,7 +383,7 @@ end
379383
InstanceNorm(channels::Integer, λ=identity;
380384
initβ=zeros32, initγ=ones32,
381385
affine=false, track_stats=false,
382-
ϵ=1f-5, momentum=0.1f0)
386+
eps=1f-5, momentum=0.1f0)
383387
384388
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
385389
`channels` should be the size of the channel dimension in your data (see below).
@@ -430,19 +434,20 @@ end
430434
function InstanceNorm(chs::Int, λ=identity;
431435
initβ=zeros32, initγ=ones32,
432436
affine=false, track_stats=false,
433-
ϵ=1f-5, momentum=0.1f0)
437+
eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing)
434438

435439
if track_stats
436440
Base.depwarn("`track_stats=true` will be removed from InstanceNorm in Flux 0.14. The default value is `track_stats=false`, which will work as before.", :InstanceNorm)
437441
end
442+
ε = Losses._greek_ascii_depwarn=> eps, :InstanceNorm, "ϵ" => "eps")
438443

439444
β = affine ? initβ(chs) : nothing
440445
γ = affine ? initγ(chs) : nothing
441446
μ = track_stats ? zeros32(chs) : nothing
442447
σ² = track_stats ? ones32(chs) : nothing
443448

444449
return InstanceNorm(λ, β, γ,
445-
μ, σ², ϵ, momentum,
450+
μ, σ², ε, momentum,
446451
affine, track_stats,
447452
nothing, chs)
448453
end
@@ -473,7 +478,7 @@ end
473478
GroupNorm(channels::Integer, G::Integer, λ=identity;
474479
initβ=zeros32, initγ=ones32,
475480
affine=true, track_stats=false,
476-
ϵ=1f-5, momentum=0.1f0)
481+
eps=1f-5, momentum=0.1f0)
477482
478483
[Group Normalization](https://arxiv.org/abs/1803.08494) layer.
479484
@@ -532,11 +537,12 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)
532537
function GroupNorm(chs::Int, G::Int, λ=identity;
533538
initβ=zeros32, initγ=ones32,
534539
affine=true, track_stats=false,
535-
ϵ=1f-5, momentum=0.1f0)
540+
eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing)
536541

537-
if track_stats
542+
if track_stats
538543
Base.depwarn("`track_stats=true` will be removed from GroupNorm in Flux 0.14. The default value is `track_stats=false`, which will work as before.", :GroupNorm)
539-
end
544+
end
545+
ε = Losses._greek_ascii_depwarn=> eps, :GroupNorm, "ϵ" => "eps")
540546

541547
chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)")
542548

@@ -548,7 +554,7 @@ end
548554
return GroupNorm(G, λ,
549555
β, γ,
550556
μ, σ²,
551-
ϵ, momentum,
557+
ε, momentum,
552558
affine, track_stats,
553559
nothing, chs)
554560
end

test/losses.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
22
using Flux: onehotbatch, σ
3+
using Statistics: mean
34

45
using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy
56
using Flux.Losses: xlogx, xlogy

0 commit comments

Comments
 (0)