Skip to content

Commit 247f8c5

Browse files
mtavenrathtianleiwu
authored and
rachguo
committed
Fix broken Pooling CUDA NHWC Ops and ensure NCHW / NHWC parity. (#19889)
Fixed all CUDA NHWC Pooling operations which were broken and enabled the NHWC CUDA pooling tests. Disabled all pooling tests which are not supported by the CUDA EP. Ensure parity between CUDA NHWC / NCHW and work towards 100% tests enabled for the CUDA EP / CUDA NHWC EP. --------- Co-authored-by: Tianlei Wu <[email protected]>
1 parent a13e5d5 commit 247f8c5

File tree

8 files changed

+253
-98
lines changed

8 files changed

+253
-98
lines changed

onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc

+7
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kM
7070
MaxPool);
7171
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool);
7272
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool);
73+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, int8_t, MaxPool);
74+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, uint8_t, MaxPool);
7375
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float,
7476
BatchNormalization);
7577
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16,
@@ -135,6 +137,7 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
135137
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, MaxPool)>,
136138
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
137139
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, MaxPool)>,
140+
138141
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
139142
float, AveragePool)>,
140143
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
@@ -147,6 +150,10 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
147150
float, MaxPool)>,
148151
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
149152
MLFloat16, MaxPool)>,
153+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
154+
int8_t, MaxPool)>,
155+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
156+
uint8_t, MaxPool)>,
150157
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
151158
float, ConvTranspose)>,
152159
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,

onnxruntime/core/providers/cuda/cudnn_common.cc

+22-7
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,28 @@ Status CudnnTensor::Set(gsl::span<const int64_t> input_dims, cudnnDataType_t dat
3737
TensorPitches pitches(input_dims);
3838
InlinedVector<int, kTensorShapeSmallBufferElementsSize> dims(rank);
3939
InlinedVector<int, kTensorShapeSmallBufferElementsSize> strides(rank);
40-
for (int i = 0; i < rank; i++) {
41-
dims[i] = gsl::narrow_cast<int>(input_dims[i]);
42-
strides[i] = gsl::narrow_cast<int>(pitches[i]);
43-
}
44-
if (is_nhwc) {
45-
std::swap(dims[1], dims[rank - 1]);
46-
std::swap(strides[1], strides[rank - 1]);
40+
41+
if (!is_nhwc) {
42+
for (int i = 0; i < rank; i++) {
43+
dims[i] = gsl::narrow_cast<int>(input_dims[i]);
44+
strides[i] = gsl::narrow_cast<int>(pitches[i]);
45+
}
46+
} else {
47+
// NHWDC <-> NCHWD
48+
49+
// N
50+
dims[0] = gsl::narrow_cast<int>(input_dims[0]);
51+
strides[0] = gsl::narrow_cast<int>(pitches[0]);
52+
53+
// HWD
54+
for (int i = 1; i < rank - 1; i++) {
55+
dims[i + 1] = gsl::narrow_cast<int>(input_dims[i]);
56+
strides[i + 1] = gsl::narrow_cast<int>(pitches[i]);
57+
}
58+
59+
// C
60+
dims[1] = gsl::narrow_cast<int>(input_dims[rank - 1]);
61+
strides[1] = gsl::narrow_cast<int>(pitches[rank - 1]);
4762
}
4863
CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast<int>(rank), dims.data(), strides.data()));
4964
return Status::OK();

onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu

+85-29
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
#include "core/providers/cuda/cu_inc/common.cuh"
99
#include "core/providers/cuda/shared_inc/fast_divmod.h"
10+
#include "core/providers/cuda/shared_inc/cuda_utils.h"
1011

1112
namespace onnxruntime {
1213
namespace cuda {
13-
template <typename T>
14+
template <typename T, bool Layout>
1415
__global__ void MaxPoolWithIndexKernel(
1516
int64_t batch,
1617
int64_t channels,
@@ -44,11 +45,27 @@ __global__ void MaxPoolWithIndexKernel(
4445
int id = blockIdx.x * blockDim.x + threadIdx.x;
4546
if (id >= output_size) return;
4647

48+
auto compute_offset =
49+
[height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t {
50+
if constexpr (Layout == LAYOUT_NCHW) {
51+
return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index;
52+
} else if constexpr (Layout == LAYOUT_NHWC) {
53+
return (((n_index * height + h_index) * width + w_index) * depth + d_index) * channels + c_index;
54+
}
55+
};
56+
4757
int d_index, w_index, h_index, c_index, n_index, id_tmp;
48-
fdm_d.divmod(id, id_tmp, d_index);
49-
fdm_w.divmod(id_tmp, id_tmp, w_index);
50-
fdm_h.divmod(id_tmp, id_tmp, h_index);
51-
fdm_c.divmod(id_tmp, n_index, c_index);
58+
if constexpr (Layout == LAYOUT_NCHW) {
59+
fdm_d.divmod(id, id_tmp, d_index);
60+
fdm_w.divmod(id_tmp, id_tmp, w_index);
61+
fdm_h.divmod(id_tmp, id_tmp, h_index);
62+
fdm_c.divmod(id_tmp, n_index, c_index);
63+
} else if constexpr (Layout == LAYOUT_NHWC) {
64+
fdm_c.divmod(id, id_tmp, c_index);
65+
fdm_d.divmod(id_tmp, id_tmp, d_index);
66+
fdm_w.divmod(id_tmp, id_tmp, w_index);
67+
fdm_h.divmod(id_tmp, n_index, h_index);
68+
}
5269

5370
int64_t d_start = d_index * stride_d - pad_d;
5471
int64_t w_start = w_index * stride_w - pad_w;
@@ -64,29 +81,45 @@ __global__ void MaxPoolWithIndexKernel(
6481
int64_t d_index_max = -1;
6582
int64_t w_index_max = -1;
6683
int64_t h_index_max = -1;
67-
int64_t offset = (n_index * channels + c_index) * height * width * depth;
84+
int64_t offset = compute_offset(n_index, c_index, 0, 0, 0);
6885
const T* p_slice = p_input + offset;
69-
T maxval = p_slice[h_start * width * depth + w_start * depth + d_start] - (T)1;
86+
T maxval = p_slice[compute_offset(0, 0, h_start, w_start, d_start)] - (T)1;
7087
for (int64_t d = d_start; d < d_end; d += dilation_d) {
7188
for (int64_t w = w_start; w < w_end; w += dilation_w) {
7289
for (int64_t h = h_start; h < h_end; h += dilation_h) {
73-
if (p_slice[h * width * depth + w * depth + d] > maxval) {
90+
auto pool_offset = compute_offset(0, 0, h, w, d);
91+
if (p_slice[pool_offset] > maxval) {
7492
h_index_max = h;
7593
w_index_max = w;
7694
d_index_max = d;
77-
maxval = static_cast<float>(p_slice[h * width * depth + w * depth + d]);
95+
maxval = static_cast<float>(p_slice[pool_offset]);
7896
}
7997
}
8098
}
8199
}
82-
p_output[id] = p_input[offset + h_index_max * width * depth + w_index_max * depth + d_index_max];
100+
p_output[id] = p_input[offset + compute_offset(0, 0, h_index_max, w_index_max, d_index_max)];
101+
83102
if (p_indices) {
84-
p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max
85-
: offset + h_index_max + w_index_max * height + d_index_max * width * height;
103+
if constexpr (Layout == LAYOUT_NCHW) {
104+
p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max
105+
: offset + h_index_max + w_index_max * height + d_index_max * width * height;
106+
} else if constexpr (Layout == LAYOUT_NHWC) {
107+
// The tests currently have to be provided in NHWC layout so that tests do not fail. When converting between
108+
// layouts, does it make sense to do an index conversion as well?
109+
// Storing indices in NHWC layout isn't critical as they are supposed to be used by Unpooling operations
110+
// which currently assume that indices reference to Tensors in NHWC layout.
111+
int64_t id_nchw =
112+
(((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index;
113+
int64_t offset_nchw = (n_index * channels + c_index) * width * height * depth;
114+
115+
p_indices[id_nchw] = (storage_order == 0)
116+
? offset_nchw + h_index_max * width * depth + w_index_max * depth + d_index_max
117+
: offset_nchw + h_index_max + w_index_max * height + d_index_max * width * height;
118+
}
86119
}
87120
}
88121

89-
template <typename T>
122+
template <typename T, bool Layout>
90123
void MaxPoolWithIndex(
91124
cudaStream_t stream,
92125
const TensorShape& input_shape,
@@ -99,14 +132,29 @@ void MaxPoolWithIndex(
99132
const T* p_input,
100133
T* p_output,
101134
int64_t* p_indices) {
102-
int64_t batchs = input_shape[0];
103-
int64_t channels = input_shape[1];
104-
int64_t height = input_shape[2];
105-
int64_t width = kernel_shape.size() > 1 ? input_shape[3] : 1;
106-
int64_t depth = kernel_shape.size() > 2 ? input_shape[4] : 1;
107-
int64_t pooled_height = output_shape[2];
108-
int64_t pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1;
109-
int64_t pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1;
135+
int64_t batchs, channels, height, width, depth;
136+
int64_t pooled_height, pooled_width, pooled_depth;
137+
if constexpr (Layout == LAYOUT_NCHW) {
138+
batchs = input_shape[0];
139+
channels = input_shape[1];
140+
height = input_shape[2];
141+
width = kernel_shape.size() > 1 ? input_shape[3] : 1;
142+
depth = kernel_shape.size() > 2 ? input_shape[4] : 1;
143+
144+
pooled_height = output_shape[2];
145+
pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1;
146+
pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1;
147+
} else if constexpr (Layout == LAYOUT_NHWC) {
148+
batchs = input_shape[0];
149+
height = input_shape[1];
150+
width = kernel_shape.size() > 1 ? input_shape[2] : 1;
151+
depth = kernel_shape.size() > 2 ? input_shape[3] : 1;
152+
channels = input_shape[input_shape.NumDimensions() - 1];
153+
154+
pooled_height = output_shape[1];
155+
pooled_width = kernel_shape.size() > 1 ? output_shape[2] : 1;
156+
pooled_depth = kernel_shape.size() > 2 ? output_shape[3] : 1;
157+
}
110158
int64_t kernel_h = kernel_shape[0];
111159
int64_t kernel_w = kernel_shape.size() > 1 ? kernel_shape[1] : 1;
112160
int64_t kernel_d = kernel_shape.size() > 2 ? kernel_shape[2] : 1;
@@ -130,7 +178,7 @@ void MaxPoolWithIndex(
130178
fast_divmod fdm_d(static_cast<int>(pooled_depth));
131179

132180
int blocksPerGrid = (int)((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock);
133-
MaxPoolWithIndexKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
181+
MaxPoolWithIndexKernel<T, Layout><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
134182
batchs,
135183
channels,
136184
height,
@@ -162,8 +210,8 @@ void MaxPoolWithIndex(
162210
p_indices);
163211
}
164212

165-
#define INSTANTIATEMAXPOOLWITHINDEX(T) \
166-
template void MaxPoolWithIndex<T>( \
213+
#define INSTANTIATEMAXPOOLWITHINDEX(T, Layout) \
214+
template void MaxPoolWithIndex<T, Layout>( \
167215
cudaStream_t stream, \
168216
const TensorShape& input_shape, \
169217
const TensorShape& output_shape, \
@@ -176,11 +224,19 @@ void MaxPoolWithIndex(
176224
T* p_output, \
177225
int64_t* p_indices);
178226

179-
INSTANTIATEMAXPOOLWITHINDEX(float)
180-
INSTANTIATEMAXPOOLWITHINDEX(double)
181-
INSTANTIATEMAXPOOLWITHINDEX(half)
182-
INSTANTIATEMAXPOOLWITHINDEX(int8_t)
183-
INSTANTIATEMAXPOOLWITHINDEX(uint8_t)
227+
INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NCHW)
228+
INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NCHW)
229+
INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NCHW)
230+
INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NCHW)
231+
INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NCHW)
232+
233+
#ifdef ENABLE_CUDA_NHWC_OPS
234+
INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NHWC)
235+
INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NHWC)
236+
INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NHWC)
237+
INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NHWC)
238+
INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NHWC)
239+
#endif
184240

185241
} // namespace cuda
186242
} // namespace onnxruntime

onnxruntime/core/providers/cuda/nn/max_pool_with_index.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
namespace onnxruntime {
99
namespace cuda {
10-
template <typename T>
10+
template <typename T, bool Layout>
1111
void MaxPoolWithIndex(
1212
cudaStream_t stream,
1313
const TensorShape& input_shape,

0 commit comments

Comments
 (0)