7
7
8
8
#include " core/providers/cuda/cu_inc/common.cuh"
9
9
#include " core/providers/cuda/shared_inc/fast_divmod.h"
10
+ #include " core/providers/cuda/shared_inc/cuda_utils.h"
10
11
11
12
namespace onnxruntime {
12
13
namespace cuda {
13
- template <typename T>
14
+ template <typename T, bool Layout >
14
15
__global__ void MaxPoolWithIndexKernel (
15
16
int64_t batch,
16
17
int64_t channels,
@@ -44,11 +45,27 @@ __global__ void MaxPoolWithIndexKernel(
44
45
int id = blockIdx .x * blockDim .x + threadIdx .x ;
45
46
if (id >= output_size) return ;
46
47
48
+ auto compute_offset =
49
+ [height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t {
50
+ if constexpr (Layout == LAYOUT_NCHW) {
51
+ return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index;
52
+ } else if constexpr (Layout == LAYOUT_NHWC) {
53
+ return (((n_index * height + h_index) * width + w_index) * depth + d_index) * channels + c_index;
54
+ }
55
+ };
56
+
47
57
int d_index, w_index, h_index, c_index, n_index, id_tmp;
48
- fdm_d.divmod (id, id_tmp, d_index);
49
- fdm_w.divmod (id_tmp, id_tmp, w_index);
50
- fdm_h.divmod (id_tmp, id_tmp, h_index);
51
- fdm_c.divmod (id_tmp, n_index, c_index);
58
+ if constexpr (Layout == LAYOUT_NCHW) {
59
+ fdm_d.divmod (id, id_tmp, d_index);
60
+ fdm_w.divmod (id_tmp, id_tmp, w_index);
61
+ fdm_h.divmod (id_tmp, id_tmp, h_index);
62
+ fdm_c.divmod (id_tmp, n_index, c_index);
63
+ } else if constexpr (Layout == LAYOUT_NHWC) {
64
+ fdm_c.divmod (id, id_tmp, c_index);
65
+ fdm_d.divmod (id_tmp, id_tmp, d_index);
66
+ fdm_w.divmod (id_tmp, id_tmp, w_index);
67
+ fdm_h.divmod (id_tmp, n_index, h_index);
68
+ }
52
69
53
70
int64_t d_start = d_index * stride_d - pad_d;
54
71
int64_t w_start = w_index * stride_w - pad_w;
@@ -64,29 +81,45 @@ __global__ void MaxPoolWithIndexKernel(
64
81
int64_t d_index_max = -1 ;
65
82
int64_t w_index_max = -1 ;
66
83
int64_t h_index_max = -1 ;
67
- int64_t offset = (n_index * channels + c_index) * height * width * depth ;
84
+ int64_t offset = compute_offset (n_index, c_index, 0 , 0 , 0 ) ;
68
85
const T* p_slice = p_input + offset;
69
- T maxval = p_slice[h_start * width * depth + w_start * depth + d_start] - (T)1 ;
86
+ T maxval = p_slice[compute_offset ( 0 , 0 , h_start, w_start, d_start) ] - (T)1 ;
70
87
for (int64_t d = d_start; d < d_end; d += dilation_d) {
71
88
for (int64_t w = w_start; w < w_end; w += dilation_w) {
72
89
for (int64_t h = h_start; h < h_end; h += dilation_h) {
73
- if (p_slice[h * width * depth + w * depth + d] > maxval) {
90
+ auto pool_offset = compute_offset (0 , 0 , h, w, d);
91
+ if (p_slice[pool_offset] > maxval) {
74
92
h_index_max = h;
75
93
w_index_max = w;
76
94
d_index_max = d;
77
- maxval = static_cast <float >(p_slice[h * width * depth + w * depth + d ]);
95
+ maxval = static_cast <float >(p_slice[pool_offset ]);
78
96
}
79
97
}
80
98
}
81
99
}
82
- p_output[id] = p_input[offset + h_index_max * width * depth + w_index_max * depth + d_index_max];
100
+ p_output[id] = p_input[offset + compute_offset (0 , 0 , h_index_max, w_index_max, d_index_max)];
101
+
83
102
if (p_indices) {
84
- p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max
85
- : offset + h_index_max + w_index_max * height + d_index_max * width * height;
103
+ if constexpr (Layout == LAYOUT_NCHW) {
104
+ p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max
105
+ : offset + h_index_max + w_index_max * height + d_index_max * width * height;
106
+ } else if constexpr (Layout == LAYOUT_NHWC) {
107
+ // The tests currently have to be provided in NHWC layout so that tests do not fail. When converting between
108
+ // layouts, does it make sense to do an index conversion as well?
109
+ // Storing indices in NHWC layout isn't critical as they are supposed to be used by Unpooling operations
110
+ // which currently assume that indices reference to Tensors in NHWC layout.
111
+ int64_t id_nchw =
112
+ (((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index;
113
+ int64_t offset_nchw = (n_index * channels + c_index) * width * height * depth;
114
+
115
+ p_indices[id_nchw] = (storage_order == 0 )
116
+ ? offset_nchw + h_index_max * width * depth + w_index_max * depth + d_index_max
117
+ : offset_nchw + h_index_max + w_index_max * height + d_index_max * width * height;
118
+ }
86
119
}
87
120
}
88
121
89
- template <typename T>
122
+ template <typename T, bool Layout >
90
123
void MaxPoolWithIndex (
91
124
cudaStream_t stream,
92
125
const TensorShape& input_shape,
@@ -99,14 +132,29 @@ void MaxPoolWithIndex(
99
132
const T* p_input,
100
133
T* p_output,
101
134
int64_t * p_indices) {
102
- int64_t batchs = input_shape[0 ];
103
- int64_t channels = input_shape[1 ];
104
- int64_t height = input_shape[2 ];
105
- int64_t width = kernel_shape.size () > 1 ? input_shape[3 ] : 1 ;
106
- int64_t depth = kernel_shape.size () > 2 ? input_shape[4 ] : 1 ;
107
- int64_t pooled_height = output_shape[2 ];
108
- int64_t pooled_width = kernel_shape.size () > 1 ? output_shape[3 ] : 1 ;
109
- int64_t pooled_depth = kernel_shape.size () > 2 ? output_shape[4 ] : 1 ;
135
+ int64_t batchs, channels, height, width, depth;
136
+ int64_t pooled_height, pooled_width, pooled_depth;
137
+ if constexpr (Layout == LAYOUT_NCHW) {
138
+ batchs = input_shape[0 ];
139
+ channels = input_shape[1 ];
140
+ height = input_shape[2 ];
141
+ width = kernel_shape.size () > 1 ? input_shape[3 ] : 1 ;
142
+ depth = kernel_shape.size () > 2 ? input_shape[4 ] : 1 ;
143
+
144
+ pooled_height = output_shape[2 ];
145
+ pooled_width = kernel_shape.size () > 1 ? output_shape[3 ] : 1 ;
146
+ pooled_depth = kernel_shape.size () > 2 ? output_shape[4 ] : 1 ;
147
+ } else if constexpr (Layout == LAYOUT_NHWC) {
148
+ batchs = input_shape[0 ];
149
+ height = input_shape[1 ];
150
+ width = kernel_shape.size () > 1 ? input_shape[2 ] : 1 ;
151
+ depth = kernel_shape.size () > 2 ? input_shape[3 ] : 1 ;
152
+ channels = input_shape[input_shape.NumDimensions () - 1 ];
153
+
154
+ pooled_height = output_shape[1 ];
155
+ pooled_width = kernel_shape.size () > 1 ? output_shape[2 ] : 1 ;
156
+ pooled_depth = kernel_shape.size () > 2 ? output_shape[3 ] : 1 ;
157
+ }
110
158
int64_t kernel_h = kernel_shape[0 ];
111
159
int64_t kernel_w = kernel_shape.size () > 1 ? kernel_shape[1 ] : 1 ;
112
160
int64_t kernel_d = kernel_shape.size () > 2 ? kernel_shape[2 ] : 1 ;
@@ -130,7 +178,7 @@ void MaxPoolWithIndex(
130
178
fast_divmod fdm_d (static_cast <int >(pooled_depth));
131
179
132
180
int blocksPerGrid = (int )((output_size + GridDim::maxThreadsPerBlock - 1 ) / GridDim::maxThreadsPerBlock);
133
- MaxPoolWithIndexKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0 , stream>>> (
181
+ MaxPoolWithIndexKernel<T, Layout> < <<blocksPerGrid, GridDim::maxThreadsPerBlock, 0 , stream>>> (
134
182
batchs,
135
183
channels,
136
184
height,
@@ -162,8 +210,8 @@ void MaxPoolWithIndex(
162
210
p_indices);
163
211
}
164
212
165
- #define INSTANTIATEMAXPOOLWITHINDEX (T ) \
166
- template void MaxPoolWithIndex<T>( \
213
+ #define INSTANTIATEMAXPOOLWITHINDEX (T, Layout ) \
214
+ template void MaxPoolWithIndex<T, Layout>( \
167
215
cudaStream_t stream, \
168
216
const TensorShape& input_shape, \
169
217
const TensorShape& output_shape, \
@@ -176,11 +224,19 @@ void MaxPoolWithIndex(
176
224
T* p_output, \
177
225
int64_t * p_indices);
178
226
179
- INSTANTIATEMAXPOOLWITHINDEX (float )
180
- INSTANTIATEMAXPOOLWITHINDEX (double )
181
- INSTANTIATEMAXPOOLWITHINDEX (half)
182
- INSTANTIATEMAXPOOLWITHINDEX (int8_t )
183
- INSTANTIATEMAXPOOLWITHINDEX (uint8_t )
227
+ INSTANTIATEMAXPOOLWITHINDEX (float , LAYOUT_NCHW)
228
+ INSTANTIATEMAXPOOLWITHINDEX (double , LAYOUT_NCHW)
229
+ INSTANTIATEMAXPOOLWITHINDEX (half, LAYOUT_NCHW)
230
+ INSTANTIATEMAXPOOLWITHINDEX (int8_t , LAYOUT_NCHW)
231
+ INSTANTIATEMAXPOOLWITHINDEX (uint8_t , LAYOUT_NCHW)
232
+
233
+ #ifdef ENABLE_CUDA_NHWC_OPS
234
+ INSTANTIATEMAXPOOLWITHINDEX (float , LAYOUT_NHWC)
235
+ INSTANTIATEMAXPOOLWITHINDEX(double , LAYOUT_NHWC)
236
+ INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NHWC)
237
+ INSTANTIATEMAXPOOLWITHINDEX(int8_t , LAYOUT_NHWC)
238
+ INSTANTIATEMAXPOOLWITHINDEX(uint8_t , LAYOUT_NHWC)
239
+ #endif
184
240
185
241
} // namespace cuda
186
242
} // namespace onnxruntime
0 commit comments