1
1
using NNlib: conv
2
2
3
+ @generated sub2 (:: Type{Val{N}} ) where N = :(Val{$ (N- 2 )})
4
+
5
+ expand (N, i:: Tuple ) = i
6
+ expand (N, i:: Integer ) = ntuple (_ -> i, N)
7
+
3
8
"""
4
9
Conv(size, in=>out)
5
10
Conv(size, in=>out, relu)
@@ -21,14 +26,12 @@ struct Conv{N,F,A,V}
21
26
dilation:: NTuple{N,Int}
22
27
end
23
28
24
- Conv (w:: AbstractArray{T} , b:: AbstractVector{T} , σ = identity;
25
- stride = 1 , pad = 0 , dilation= 1 ) where T =
26
- Conv (σ, w, b, stride, pad, dilation)
29
+ Conv (w:: AbstractArray{T,N } , b:: AbstractVector{T} , σ = identity;
30
+ stride = 1 , pad = 0 , dilation = 1 ) where {T,N} =
31
+ Conv (σ, w, b, expand .( sub2 (Val{N}), ( stride, pad, dilation)) ... )
27
32
28
33
Conv (k:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity; init = initn,
29
- stride:: NTuple{N,Integer} = map (_-> 1 ,k),
30
- pad:: NTuple{N,Integer} = map (_-> 0 ,k),
31
- dilation:: NTuple{N,Integer} = map (_-> 1 ,k)) where N =
34
+ stride = 1 , pad = 0 , dilation = 1 ) where N =
32
35
Conv (param (init (k... , ch... )), param (zeros (ch[2 ])), σ,
33
36
stride = stride, pad = pad, dilation = dilation)
34
37
0 commit comments