From ccb7899624b170a9a86d6fdf21d8124ff1b364a1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Oct 2021 12:58:47 +0100 Subject: [PATCH 1/5] Adding FasterRCNN ResNet50. --- torchvision/prototype/models/__init__.py | 1 + torchvision/prototype/models/_meta.py | 95 +++++++++++++++++++ .../prototype/models/detection/__init__.py | 1 + .../models/detection/backbone_utils.py | 37 ++++++++ .../prototype/models/detection/faster_rcnn.py | 59 ++++++++++++ torchvision/prototype/transforms/presets.py | 41 ++++---- 6 files changed, 212 insertions(+), 22 deletions(-) create mode 100644 torchvision/prototype/models/detection/__init__.py create mode 100644 torchvision/prototype/models/detection/backbone_utils.py create mode 100644 torchvision/prototype/models/detection/faster_rcnn.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index b792ca6ecf7..2879a37dbf9 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1 +1,2 @@ from .resnet import * +from . import detection diff --git a/torchvision/prototype/models/_meta.py b/torchvision/prototype/models/_meta.py index 87eb338d51a..056ffdfa249 100644 --- a/torchvision/prototype/models/_meta.py +++ b/torchvision/prototype/models/_meta.py @@ -1006,3 +1006,98 @@ "ear", "toilet tissue", ] + +# To be replaced with torchvision.datasets.find("coco").info.categories +_COCO_CATEGORIES = [ + "__background__", + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "N/A", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "N/A", + "backpack", + "umbrella", + "N/A", + "N/A", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "N/A", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "N/A", + "dining table", + "N/A", + "N/A", + "toilet", + "N/A", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "N/A", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +] diff --git a/torchvision/prototype/models/detection/__init__.py b/torchvision/prototype/models/detection/__init__.py new file mode 100644 index 00000000000..e7e08d7ac8e --- /dev/null +++ b/torchvision/prototype/models/detection/__init__.py @@ -0,0 +1 @@ +from .faster_rcnn import * diff --git a/torchvision/prototype/models/detection/backbone_utils.py b/torchvision/prototype/models/detection/backbone_utils.py new file mode 100644 index 00000000000..9806773ac7f --- /dev/null +++ b/torchvision/prototype/models/detection/backbone_utils.py @@ -0,0 +1,37 @@ +from ....models.detection.backbone_utils import misc_nn_ops, LastLevelMaxPool, BackboneWithFPN +from .. import resnet + + +def resnet_fpn_backbone( + backbone_name, + weights, + norm_layer=misc_nn_ops.FrozenBatchNorm2d, + trainable_layers=3, + returned_layers=None, + extra_blocks=None, +): + backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) + + # COPY-PASTED CODE FROM torchvision.models.detection.backbone_utils.resnet_fpn_backbone + # ===================================================================================== + assert 0 <= trainable_layers <= 5 + layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] + if trainable_layers == 5: + layers_to_train.append("bn1") + for name, parameter in backbone.named_parameters(): + if all([not name.startswith(layer) for layer in layers_to_train]): + parameter.requires_grad_(False) + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + if returned_layers is None: + returned_layers = [1, 2, 3, 4] + assert min(returned_layers) > 0 and max(returned_layers) < 5 + return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} + + in_channels_stage2 = backbone.inplanes // 8 + in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] + out_channels = 256 + return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) + # ===================================================================================== diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py new file mode 100644 index 00000000000..c3e72efe448 --- /dev/null +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -0,0 +1,59 @@ +import warnings +from typing import Any, Optional + +from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers +from ...transforms.presets import CocoEval +from .._api import Weights, WeightEntry +from .._meta import _COCO_CATEGORIES +from ..resnet import ResNet50Weights +from .backbone_utils import resnet_fpn_backbone + + +__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] + + +class FasterRCNNResNet50FPNWeights(Weights): + Coco_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", + transforms=CocoEval, + meta={ + "categories": _COCO_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", + "map": 37.0, + }, + ) + + +def fasterrcnn_resnet50_fpn( + weights: Optional[FasterRCNNResNet50FPNWeights] = None, + weights_backbone: Optional[ResNet50Weights] = ResNet50Weights.ImageNet1K_RefV1, # TODO: Should we default to None? + progress: bool = True, + num_classes: int = 91, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None + weights = FasterRCNNResNet50FPNWeights.verify(weights) + if "pretrained_backbone" in kwargs: + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") + weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None + weights_backbone = ResNet50Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = len(weights.meta["categories"]) + + trainable_backbone_layers = _validate_trainable_layers( + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 + ) + + backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers) + model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + overwrite_eps(model, 0.0) + + return model diff --git a/torchvision/prototype/transforms/presets.py b/torchvision/prototype/transforms/presets.py index 56d6fc60e02..81ad8cefbfe 100644 --- a/torchvision/prototype/transforms/presets.py +++ b/torchvision/prototype/transforms/presets.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Dict, Optional, Tuple import torch from torch import Tensor, nn @@ -7,22 +7,19 @@ from ...transforms import functional as F -__all__ = ["ConvertImageDtype", "ImageNetEval"] +__all__ = ["CocoEval", "ImageNetEval"] -# Allows handling of both PIL and Tensor images -class ConvertImageDtype(nn.Module): - def __init__(self, dtype: torch.dtype) -> None: - super().__init__() - self.dtype = dtype - - def forward(self, img: Tensor) -> Tensor: +class CocoEval(nn.Module): + def forward( + self, img: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if not isinstance(img, Tensor): img = F.pil_to_tensor(img) - return F.convert_image_dtype(img, self.dtype) + return F.convert_image_dtype(img, torch.float), target -class ImageNetEval: +class ImageNetEval(nn.Module): def __init__( self, crop_size: int, @@ -31,14 +28,14 @@ def __init__( std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, ) -> None: - self.transforms = T.Compose( - [ - T.Resize(resize_size, interpolation=interpolation), - T.CenterCrop(crop_size), - ConvertImageDtype(dtype=torch.float), - T.Normalize(mean=mean, std=std), - ] - ) - - def __call__(self, img: Tensor) -> Tensor: - return self.transforms(img) + super().__init__() + self._resize = T.Resize(resize_size, interpolation=interpolation) + self._crop = T.CenterCrop(crop_size) + self._normalize = T.Normalize(mean=mean, std=std) + + def forward(self, img: Tensor) -> Tensor: + img = self._crop(self._resize(img)) + if not isinstance(img, Tensor): + img = F.pil_to_tensor(img) + img = F.convert_image_dtype(img, torch.float) + return self._normalize(img) From 7c256a2f6608da11aed7ee8e9e733861bca8fe43 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Oct 2021 16:03:31 +0100 Subject: [PATCH 2/5] Refactoring to remove duplicate code. --- .../models/detection/backbone_utils.py | 3 +++ .../models/detection/backbone_utils.py | 27 ++----------------- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 70a7b40bd50..22b9231f517 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -92,7 +92,10 @@ def resnet_fpn_backbone( default a ``LastLevelMaxPool`` is used. """ backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) + +def _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks): # select layers that wont be frozen assert 0 <= trainable_layers <= 5 layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] diff --git a/torchvision/prototype/models/detection/backbone_utils.py b/torchvision/prototype/models/detection/backbone_utils.py index 9806773ac7f..9893ebf8e5d 100644 --- a/torchvision/prototype/models/detection/backbone_utils.py +++ b/torchvision/prototype/models/detection/backbone_utils.py @@ -1,4 +1,4 @@ -from ....models.detection.backbone_utils import misc_nn_ops, LastLevelMaxPool, BackboneWithFPN +from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config from .. import resnet @@ -11,27 +11,4 @@ def resnet_fpn_backbone( extra_blocks=None, ): backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) - - # COPY-PASTED CODE FROM torchvision.models.detection.backbone_utils.resnet_fpn_backbone - # ===================================================================================== - assert 0 <= trainable_layers <= 5 - layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] - if trainable_layers == 5: - layers_to_train.append("bn1") - for name, parameter in backbone.named_parameters(): - if all([not name.startswith(layer) for layer in layers_to_train]): - parameter.requires_grad_(False) - - if extra_blocks is None: - extra_blocks = LastLevelMaxPool() - - if returned_layers is None: - returned_layers = [1, 2, 3, 4] - assert min(returned_layers) > 0 and max(returned_layers) < 5 - return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} - - in_channels_stage2 = backbone.inplanes // 8 - in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] - out_channels = 256 - return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) - # ===================================================================================== + return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) From 802d45a7fca56c46921e5e1ec8f05df7732343a2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Oct 2021 18:12:12 +0100 Subject: [PATCH 3/5] Adding typing info. --- torchvision/models/detection/backbone_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 22b9231f517..a2a45c43733 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,8 +1,9 @@ import warnings +from typing import List, Optional from torch import nn from torchvision.ops import misc as misc_nn_ops -from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool +from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock from .. import mobilenet from .. import resnet @@ -95,7 +96,12 @@ def resnet_fpn_backbone( return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) -def _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks): +def _resnet_backbone_config( + backbone: resnet.ResNet, + trainable_layers: int, + returned_layers: Optional[List[int]], + extra_blocks: Optional[ExtraFPNBlock], +): # select layers that wont be frozen assert 0 <= trainable_layers <= 5 layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] From 0146a2384b4d963f5cc47e6f890b098824acf6c5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Oct 2021 18:31:01 +0100 Subject: [PATCH 4/5] Setting weights_backbone=None as default value. --- torchvision/prototype/models/detection/faster_rcnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index c3e72efe448..f6d38c8584e 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -26,7 +26,7 @@ class FasterRCNNResNet50FPNWeights(Weights): def fasterrcnn_resnet50_fpn( weights: Optional[FasterRCNNResNet50FPNWeights] = None, - weights_backbone: Optional[ResNet50Weights] = ResNet50Weights.ImageNet1K_RefV1, # TODO: Should we default to None? + weights_backbone: Optional[ResNet50Weights] = None, progress: bool = True, num_classes: int = 91, trainable_backbone_layers: Optional[int] = None, From 32f2a2816b62d5c6cbd4b07dcce73a720639450e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 15 Oct 2021 13:38:48 +0100 Subject: [PATCH 5/5] Overwrite eps only for specific weights. --- torchvision/prototype/models/detection/faster_rcnn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index f6d38c8584e..0b27eb50a37 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -54,6 +54,7 @@ def fasterrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.state_dict(progress=progress)) - overwrite_eps(model, 0.0) + if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1: + overwrite_eps(model, 0.0) return model