Skip to content

Commit 134ac15

Browse files
authored
Merge pull request #237 from tejank10/scalar_pad_stride
Scalar pad and stride
2 parents 7e3cf45 + 7726a5b commit 134ac15

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/layers/conv.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
using NNlib: conv
22

3+
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
4+
5+
expand(N, i::Tuple) = i
6+
expand(N, i::Integer) = ntuple(_ -> i, N)
7+
38
"""
49
Conv(size, in=>out)
510
Conv(size, in=>out, relu)
@@ -21,14 +26,12 @@ struct Conv{N,F,A,V}
2126
dilation::NTuple{N,Int}
2227
end
2328

24-
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
25-
stride = 1, pad = 0, dilation=1) where T =
26-
Conv(σ, w, b, stride, pad, dilation)
29+
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
30+
stride = 1, pad = 0, dilation = 1) where {T,N} =
31+
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
2732

2833
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
29-
stride::NTuple{N,Integer} = map(_->1,k),
30-
pad::NTuple{N,Integer} = map(_->0,k),
31-
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
34+
stride = 1, pad = 0, dilation = 1) where N =
3235
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
3336
stride = stride, pad = pad, dilation = dilation)
3437

0 commit comments

Comments
 (0)