Skip to content

[models] IntermediateLayerGetter doesn't handle well architectures with modules referenced multiple times #3802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
frgfm opened this issue May 9, 2021 · 3 comments

Comments

@frgfm
Copy link
Contributor

frgfm commented May 9, 2021

🐛 Minor trouble with IntermediateLayerGetter

Since IntermediateLayerGetter is using the .named_children method of torch.nn.Module, the accepted layers cannot be a layer that is referenced multiple times in the architecture. Using the private ._modules.items() method of the same object would fix it though. Happy to open a PR if you think it's a good idea

To Reproduce

from torch import nn
from torchvision.models._utils import IntermediateLayerGetter

act_layer = nn.ReLU()
mod = nn.Sequential(nn.Conv2d(3, 8, 3, padding=1), act_layer, nn.Conv2d(8, 16, 3, padding=1), act_layer)
layer_getter = IntermediateLayerGetter(mod, {"3": "0"})

raises

ValueError: return_layers are not present in model

Expected behavior

The above snippet should be working without raising any error

Environment

  • PyTorch / torchvision Version: 1.8.0 / 0.9.0
  • OS: Linux
  • How you installed PyTorch / torchvision: conda
  • Python version: 3.8
  • CUDA/cuDNN version: CUDA 11.1 / cuDNN 8.0.5
  • GPU models and configuration: GeForce RTX 2070 with Max-Q Design (driver 450.119.03)
@fmassa
Copy link
Member

fmassa commented May 10, 2021

Hi,

Thanks for opening this issue!

Indeed, this is a current limitation with IntermediateLayerGetter. IntermediateLayerGetter is very fragile and only supports a very limited set of use-cases, but we are actively looking into improving it, with the next iteration of it present in #3597

Also, a note of warning: IntermediateLayerGetter is a private API for now (present in _utils), so we might break its API without deprecation warnings.

@fmassa
Copy link
Member

fmassa commented Sep 8, 2021

While not directly fixed in IntermediateLayerGetter, the newly provided feature in #4302 should handle this case without problems.

As such, I'm closing this issue, but let us know if you face any problems with #4302

@fmassa fmassa closed this as completed Sep 8, 2021
@frgfm
Copy link
Contributor Author

frgfm commented Sep 21, 2021

Apologies, I left those notifications to make sure I had time investigating the corresponding PR. Thanks a lot @fmassa, that looks exactly like what was needed 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants