diff --git a/mypy.ini b/mypy.ini index a2733d3ae3b..a9d62f38e7b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -21,10 +21,6 @@ ignore_errors=True ignore_errors = True -[mypy-torchvision.models.detection.backbone_utils] - -ignore_errors = True - [mypy-torchvision.models.detection.transform] ignore_errors = True diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index a2a45c43733..e697e06929a 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,7 +1,7 @@ import warnings -from typing import List, Optional +from typing import Callable, Dict, Optional, List -from torch import nn +from torch import nn, Tensor from torchvision.ops import misc as misc_nn_ops from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock @@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module): out_channels (int): the number of channels in the FPN """ - def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): + def __init__( + self, + backbone: nn.Module, + return_layers: Dict[str, str], + in_channels_list: List[int], + out_channels: int, + extra_blocks: Optional[ExtraFPNBlock] = None, + ) -> None: super(BackboneWithFPN, self).__init__() if extra_blocks is None: @@ -43,20 +50,20 @@ def __init__(self, backbone, return_layers, in_channels_list, out_channels, extr ) self.out_channels = out_channels - def forward(self, x): + def forward(self, x: Tensor) -> Dict[str, Tensor]: x = self.body(x) x = self.fpn(x) return x def resnet_fpn_backbone( - backbone_name, - pretrained, - norm_layer=misc_nn_ops.FrozenBatchNorm2d, - trainable_layers=3, - returned_layers=None, - extra_blocks=None, -): + backbone_name: str, + pretrained: bool, + norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, + trainable_layers: int = 3, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> BackboneWithFPN: """ Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. @@ -80,7 +87,7 @@ def resnet_fpn_backbone( backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2' pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet - norm_layer (torchvision.ops): it is recommended to use the default value. For details visit: + norm_layer (callable): it is recommended to use the default value. For details visit: (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. @@ -101,7 +108,8 @@ def _resnet_backbone_config( trainable_layers: int, returned_layers: Optional[List[int]], extra_blocks: Optional[ExtraFPNBlock], -): +) -> BackboneWithFPN: + # select layers that wont be frozen assert 0 <= trainable_layers <= 5 layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] @@ -125,8 +133,13 @@ def _resnet_backbone_config( return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) -def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value): - # dont freeze any layers if pretrained model or backbone is not used +def _validate_trainable_layers( + pretrained: bool, + trainable_backbone_layers: Optional[int], + max_value: int, + default_value: int, +) -> int: + # don't freeze any layers if pretrained model or backbone is not used if not pretrained: if trainable_backbone_layers is not None: warnings.warn( @@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, def mobilenet_backbone( - backbone_name, - pretrained, - fpn, - norm_layer=misc_nn_ops.FrozenBatchNorm2d, - trainable_layers=2, - returned_layers=None, - extra_blocks=None, -): + backbone_name: str, + pretrained: bool, + fpn: bool, + norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, + trainable_layers: int = 2, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> nn.Module: + backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -185,5 +199,5 @@ def mobilenet_backbone( # depthwise linear combination of channels to reduce their size nn.Conv2d(backbone[-1].out_channels, out_channels, 1), ) - m.out_channels = out_channels + m.out_channels = out_channels # type: ignore[assignment] return m diff --git a/torchvision/prototype/models/detection/backbone_utils.py b/torchvision/prototype/models/detection/backbone_utils.py index 9893ebf8e5d..d95b1ab52f0 100644 --- a/torchvision/prototype/models/detection/backbone_utils.py +++ b/torchvision/prototype/models/detection/backbone_utils.py @@ -1,14 +1,20 @@ -from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config +from typing import Callable, Optional, List + +from torch import nn + +from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock from .. import resnet +from .._api import Weights def resnet_fpn_backbone( - backbone_name, - weights, - norm_layer=misc_nn_ops.FrozenBatchNorm2d, - trainable_layers=3, - returned_layers=None, - extra_blocks=None, -): + backbone_name: str, + weights: Optional[Weights], + norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, + trainable_layers: int = 3, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> BackboneWithFPN: + backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)