Skip to content

Commit 901356d

Browse files
committed
Fix a bug in CUDA pool op
1 parent bdf678d commit 901356d

File tree

1 file changed

+6
-2
lines changed
  • onnxruntime/core/providers/cuda/nn

1 file changed

+6
-2
lines changed

onnxruntime/core/providers/cuda/nn/pool.cc

+6-2
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,20 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
182182
if (NHWC) {
183183
x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, 1);
184184
y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, 1);
185+
ORT_ENFORCE(pads.size() >= kernel_shape.size());
186+
pads.insert(pads.begin() + kernel_shape.size(), 0);
187+
pads.insert(pads.end(), 0);
185188
kernel_shape.insert(kernel_shape.begin() + 1, 1);
186189
strides.insert(strides.begin() + 1, 1);
187190
} else {
188191
x_dims_cudnn.push_back(1);
189192
y_dims_cudnn.push_back(1);
193+
ORT_ENFORCE(pads.size() >= kernel_shape.size());
194+
pads.insert(pads.begin() + kernel_shape.size(), 0);
195+
pads.insert(pads.end(), 0);
190196
kernel_shape.push_back(1);
191197
strides.push_back(1);
192198
}
193-
pads.insert(pads.begin() + kernel_shape.size(), 0);
194-
pads.insert(pads.end(), 0);
195199
}
196200

197201
cudnnPoolingMode_t mode = CUDNN_POOLING_MAX;

0 commit comments

Comments
 (0)