5
5
depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0)
6
6
7
7
Perform a depthwise convolution using im2col and GEMM, store the result in `y`.
8
-
9
- See `conv_im2col!()` for an explanation of optional parameters.
8
+ See [`conv_im2col!`](@ref) for explanation of optional parameters.
10
9
"""
11
10
depthwiseconv_im2col!
12
11
@@ -48,27 +47,32 @@ function depthwiseconv_im2col!(
48
47
end
49
48
50
49
"""
51
- ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta)
50
+ ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims));
51
+ alpha=1, beta=0)
52
52
53
- Depthwise conv2d backward pass onto the weights using im2col and GEMM.
54
- See the documentation for `conv_im2col!()` for explanation of optional parameters.
53
+ Depthwise conv backward pass onto the weights using im2col and GEMM.
54
+ See [ `conv_im2col!`](@ref) for explanation of optional parameters.
55
55
"""
56
56
∇depthwiseconv_filter_im2col!
57
57
58
58
function ∇depthwiseconv_filter_im2col! (
59
59
dw:: AbstractArray{T,5} , x:: AbstractArray{T,5} ,
60
60
dy:: AbstractArray{T,5} , cdims:: DepthwiseConvDims ;
61
- col:: AbstractArray{T,3} = similar (dw, im2col_dims (cdims)),
61
+ col:: AbstractArray{T,3} = similar (dw, ∇filter_im2col_dims (cdims)),
62
62
alpha:: T = T (1 ), beta:: T = T (0 )) where T
63
63
check_dims (size (x), size (dw), size (dy), cdims)
64
64
65
65
M = prod (kernel_size (cdims))
66
66
N = channel_multiplier (cdims)
67
67
K = prod (output_size (cdims))
68
68
69
- @threads for batch_idx in 1 : size (x)[end ]
69
+ for batch_idx in 1 : size (x, 5 )
70
+ # Because we accumulate over batches in this loop, we must set `beta` equal
71
+ # to `1.0` after the first sample.
72
+ beta′ = batch_idx == 1 ? beta : T (1 )
73
+
70
74
# col_slice is a thread-local workspace
71
- col_slice = view (col, :, :, threadid () )
75
+ col_slice = view (col, :, :, 1 )
72
76
im2col! (col_slice, view (x, :, :, :, :, batch_idx), cdims)
73
77
74
78
# We do a separate convolution for each channel in x, as we must
@@ -78,22 +82,18 @@ function ∇depthwiseconv_filter_im2col!(
78
82
col_ptr = pointer (col_slice, (c_in - 1 )* M* K + 1 )
79
83
dy_ptr = pointer (dy, (batch_idx - 1 )* N* K* channels_in (cdims) + (c_in - 1 )* K* N + 1 )
80
84
dw_ptr = pointer (dw, (c_in - 1 )* M* N + 1 )
81
- gemm! (Val (true ), Val (false ), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
85
+ gemm! (Val (true ), Val (false ), M, N, K, alpha, col_ptr, dy_ptr, beta′ , dw_ptr)
82
86
end
83
87
end
84
-
85
- # Because we accumulate over batches in this loop, we must set `beta` equal
86
- # to `1.0` from this point on.
87
- beta = T (1 )
88
88
end
89
89
return dw
90
90
end
91
91
92
92
"""
93
- depthwiseconv2d_Δx_im2col !(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)
93
+ ∇depthwiseconv_data_im2col !(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)
94
94
95
95
Depwthwise conv2d backward pass onto the input using im2col and GEMM.
96
- See the documentation for `conv_im2col!()` for explanation of optional parameters.
96
+ See [ `conv_im2col!`](@ref) for explanation of optional parameters.
97
97
"""
98
98
∇depthwiseconv_data_im2col!
99
99
0 commit comments