Skip to content

r3d_18(3D_resnet) do not work!! (When I import it in my torch.nn.Module) #4083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
NeighborhoodCoding opened this issue Jun 18, 2021 · 7 comments

Comments

@NeighborhoodCoding
Copy link

Hi, I just loaded pre-trained 3D-resnets
(https://pytorch.org/vision/stable/models.html#video-classification)
What I loaded is ResNet 3D 18(r3d_18) and my input shape is [1 x 3 x 192 x 112 x 112]
where 1 is batch size, 3 is [R,G,B] and 192 is video_length and 112 and 112 is W and H.

I got this error, at forward()



  File "train_total.py", line 497, in forward
    timed_CNN_out_front = self.CNNlayers_front(pixel.cuda())
  File "/home/ai/anaconda3/envs/mos/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ai/anaconda3/envs/mos/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/ai/anaconda3/envs/mos/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ai/anaconda3/envs/mos/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/ai/anaconda3/envs/mos/lib/python3.6/site-packages/torch/nn/functional.py", line 1372, in linear
    output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [512 x 1], m2: [512 x 400] at /opt/anaconda/conda-bld/pytorch-base_1600153523661/work/aten/src/THC/generic/THCTensorMathBlas.cu:290



Why this error come out?? I just read that resnet3D requries the above shape what I mentioned...

@NicolasHug
Copy link
Member

Hi @sungmanhong , I cannot reproduce your issue with:

import torch
import torchvision
model = torchvision.models.video.r3d_18(pretrained=False)
img = torch.randn(1, 3, 192, 112, 112).to(torch.float)
model(img)

which runs fine.

Could you please provide a minimal reproducible example showing the error, as well as the output of

wget https://github.com/raw/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

as requested in the issue template? Thanks

@NeighborhoodCoding
Copy link
Author

NeighborhoodCoding commented Jun 21, 2021

Thanks for your awesome support! (include my other question, and sorry for late response, it was holiday in Korea, sorry...)
Yes, the above code run fine!

But I found that, If I import the r3d_18 model in the custom "torch.nn" class(for re-training, instead of freeze the layer),
the error occurs.
Here is example(I tried to share as colab but the cuda memory error occurs, So I just sharing the codes...)

import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
from torch.optim import Adam, lr_scheduler
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms, models
import numpy as np
import h5py
from PIL import Image
import os
import datetime
from scipy import stats
import logging

class CustomNN(torch.nn.Module):
    def __init__(self):
        super(CustomNN, self).__init__()
        self.CNNlayers_front = nn.Sequential(*list(models.video.r3d_18(pretrained=True).children()))
        self.QoS = nn.Linear(512, 1)

    def forward(self, pixel):
        pixel = self.CNNlayers_front(pixel.cuda())
        pixel = self.QoS(pixel)
        return pixel

model = CustomNN()
outputs = model(torch.zeros([1, 3, 4, 112, 112]))   #batch, RGB(3), video_length, width, height 

You can reduce the video_length or may width or height if you get CUDA memory error.
I got this error when I run the above code

RuntimeError: size mismatch, m1: [512 x 1], m2: [512 x 400] at /opt/anaconda/conda-bld/pytorch-base_1600153523661/work/aten/src/THC/generic/THCTensorMathBlas.cu:290

or

RuntimeError: mat1 dim 1 must match mat2 dim 0

I found that it occurs in the last layer(linear? ReLU?)
The last linear layer get 512 features and returns 400 features in r3d_18's final classification output(512 input features to 400ea final decision features), but for some reason,
there is some size mismatch. I suspect, the dimension of batch may doubled? As I imported in custom nn.Module? so maybe the r3d_18 get dimension of 1, instead of 512, and returns the above error.

However, your code runs fine.

import torch
import torchvision
model = torchvision.models.video.r3d_18(pretrained=False)
img = torch.randn(1, 3, 192, 112, 112).to(torch.float)
model(img)

So I suspect that is some dimension mismatch when the batch dimension is inserted, at forward().
So I examine that part. Thanks.

**ADD: when I do not load the last layer, for example, it runs fine.
The difference between the code is In
nn.Sequential(*list(models.video.r3d_18(pretrained=True).children())[0:6])
and
nn.Sequential(*list(models.video.r3d_18(pretrained=True).children()))


class CustomNN(torch.nn.Module):
    def __init__(self):
        super(CustomNN, self).__init__()
        self.CNNlayers_front = nn.Sequential(*list(models.video.r3d_18(pretrained=True).children())[0:6])
        self.QoS = nn.Linear(512, 1)

    def forward(self, pixel):
        pixel = self.CNNlayers_front(pixel.cuda())
        #pixel = self.QoS(pixel)
        return pixel

model = CustomNN()
outputs = model(torch.zeros([1, 3, 4, 112, 112]))   #batch, RGB(3), video_length, width, height 

The problem is in the last layer of r3d_18 when I load the model in the class of torch.nn.Module.
And I updated my question name for more directly.
Thanks.

@NeighborhoodCoding NeighborhoodCoding changed the title r3d_18(3D_resnet) do not work!! r3d_18(3D_resnet) do not work!! (When I import in my torch.nn.Module) Jun 21, 2021
@NeighborhoodCoding NeighborhoodCoding changed the title r3d_18(3D_resnet) do not work!! (When I import in my torch.nn.Module) r3d_18(3D_resnet) do not work!! (When I import it in my torch.nn.Module) Jun 21, 2021
@NeighborhoodCoding
Copy link
Author

NeighborhoodCoding commented Jun 21, 2021

Hello, I just made very simple error reconstruction as follows...

import torch
import torchvision
model = nn.Sequential(*list(models.video.r3d_18(pretrained=True).children())[:]).to('cuda')
img = torch.randn(1, 3, 192, 112, 112).to(torch.float).to('cuda')
model(img)

which results in below error...
Can't I use nn.Sequential(*list(models.video.r3d_18(pretrained=True).children())[:])?
I'm using nn.Sequential to do a transfer learning



Traceback (most recent call last):
  File "train.py", line 155, in <module>
    model(img)
  File "/home/hsm/anaconda3/envs/reproducibleresearch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/hsm/anaconda3/envs/reproducibleresearch/lib/python3.6/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/home/hsm/anaconda3/envs/reproducibleresearch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/hsm/anaconda3/envs/reproducibleresearch/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 94, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/hsm/anaconda3/envs/reproducibleresearch/lib/python3.6/site-packages/torch/nn/functional.py", line 1753, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 dim 1 must match mat2 dim 0

@NicolasHug
Copy link
Member

Here is example(I tried to share as colab but the cuda memory error occurs, So I just sharing the codes...)

Thanks for sharing the reproducible example. A small piece of code like that is way better than sharing a notebook :)

Regarding the issue: it seems that the core problem is that

import torch
import torch.nn as nn
import torchvision

img = torch.randn(1, 3, 8, 112, 112).to(torch.float)

model = torchvision.models.video.r3d_18(pretrained=False)
seq = nn.Sequential(*list(model.children()))
seq(img).shape
# error

fails, which is a bit unexpected. Looking at the forward() pass of r3d_18, there's a call to flatten which is not present when calling children(), which causes the dimension mismatch issue:

def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
# Flatten the layer to fc
x = x.flatten(1)
x = self.fc(x)
return x

A simple fix for you would be to introduce this Flatten in the Sequential model like so:

seq = nn.Sequential(*list(model.children())[:-1], nn.Flatten(1), list(model.children())[-1])
seq(img).shape
# passes

I wonder however whether we should always have nn.Sequential(*list(model.children()) == model, or if this doesn't hold in general anyway. Thoughts @fmassa?

@fmassa
Copy link
Member

fmassa commented Jun 21, 2021

I wonder however whether we should always have nn.Sequential(*list(model.children()) == model, or if this doesn't hold in general anyway. Thoughts @fmassa?

This doesn't always hold, and only works for a very limited subset of models.

A more robust way of doing this would be to rely on FX for the intermediate feature extraction #3597, but I still need to get that PR to completion.

@NeighborhoodCoding
Copy link
Author

NeighborhoodCoding commented Jun 21, 2021

Thank you for helping me. I didn't know that the flatten layer was missing because I lacked detailed debugging skills, so I asked foolishly, but thank you for your kind reply.

I think, PyTorch was good because we are able to selectively import the layers by nn.Sequential(*list(model.children()). For example, When we are doing transfer learning on images or videos, choosing What layers to import from pre-trained ResNets(Not all layers) would be great for very flexible and easy way coding for newbie researchers.

Of course, it may be difficult to ensure full compatibility for all models to nn.Sequential(*list(model.children()), but if compatibility is up-leveled, I think it would help researchers in many areas (in my case, 3d-resnet for video classification or regression).

@NicolasHug
Copy link
Member

Thanks for the feedback @sungmanhong , you're right that ideally model ablation / tweaking could be a bit easier. As @fmassa mentioned, hopefully #3597 will make this easier and will be available in the next release. We'll make sure to write good docs to explain how to perform that.

Since the original problem is resolved here, I'll close the issue if you don't mind

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants