Skip to content

Dilated Conv3d segfaults (cpu) #9264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mys007 opened this issue Jul 9, 2018 · 3 comments
Closed

Dilated Conv3d segfaults (cpu) #9264

mys007 opened this issue Jul 9, 2018 · 3 comments
Assignees

Comments

@mys007
Copy link
Contributor

mys007 commented Jul 9, 2018

I have encountered a segfault when running Conv3d with dilation under pytorch 0.4.0 on the cpu for some specific tensor sizes. This is the code to reproduce the behavior:

import torch
import torch.nn as nn
mod = nn.Conv3d(96, 256, 4, 1, 0, 2)     # does not crash with dilation = 1
mod(torch.zeros([1, 96, 111, 63, 111]))

And this is the stack trace from gdb:

#0  0x00007fffbfb0c3d7 in THNN_Floatvol2col(float const*, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, float*)
    () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#1  0x00007fffbfc240cb in THNN_FloatVolumetricDilatedConvolution_updateOutput ()
   from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#2  0x00007fffbf42ae43 in at::CPUFloatType::thnn_conv_dilated3d_forward(at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>) const () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#3  0x00007fffdeda145d in torch::autograd::VariableType::thnn_conv_dilated3d_forward (this=0x555556570a40, self=..., weight=..., kernel_size=..., bias=..., 
    stride=..., padding=..., dilation=...) at torch/csrc/autograd/generated/VariableType.cpp:17592
#4  0x00007fffbf625595 in at::Type::thnn_conv_dilated3d(at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>) const () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#5  0x00007fffded3ec54 in torch::autograd::VariableType::thnn_conv_dilated3d (this=0x555556570a40, self=..., weight=..., kernel_size=..., bias=..., stride=..., 
    padding=..., dilation=...) at torch/csrc/autograd/generated/VariableType.cpp:17528
#6  0x00007fffbf313a2b in at::native::_convolution_nogroup(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>) () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#7  0x00007fffbf620bae in at::Type::_convolution_nogroup(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>) const () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#8  0x00007fffded40db7 in torch::autograd::VariableType::_convolution_nogroup (this=0x555556570a40, input=..., weight=..., bias=..., stride=..., padding=..., 
    dilation=..., transposed=false, output_padding=...) at torch/csrc/autograd/generated/VariableType.cpp:18403
#9  0x00007fffbf318fe8 in at::native::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>, long, bool, bool, bool) () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#10 0x00007fffbf620b64 in at::Type::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>, long, bool, bool, bool) const () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#11 0x00007fffded40a6e in torch::autograd::VariableType::_convolution (this=0x555556570a40, input=..., weight=..., bias=..., stride=..., padding=..., 
    dilation=..., transposed=false, output_padding=..., groups=1, benchmark=false, deterministic=false, cudnn_enabled=true)
    at torch/csrc/autograd/generated/VariableType.cpp:18386
#12 0x00007fffbf313e0e in at::native::convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>, long) () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#13 0x00007fffbf620afe in at::Type::convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>, long) const () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#14 0x00007fffdece124c in torch::autograd::VariableType::convolution (this=0x555556570a40, input=..., weight=..., bias=..., stride=..., padding=..., 
    dilation=..., transposed=<optimised out>, output_padding=..., groups=1) at torch/csrc/autograd/generated/VariableType.cpp:18368
#15 0x00007fffbf3141aa in at::native::conv3d(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, long) () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#16 0x00007fffbf620d22 in at::Type::conv3d(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, long) const () from /home/martin/miniconda3/envs/BAI/lib/python3.6/site-packages/torch/lib/libATen.so
#17 0x00007fffdece147e in torch::autograd::VariableType::conv3d (this=0x555556570a40, input=..., weight=..., bias=..., stride=..., padding=..., dilation=..., 
    groups=1) at torch/csrc/autograd/generated/VariableType.cpp:18446
#18 0x00007fffdee94387 in at::conv3d (groups=1, dilation=..., padding=..., stride=..., bias=..., weight=..., input=...)
    at /opt/conda/conda-bld/pytorch_1524584710464/work/torch/lib/tmp_install/include/ATen/Functions.h:3044
#19 torch::autograd::dispatch_conv3d (groups=1, dilation=..., padding=..., stride=..., bias=..., weight=..., input=...)
    at torch/csrc/autograd/generated/python_torch_functions_dispatch.h:1073
#20 torch::autograd::THPVariable_conv3d (self=<optimised out>, args=<optimised out>, kwargs=<optimised out>)
    at torch/csrc/autograd/generated/python_torch_functions.cpp:1662

And this is the environment information:

Collecting environment information...
PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 8.0.61

OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: Could not collect

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] numpy (1.14.5)
[pip] torch (0.4.0)
[pip] torchvision (0.2.1)
[conda] pytorch                   0.4.0           py36_cuda8.0.61_cudnn7.1.2_1    pytorch
[conda] torchvision               0.2.1                    py36_1    pytorch

Thanks for your help!

@ssnl
Copy link
Collaborator

ssnl commented Jul 9, 2018

repro'ed on master!

@ssnl
Copy link
Collaborator

ssnl commented Jul 9, 2018

oh it overflowed int range.... fix incoming

@mys007
Copy link
Contributor Author

mys007 commented Jul 9, 2018

Well done, Tongzhou, that was really brisk!!

goodlux pushed a commit to goodlux/pytorch that referenced this issue Aug 15, 2018
Summary:
Fixes pytorch#9264 .

There can be so many elements in the output of `vol2col` so it overflows `int` range! This PR changes 3d conv to use `int64_t` mostly.

Also fixes some unused var warning (cc goldsborough )
Pull Request resolved: pytorch#9274

Differential Revision: D8770682

Pulled By: SsnL

fbshipit-source-id: f6e37f1aa56fe1009dd4c9bcbc042244e47252db
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants