diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 70a7b40bd50..a2a45c43733 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,8 +1,9 @@ import warnings +from typing import List, Optional from torch import nn from torchvision.ops import misc as misc_nn_ops -from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool +from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock from .. import mobilenet from .. import resnet @@ -92,7 +93,15 @@ def resnet_fpn_backbone( default a ``LastLevelMaxPool`` is used. """ backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) + +def _resnet_backbone_config( + backbone: resnet.ResNet, + trainable_layers: int, + returned_layers: Optional[List[int]], + extra_blocks: Optional[ExtraFPNBlock], +): # select layers that wont be frozen assert 0 <= trainable_layers <= 5 layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index b792ca6ecf7..2879a37dbf9 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1 +1,2 @@ from .resnet import * +from . import detection diff --git a/torchvision/prototype/models/_meta.py b/torchvision/prototype/models/_meta.py index 87eb338d51a..056ffdfa249 100644 --- a/torchvision/prototype/models/_meta.py +++ b/torchvision/prototype/models/_meta.py @@ -1006,3 +1006,98 @@ "ear", "toilet tissue", ] + +# To be replaced with torchvision.datasets.find("coco").info.categories +_COCO_CATEGORIES = [ + "__background__", + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "N/A", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "N/A", + "backpack", + "umbrella", + "N/A", + "N/A", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "N/A", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "N/A", + "dining table", + "N/A", + "N/A", + "toilet", + "N/A", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "N/A", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +] diff --git a/torchvision/prototype/models/detection/__init__.py b/torchvision/prototype/models/detection/__init__.py new file mode 100644 index 00000000000..e7e08d7ac8e --- /dev/null +++ b/torchvision/prototype/models/detection/__init__.py @@ -0,0 +1 @@ +from .faster_rcnn import * diff --git a/torchvision/prototype/models/detection/backbone_utils.py b/torchvision/prototype/models/detection/backbone_utils.py new file mode 100644 index 00000000000..9893ebf8e5d --- /dev/null +++ b/torchvision/prototype/models/detection/backbone_utils.py @@ -0,0 +1,14 @@ +from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config +from .. import resnet + + +def resnet_fpn_backbone( + backbone_name, + weights, + norm_layer=misc_nn_ops.FrozenBatchNorm2d, + trainable_layers=3, + returned_layers=None, + extra_blocks=None, +): + backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) + return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py new file mode 100644 index 00000000000..0b27eb50a37 --- /dev/null +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -0,0 +1,60 @@ +import warnings +from typing import Any, Optional + +from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers +from ...transforms.presets import CocoEval +from .._api import Weights, WeightEntry +from .._meta import _COCO_CATEGORIES +from ..resnet import ResNet50Weights +from .backbone_utils import resnet_fpn_backbone + + +__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] + + +class FasterRCNNResNet50FPNWeights(Weights): + Coco_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", + transforms=CocoEval, + meta={ + "categories": _COCO_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", + "map": 37.0, + }, + ) + + +def fasterrcnn_resnet50_fpn( + weights: Optional[FasterRCNNResNet50FPNWeights] = None, + weights_backbone: Optional[ResNet50Weights] = None, + progress: bool = True, + num_classes: int = 91, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None + weights = FasterRCNNResNet50FPNWeights.verify(weights) + if "pretrained_backbone" in kwargs: + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") + weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None + weights_backbone = ResNet50Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = len(weights.meta["categories"]) + + trainable_backbone_layers = _validate_trainable_layers( + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 + ) + + backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers) + model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1: + overwrite_eps(model, 0.0) + + return model diff --git a/torchvision/prototype/transforms/presets.py b/torchvision/prototype/transforms/presets.py index 56d6fc60e02..81ad8cefbfe 100644 --- a/torchvision/prototype/transforms/presets.py +++ b/torchvision/prototype/transforms/presets.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Dict, Optional, Tuple import torch from torch import Tensor, nn @@ -7,22 +7,19 @@ from ...transforms import functional as F -__all__ = ["ConvertImageDtype", "ImageNetEval"] +__all__ = ["CocoEval", "ImageNetEval"] -# Allows handling of both PIL and Tensor images -class ConvertImageDtype(nn.Module): - def __init__(self, dtype: torch.dtype) -> None: - super().__init__() - self.dtype = dtype - - def forward(self, img: Tensor) -> Tensor: +class CocoEval(nn.Module): + def forward( + self, img: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if not isinstance(img, Tensor): img = F.pil_to_tensor(img) - return F.convert_image_dtype(img, self.dtype) + return F.convert_image_dtype(img, torch.float), target -class ImageNetEval: +class ImageNetEval(nn.Module): def __init__( self, crop_size: int, @@ -31,14 +28,14 @@ def __init__( std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, ) -> None: - self.transforms = T.Compose( - [ - T.Resize(resize_size, interpolation=interpolation), - T.CenterCrop(crop_size), - ConvertImageDtype(dtype=torch.float), - T.Normalize(mean=mean, std=std), - ] - ) - - def __call__(self, img: Tensor) -> Tensor: - return self.transforms(img) + super().__init__() + self._resize = T.Resize(resize_size, interpolation=interpolation) + self._crop = T.CenterCrop(crop_size) + self._normalize = T.Normalize(mean=mean, std=std) + + def forward(self, img: Tensor) -> Tensor: + img = self._crop(self._resize(img)) + if not isinstance(img, Tensor): + img = F.pil_to_tensor(img) + img = F.convert_image_dtype(img, torch.float) + return self._normalize(img)