-
Notifications
You must be signed in to change notification settings - Fork 7.1k
r3d_18(3D_resnet) do not work!! (When I import it in my torch.nn.Module) #4083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi @sungmanhong , I cannot reproduce your issue with: import torch
import torchvision
model = torchvision.models.video.r3d_18(pretrained=False)
img = torch.randn(1, 3, 192, 112, 112).to(torch.float)
model(img) which runs fine. Could you please provide a minimal reproducible example showing the error, as well as the output of
as requested in the issue template? Thanks |
Thanks for your awesome support! (include my other question, and sorry for late response, it was holiday in Korea, sorry...) But I found that, If I import the r3d_18 model in the custom "torch.nn" class(for re-training, instead of freeze the layer),
You can reduce the video_length or may width or height if you get CUDA memory error.
or
I found that it occurs in the last layer(linear? ReLU?) However, your code runs fine.
So I suspect that is some dimension mismatch when the batch dimension is inserted, at forward(). **ADD: when I do not load the last layer, for example, it runs fine.
The problem is in the last layer of r3d_18 when I load the model in the class of torch.nn.Module. |
Hello, I just made very simple error reconstruction as follows...
which results in below error...
|
Thanks for sharing the reproducible example. A small piece of code like that is way better than sharing a notebook :) Regarding the issue: it seems that the core problem is that import torch
import torch.nn as nn
import torchvision
img = torch.randn(1, 3, 8, 112, 112).to(torch.float)
model = torchvision.models.video.r3d_18(pretrained=False)
seq = nn.Sequential(*list(model.children()))
seq(img).shape
# error fails, which is a bit unexpected. Looking at the vision/torchvision/models/video/resnet.py Lines 226 to 239 in 68b128d
A simple fix for you would be to introduce this seq = nn.Sequential(*list(model.children())[:-1], nn.Flatten(1), list(model.children())[-1])
seq(img).shape
# passes I wonder however whether we should always have |
This doesn't always hold, and only works for a very limited subset of models. A more robust way of doing this would be to rely on FX for the intermediate feature extraction #3597, but I still need to get that PR to completion. |
Thank you for helping me. I didn't know that the flatten layer was missing because I lacked detailed debugging skills, so I asked foolishly, but thank you for your kind reply. I think, PyTorch was good because we are able to selectively import the layers by Of course, it may be difficult to ensure full compatibility for all models to |
Thanks for the feedback @sungmanhong , you're right that ideally model ablation / tweaking could be a bit easier. As @fmassa mentioned, hopefully #3597 will make this easier and will be available in the next release. We'll make sure to write good docs to explain how to perform that. Since the original problem is resolved here, I'll close the issue if you don't mind |
Hi, I just loaded pre-trained 3D-resnets
(https://pytorch.org/vision/stable/models.html#video-classification)
What I loaded is ResNet 3D 18(r3d_18) and my input shape is [1 x 3 x 192 x 112 x 112]
where 1 is batch size, 3 is [R,G,B] and 192 is video_length and 112 and 112 is W and H.
I got this error, at forward()
Why this error come out?? I just read that resnet3D requries the above shape what I mentioned...
The text was updated successfully, but these errors were encountered: