Skip to content

Multi-pretrained weight support - FasterRCNN ResNet50 #4613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from typing import List, Optional

from torch import nn
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock

from .. import mobilenet
from .. import resnet
Expand Down Expand Up @@ -92,7 +93,15 @@ def resnet_fpn_backbone(
default a ``LastLevelMaxPool`` is used.
"""
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)


def _resnet_backbone_config(
backbone: resnet.ResNet,
trainable_layers: int,
returned_layers: Optional[List[int]],
extra_blocks: Optional[ExtraFPNBlock],
):
# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .resnet import *
from . import detection
95 changes: 95 additions & 0 deletions torchvision/prototype/models/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,3 +1006,98 @@
"ear",
"toilet tissue",
]

# To be replaced with torchvision.datasets.find("coco").info.categories
_COCO_CATEGORIES = [
"__background__",
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"N/A",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"N/A",
"backpack",
"umbrella",
"N/A",
"N/A",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"N/A",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"N/A",
"dining table",
"N/A",
"N/A",
"toilet",
"N/A",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"N/A",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
1 change: 1 addition & 0 deletions torchvision/prototype/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .faster_rcnn import *
14 changes: 14 additions & 0 deletions torchvision/prototype/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config
from .. import resnet


def resnet_fpn_backbone(
backbone_name,
weights,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
Copy link
Contributor Author

@datumbox datumbox Oct 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I'm forced to copy the whole function just to change the pretrained to weights param. I refactored to minimize copy-pasted code.

return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
60 changes: 60 additions & 0 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import warnings
from typing import Any, Optional

from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inherit as much as possible. The changes below will be moved on the existing files once we move to torchvision.

from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights
from .backbone_utils import resnet_fpn_backbone


__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]


class FasterRCNNResNet50FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
)


def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)

if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably raise an error / warning if the user modifies the num_classes and passes a weights argument. Otherwise they might silently think that we are doing magic inside

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'll add this check to resnet as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this and it's a bit problematic. The num_classes parameter has a default value in all of our model builders. So to see i it was modified, we need to see if the default value was changed which can lead to messy code. An alternative approach could be to throw a warning if the num_classes != len(weights.meta["categories"]) but still overwrite it to make the life of users easier.

Because it's not clear how this should be handled, I'm going to merge the PR to unblock the work but I'm happy to discuss the policy here and update everywhere in a follow up PR.


trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)

backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)

return model
41 changes: 19 additions & 22 deletions torchvision/prototype/transforms/presets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Dict, Optional, Tuple

import torch
from torch import Tensor, nn
Expand All @@ -7,22 +7,19 @@
from ...transforms import functional as F


__all__ = ["ConvertImageDtype", "ImageNetEval"]
__all__ = ["CocoEval", "ImageNetEval"]


# Allows handling of both PIL and Tensor images
class ConvertImageDtype(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the standalone transform to avoid introducing a new class here.

def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype

def forward(self, img: Tensor) -> Tensor:
class CocoEval(nn.Module):
def forward(
self, img: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
return F.convert_image_dtype(img, self.dtype)
return F.convert_image_dtype(img, torch.float), target


class ImageNetEval:
class ImageNetEval(nn.Module):
def __init__(
self,
crop_size: int,
Expand All @@ -31,14 +28,14 @@ def __init__(
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
) -> None:
self.transforms = T.Compose(
[
T.Resize(resize_size, interpolation=interpolation),
T.CenterCrop(crop_size),
ConvertImageDtype(dtype=torch.float),
T.Normalize(mean=mean, std=std),
]
)

def __call__(self, img: Tensor) -> Tensor:
return self.transforms(img)
super().__init__()
self._resize = T.Resize(resize_size, interpolation=interpolation)
self._crop = T.CenterCrop(crop_size)
self._normalize = T.Normalize(mean=mean, std=std)

def forward(self, img: Tensor) -> Tensor:
img = self._crop(self._resize(img))
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float)
return self._normalize(img)