Skip to content

Commit 082833f

Browse files
ssnlRob Kunkle
authored and
Rob Kunkle
committed
3d conv should use int64_t (pytorch#9274)
Summary: Fixes pytorch#9264 . There can be so many elements in the output of `vol2col` so it overflows `int` range! This PR changes 3d conv to use `int64_t` mostly. Also fixes some unused var warning (cc goldsborough ) Pull Request resolved: pytorch#9274 Differential Revision: D8770682 Pulled By: SsnL fbshipit-source-id: f6e37f1aa56fe1009dd4c9bcbc042244e47252db
1 parent 3a792a7 commit 082833f

7 files changed

+122
-124
lines changed

aten/src/THCUNN/VolumetricConvolution.cu

Lines changed: 66 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,30 @@
88
// Borrowed from Theano
99
// Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter, Nicolas Ballas
1010
template <typename Dtype>
11-
__global__ void im3d2col_kernel(const int n, const Dtype* data_im,
12-
const int height, const int width, const int depth,
13-
const int kernel_h, const int kernel_w, const int kernel_d,
14-
const int pad_h, const int pad_w, const int pad_d,
15-
const int stride_h, const int stride_w, const int stride_d,
16-
const int height_col, const int width_col, const int depth_col,
11+
__global__ void im3d2col_kernel(const int64_t n, const Dtype* data_im,
12+
const int64_t height, const int64_t width, const int64_t depth,
13+
const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_d,
14+
const int64_t pad_h, const int64_t pad_w, const int64_t pad_d,
15+
const int64_t stride_h, const int64_t stride_w, const int64_t stride_d,
16+
const int64_t height_col, const int64_t width_col, const int64_t depth_col,
1717
Dtype* data_col)
1818
{
1919
CUDA_KERNEL_LOOP(index, n)
2020
{
21-
int d_out = index % depth_col;
22-
int w_index = index / depth_col;
23-
int w_out = w_index % width_col;
24-
int h_index = w_index / width_col;
25-
int h_out = h_index % height_col;
21+
int64_t d_out = index % depth_col;
22+
int64_t w_index = index / depth_col;
23+
int64_t w_out = w_index % width_col;
24+
int64_t h_index = w_index / width_col;
25+
int64_t h_out = h_index % height_col;
2626

27-
int channel_in = h_index / height_col;
27+
int64_t channel_in = h_index / height_col;
2828
//channel_in = 1;
2929

30-
int channel_out = channel_in * kernel_h * kernel_w * kernel_d;
30+
int64_t channel_out = channel_in * kernel_h * kernel_w * kernel_d;
3131

32-
int h_in = h_out * stride_h - pad_h;
33-
int w_in = w_out * stride_w - pad_w;
34-
int d_in = d_out * stride_d - pad_d;
32+
int64_t h_in = h_out * stride_h - pad_h;
33+
int64_t w_in = w_out * stride_w - pad_w;
34+
int64_t d_in = d_out * stride_d - pad_d;
3535

3636
Dtype* data_col_ptr = data_col;
3737
data_col_ptr += channel_out * (height_col * width_col * depth_col) +
@@ -41,15 +41,15 @@ __global__ void im3d2col_kernel(const int n, const Dtype* data_im,
4141
data_im_ptr += channel_in * (height * width * depth) +
4242
h_in * (width * depth) + w_in * depth + d_in;
4343

44-
for (int i = 0; i < kernel_h; ++i)
44+
for (int64_t i = 0; i < kernel_h; ++i)
4545
{
46-
int h = h_in + i;
47-
for (int j = 0; j < kernel_w; ++j)
46+
int64_t h = h_in + i;
47+
for (int64_t j = 0; j < kernel_w; ++j)
4848
{
49-
int w = w_in + j;
50-
for (int k = 0; k < kernel_d; ++k)
49+
int64_t w = w_in + j;
50+
for (int64_t k = 0; k < kernel_d; ++k)
5151
{
52-
int d = d_in + k;
52+
int64_t d = d_in + k;
5353
*data_col_ptr = (h >= 0 && w >= 0 && d >= 0 &&
5454
h < height && w < width && d < depth) ?
5555
data_im_ptr[i * (width * depth) + j *depth + k] : ScalarConvert<int, Dtype>::to(0);
@@ -61,19 +61,19 @@ __global__ void im3d2col_kernel(const int n, const Dtype* data_im,
6161
}
6262

6363
template <typename Dtype>
64-
void im3d2col(cudaStream_t stream, const Dtype* data_im, const int channels,
65-
const int height, const int width, const int depth,
66-
const int kernel_h, const int kernel_w, const int kernel_d,
67-
const int pad_h, const int pad_w, const int pad_d,
68-
const int stride_h, const int stride_w, const int stride_d,
64+
void im3d2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
65+
const int64_t height, const int64_t width, const int64_t depth,
66+
const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_d,
67+
const int64_t pad_h, const int64_t pad_w, const int64_t pad_d,
68+
const int64_t stride_h, const int64_t stride_w, const int64_t stride_d,
6969
Dtype* data_col)
7070
{
7171
// We are going to launch channels * height_col * width_col * depth_col kernels, each
7272
// kernel responsible for copying a single-channel grid.
73-
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
74-
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
75-
int depth_col = (depth + 2 * pad_d - kernel_d) / stride_d + 1;
76-
int num_kernels = channels * height_col * width_col * depth_col;
73+
int64_t height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
74+
int64_t width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
75+
int64_t depth_col = (depth + 2 * pad_d - kernel_d) / stride_d + 1;
76+
int64_t num_kernels = channels * height_col * width_col * depth_col;
7777
im3d2col_kernel<<<GET_BLOCKS(num_kernels),
7878
CUDA_NUM_THREADS, 0, stream>>>(num_kernels, data_im,
7979
height, width, depth,
@@ -86,42 +86,42 @@ void im3d2col(cudaStream_t stream, const Dtype* data_im, const int channels,
8686
}
8787

8888
template <typename Dtype, typename Acctype>
89-
__global__ void col2im3d_kernel(const int n, const Dtype* data_col,
90-
const int height, const int width, const int depth,
91-
const int channels,
92-
const int patch_h, const int patch_w, const int patch_d,
93-
const int pad_h, const int pad_w, const int pad_d,
94-
const int stride_h, const int stride_w, const int stride_d,
95-
const int height_col, const int width_col, const int depth_col,
89+
__global__ void col2im3d_kernel(const int64_t n, const Dtype* data_col,
90+
const int64_t height, const int64_t width, const int64_t depth,
91+
const int64_t channels,
92+
const int64_t patch_h, const int64_t patch_w, const int64_t patch_d,
93+
const int64_t pad_h, const int64_t pad_w, const int64_t pad_d,
94+
const int64_t stride_h, const int64_t stride_w, const int64_t stride_d,
95+
const int64_t height_col, const int64_t width_col, const int64_t depth_col,
9696
Dtype* data_im)
9797
{
9898
CUDA_KERNEL_LOOP(index, n)
9999
{
100100
Acctype val = 0;
101-
int d = index % depth + pad_d;
102-
int w_index = index / depth;
103-
int w = w_index % width + pad_w;
104-
int h_index = w_index / width;
105-
int h = h_index % height + pad_h;
106-
int c = h_index / height;
101+
int64_t d = index % depth + pad_d;
102+
int64_t w_index = index / depth;
103+
int64_t w = w_index % width + pad_w;
104+
int64_t h_index = w_index / width;
105+
int64_t h = h_index % height + pad_h;
106+
int64_t c = h_index / height;
107107

108108
// compute the start and end of the output
109-
int d_col_start = (d < patch_d) ? 0 : (d - patch_d) / stride_d + 1;
110-
int d_col_end = min(d / stride_d + 1, depth_col);
111-
int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1;
112-
int w_col_end = min(w / stride_w + 1, width_col);
113-
int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1;
114-
int h_col_end = min(h / stride_h + 1, height_col);
109+
int64_t d_col_start = (d < patch_d) ? 0 : (d - patch_d) / stride_d + 1;
110+
int64_t d_col_end = min(d / stride_d + 1, depth_col);
111+
int64_t w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1;
112+
int64_t w_col_end = min(w / stride_w + 1, width_col);
113+
int64_t h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1;
114+
int64_t h_col_end = min(h / stride_h + 1, height_col);
115115

116-
int offset =
116+
int64_t offset =
117117
(c * patch_h * patch_w * patch_d + h * patch_w * patch_d + w * patch_d + d) * height_col * width_col * depth_col;
118118

119-
int coeff_h_col = (1 - stride_h * patch_w * patch_d * height_col) * width_col * depth_col;
120-
int coeff_w_col = (1 - stride_w * patch_d * height_col * width_col) * depth_col;
121-
int coeff_d_col = (1 - stride_d * height_col * width_col * depth_col);
122-
for (int d_col = d_col_start; d_col < d_col_end; ++d_col)
123-
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
124-
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
119+
int64_t coeff_h_col = (1 - stride_h * patch_w * patch_d * height_col) * width_col * depth_col;
120+
int64_t coeff_w_col = (1 - stride_w * patch_d * height_col * width_col) * depth_col;
121+
int64_t coeff_d_col = (1 - stride_d * height_col * width_col * depth_col);
122+
for (int64_t d_col = d_col_start; d_col < d_col_end; ++d_col)
123+
for (int64_t h_col = h_col_start; h_col < h_col_end; ++h_col) {
124+
for (int64_t w_col = w_col_start; w_col < w_col_end; ++w_col) {
125125
val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col + d_col * coeff_d_col];
126126
}
127127
}
@@ -130,17 +130,17 @@ __global__ void col2im3d_kernel(const int n, const Dtype* data_col,
130130
}
131131

132132
template <typename Dtype, typename Acctype>
133-
void col2im3d(cudaStream_t stream, const Dtype* data_col, const int channels,
134-
const int height, const int width, const int depth,
135-
const int patch_h, const int patch_w, const int patch_d,
136-
const int pad_h, const int pad_w, const int pad_d,
137-
const int stride_h, const int stride_w, const int stride_d,
133+
void col2im3d(cudaStream_t stream, const Dtype* data_col, const int64_t channels,
134+
const int64_t height, const int64_t width, const int64_t depth,
135+
const int64_t patch_h, const int64_t patch_w, const int64_t patch_d,
136+
const int64_t pad_h, const int64_t pad_w, const int64_t pad_d,
137+
const int64_t stride_h, const int64_t stride_w, const int64_t stride_d,
138138
Dtype* data_im)
139139
{
140-
int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
141-
int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
142-
int depth_col = (depth + 2 * pad_d - patch_d) / stride_d + 1;
143-
int num_kernels = channels * height * width * depth;
140+
int64_t height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
141+
int64_t width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
142+
int64_t depth_col = (depth + 2 * pad_d - patch_d) / stride_d + 1;
143+
int64_t num_kernels = channels * height * width * depth;
144144

145145
// To avoid involving atomic operations, we will launch one kernel per
146146
// bottom dimension, and then in the kernel add up the top dimensions.

aten/src/THCUNN/generic/Im2Col.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ static inline void THNN_(Im2Col_shapeCheck)(
3131
int inputWidth = THCTensor_(size)(state, input, dim_batch + 3);
3232
int outputHeight = (inputHeight + 2 * padH - (dH * (kH - 1) + 1)) / sH + 1;
3333
int outputWidth = (inputWidth + 2 * padW - (dW * (kW - 1) + 1)) / sW + 1;
34-
int nOutputPlane = nInputPlane * kW * kH;
35-
int outputLength = outputHeight * outputWidth;
3634

3735
if (outputHeight < 1 || outputWidth < 1) {
3836
THError("Given input with spatial size (%d, %d), kernel_size=(%d, %d), "

aten/src/THCUNN/generic/VolumetricConvolution.cu

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ static inline void THNN_(VolumetricConvolution_shapeCheck)
4747
if (weight == NULL) {
4848
weight = gradWeight;
4949
}
50-
int nOutputPlane = (int)weight->size[0];
51-
int nInputPlane = (int)weight->size[1];
52-
int kT = (int)weight->size[2];
53-
int kH = (int)weight->size[3];
54-
int kW = (int)weight->size[4];
50+
int64_t nOutputPlane = weight->size[0];
51+
int64_t nInputPlane = weight->size[1];
52+
int64_t kT = weight->size[2];
53+
int64_t kH = weight->size[3];
54+
int64_t kW = weight->size[4];
5555

5656
THArgCheck(kT > 0 && kW > 0 && kH > 0, 4,
5757
"kernel size should be greater than zero, but got kT: %d kH: %d kW: %d", kT, kH, kW);
@@ -267,11 +267,11 @@ void THNN_(VolumetricConvolution_updateGradInput)(
267267
int padT, int padW, int padH)
268268
{
269269

270-
int nOutputPlane = (int)weight->size[0];
271-
int nInputPlane = (int)weight->size[1];
272-
int kT = (int)weight->size[2];
273-
int kH = (int)weight->size[3];
274-
int kW = (int)weight->size[4];
270+
int64_t nOutputPlane = weight->size[0];
271+
int64_t nInputPlane = weight->size[1];
272+
int64_t kT = weight->size[2];
273+
int64_t kH = weight->size[3];
274+
int64_t kW = weight->size[4];
275275

276276
THCTensor *gradColumns = finput;
277277

@@ -507,7 +507,7 @@ void THNN_(VolumetricConvolution_accGradParameters)(
507507
#endif
508508
}
509509
}
510-
510+
511511
// Free
512512
THCTensor_(free)(state, input_n);
513513
THCTensor_(free)(state, gradOutput_n);

aten/src/THCUNN/generic/VolumetricDilatedMaxPooling.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
377377
THCDeviceTensor<THCIndex_t, 4> cudaIndices =
378378
toDeviceTensor<THCIndex_t, 4>(state, indices1);
379379

380-
int totalZ = outputTime * inputSlices * batchSize;
380+
int64_t totalZ = outputTime * inputSlices * batchSize;
381381
int offsetZ = 0;
382382
dim3 block(32, 8);
383383

aten/src/THNN/generic/Col2Im.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ static void THNN_(im2col)(const real* data_im, const int channels,
7070
int h_offset = (c_col / kernel_w) % kernel_h;
7171
int c_im = c_col / kernel_h / kernel_w;
7272
for (int h_col = 0; h_col < height_col; ++h_col) {
73+
int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
7374
for (int w_col = 0; w_col < width_col; ++w_col) {
74-
int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
7575
int w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
7676
data_col[(c_col * height_col + h_col) * width_col + w_col] =
7777
(h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ?
@@ -98,8 +98,8 @@ static void THNN_(col2im)(const real* data_col, const int channels,
9898
int h_offset = (c_col / kernel_w) % kernel_h;
9999
int c_im = c_col / kernel_h / kernel_w;
100100
for (int h_col = 0; h_col < height_col; ++h_col) {
101+
int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
101102
for (int w_col = 0; w_col < width_col; ++w_col) {
102-
int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
103103
int w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
104104
if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
105105
data_im[(c_im * height + h_im) * width + w_im] +=

aten/src/THNN/generic/VolumetricDilatedConvolution.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)(
9595
dilationT, dilationH, dilationW, 0);
9696

9797
// Params:
98-
int nInputPlane = weight->size[1];
99-
int nOutputPlane = weight->size[0];
98+
int64_t nInputPlane = weight->size[1];
99+
int64_t nOutputPlane = weight->size[0];
100100

101101
input = THTensor_(newContiguous)(input);
102102
weight = THTensor_(newContiguous)(weight);
@@ -230,8 +230,8 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)(
230230
dilationT, dilationH, dilationW, 0);
231231

232232
// Params
233-
int nInputPlane = weight->size[1];
234-
int nOutputPlane = weight->size[0];
233+
int64_t nInputPlane = weight->size[1];
234+
int64_t nOutputPlane = weight->size[0];
235235

236236
input = THTensor_(newContiguous)(input);
237237
gradOutput = THTensor_(newContiguous)(gradOutput);

0 commit comments

Comments
 (0)