@@ -21,78 +21,104 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21
21
layout (set = 0 , binding = 2 ) uniform PRECISION sampler3D kernel_in;
22
22
layout (set = 0 , binding = 3 ) uniform PRECISION sampler3D bias_in;
23
23
24
- layout (set = 0 , binding = 4 ) uniform PRECISION restrict Out_channels {
25
- int data;
26
- }
27
- out_channels;
28
-
29
- layout (set = 0 , binding = 5 ) uniform PRECISION restrict In_length {
30
- int data;
31
- }
32
- in_length;
33
-
34
- layout (set = 0 , binding = 6 ) uniform PRECISION restrict Kernel_size {
35
- int data;
36
- }
37
- kernel_size;
24
+ layout (set = 0 , binding = 4 ) uniform PRECISION restrict OutLimits {
25
+ ivec3 out_limits;
26
+ };
27
+
28
+ layout (set = 0 , binding = 5 ) uniform PRECISION restrict InSizes {
29
+ ivec4 in_sizes;
30
+ };
31
+
32
+ layout (set = 0 , binding = 6 ) uniform PRECISION restrict Params {
33
+ int kernel_size;
34
+ int stride;
35
+ int padding;
36
+ int dilation;
37
+ int in_group_size;
38
+ int out_group_size;
39
+ };
38
40
39
41
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
40
42
41
- /*
42
- * This implementation optimize for simplicity (and partially performance) for a
43
- * (1, C, L) where C == groups. Hence we only focus on calculating the rolling
44
- * kernel of the L dimension.
45
- */
43
+ // Let us define
44
+ //
45
+ // input = (N, in_C, in_L),
46
+ // output = (N, out_C, out_L),
47
+ // groups = G,
48
+ // kernel = K,
49
+ //
50
+ // which results in shapes
51
+ //
52
+ // weight = (out_C, in_C / G, K),
53
+ // bias = (out_C,).
54
+ //
55
+ // This implementation performs out_C shader invocations, where each invocation
56
+ // calculates the rolling kernel of the length dimension for each batch, i.e.,
57
+ // computes out_L * N results.
58
+ //
59
+ // Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
60
+ // shader invocations, where each invocation computes 1 result. But that
61
+ // performs worse.
46
62
void main() {
47
63
const ivec3 pos = ivec3 (gl_GlobalInvocationID);
48
64
49
- // The global workgroup should have taken care of it. We only perform one
50
- // work item for each 1d tensor on lengths
51
- if (pos.x >= 1 ) {
65
+ if (any (greaterThanEqual (pos, out_limits))) {
52
66
return ;
53
67
}
54
68
55
- int c = pos.y;
56
- if (c >= out_channels.data) {
57
- return ;
58
- }
59
-
60
- // Assume n = 1, do not handle n > 1 case for now.
61
- int n = pos.z;
62
- if (n >= 1 ) {
63
- return ;
64
- }
65
-
66
- vec4 bias = texelFetch(bias_in, ivec3 (c, 0 , 0 ), 0 );
67
-
68
- for (int i = 0 ; i < in_length.data - kernel_size.data + 1 ; ++ i) {
69
- vec4 v = vec4 (0 );
70
- for (int k = 0 ; k < kernel_size.data; ++ k) {
71
- const ivec3 in_pos = ivec3 (i+ k, c, 0 );
72
- const vec4 input_value = texelFetch(image_in, in_pos, 0 );
73
-
74
- // Note that we are reading weight in the inner loop, this could be
75
- // improved by moving it before the outer loop. Since the weight vector is
76
- // contant for the entire call.
77
-
78
- // weight in input-space: (c, 0, k);
79
- // notice that c is 4-packed. We need to mod 4 to get the actual weight.
80
- const ivec3 w_pos = ivec3 (k, 0 , c / 4 );
81
- const vec4 weight = texelFetch(kernel_in, w_pos, 0 );
82
-
83
- float w = weight.x;
84
- if (c % 4 == 1 ) {
85
- w = weight.y;
86
- } else if (c % 4 == 2 ) {
87
- w = weight.z;
88
- } else if (c % 4 == 3 ) {
89
- w = weight.w;
69
+ int in_length = in_sizes.x;
70
+ int batch_size = in_sizes.z;
71
+
72
+ // "out_c" is the output's channel index where we write our result.
73
+ // Across shader invocations, this is the only value that varies.
74
+ int out_c = pos.y;
75
+ vec4 bias = texelFetch(bias_in, ivec3 (out_c, 0 , 0 ), 0 );
76
+
77
+ // "in_c" tracks the input's channel start index.
78
+ // We iterate over the input group that corresponds to the output group.
79
+ int c_start = (out_c / out_group_size) * in_group_size;
80
+ int c_end = c_start + in_group_size;
81
+
82
+ // "in_l" tracks the input's length start index for our input-kernel overlay
83
+ // region.
84
+ int l_start = - padding;
85
+ int l_end = in_length + padding - dilation * (kernel_size - 1 );
86
+
87
+ // Since the input/output tensors are channel-packed, which is along the
88
+ // batch dimension, we can batch-read/write four elements at a time.
89
+ for (int n = 0 ; n < batch_size; n += 4 ) {
90
+ // "out_l" tracks the output's length index where we write our result.
91
+ int out_l = 0 ;
92
+
93
+ for (int in_l = l_start; in_l < l_end; in_l += stride, ++ out_l) {
94
+ vec4 sum = vec4 (0 );
95
+
96
+ for (int in_c = c_start; in_c < c_end; ++ in_c) {
97
+ // "k" tracks the kernel's index for our input-kernel computation.
98
+ // It reads out-of-bound zeros, but trying to avoid them complicates
99
+ // for-loop conditions, which results in worse performance.
100
+ for (int k = 0 ; k < kernel_size; k += 4 ) {
101
+ // Since the weight tensor is width-packed, which is along the length
102
+ // dimension, we can batch-read four elements at a time.
103
+ const ivec3 w_pos = ivec3 (k / 4 , in_c % in_group_size, out_c);
104
+ const vec4 weight = texelFetch(kernel_in, w_pos, 0 );
105
+
106
+ const ivec3 in_pos_0 = ivec3 (in_l + k * dilation, in_c, n / 4 );
107
+ sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0 ), sum);
108
+
109
+ const ivec3 in_pos_1 = ivec3 (in_l + (k+ 1 ) * dilation, in_c, n / 4 );
110
+ sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0 ), sum);
111
+
112
+ const ivec3 in_pos_2 = ivec3 (in_l + (k+ 2 ) * dilation, in_c, n / 4 );
113
+ sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0 ), sum);
114
+
115
+ const ivec3 in_pos_3 = ivec3 (in_l + (k+ 3 ) * dilation, in_c, n / 4 );
116
+ sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0 ), sum);
117
+ }
90
118
}
91
119
92
- v += w * input_value.x;
120
+ ivec3 out_pos = ivec3 (out_l, out_c, n / 4 );
121
+ imageStore(image_out, out_pos, sum + bias.x);
93
122
}
94
-
95
- ivec3 out_pos = ivec3 (i, c, 0 );
96
- imageStore(image_out, out_pos, vec4 (v.x + bias.x, 0 , 0 , 0 ));
97
123
}
98
124
}
0 commit comments