diff --git a/torchvision/models/segmentation/__init__.py b/torchvision/models/segmentation/__init__.py index fb6633d7fb5..1765502d693 100644 --- a/torchvision/models/segmentation/__init__.py +++ b/torchvision/models/segmentation/__init__.py @@ -1,4 +1,3 @@ -from .segmentation import * from .fcn import * from .deeplabv3 import * from .lraspp import * diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 0e9a9477838..1286482dde5 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -4,6 +4,8 @@ from torch import nn, Tensor from torch.nn import functional as F +from ..._internally_replaced_utils import load_state_dict_from_url + class _SimpleSegmentationModel(nn.Module): __constants__ = ["aux_classifier"] @@ -32,3 +34,10 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: result["aux"] = x return result + + +def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None: + if model_url is None: + raise ValueError("No checkpoint is available for {}".format(arch)) + state_dict = load_state_dict_from_url(model_url, progress=progress) + model.load_state_dict(state_dict) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index a8f06bd89bd..3b2a1d12d9a 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -1,13 +1,29 @@ -from typing import List +from typing import List, Optional import torch from torch import nn from torch.nn import functional as F -from ._utils import _SimpleSegmentationModel +from .. import mobilenetv3 +from .. import resnet +from ..feature_extraction import create_feature_extractor +from ._utils import _SimpleSegmentationModel, _load_weights +from .fcn import FCNHead -__all__ = ["DeepLabV3"] +__all__ = [ + "DeepLabV3", + "deeplabv3_resnet50", + "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", +] + + +model_urls = { + "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", +} class DeepLabV3(_SimpleSegmentationModel): @@ -95,3 +111,131 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: _res.append(conv(x)) res = torch.cat(_res, dim=1) return self.project(res) + + +def _deeplabv3_resnet( + backbone: resnet.ResNet, + num_classes: int, + aux: Optional[bool], +) -> DeepLabV3: + return_layers = {"layer4": "out"} + if aux: + return_layers["layer3"] = "aux" + backbone = create_feature_extractor(backbone, return_layers) + + aux_classifier = FCNHead(1024, num_classes) if aux else None + classifier = DeepLabHead(2048, num_classes) + return DeepLabV3(backbone, classifier, aux_classifier) + + +def _deeplabv3_mobilenetv3( + backbone: mobilenetv3.MobileNetV3, + num_classes: int, + aux: Optional[bool], +) -> DeepLabV3: + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + out_pos = stage_indices[-1] # use C5 which has output_stride = 16 + out_inplanes = backbone[out_pos].out_channels + aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + aux_inplanes = backbone[aux_pos].out_channels + return_layers = {str(out_pos): "out"} + if aux: + return_layers[str(aux_pos)] = "aux" + backbone = create_feature_extractor(backbone, return_layers) + + aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None + classifier = DeepLabHead(out_inplanes, num_classes) + return DeepLabV3(backbone, classifier, aux_classifier) + + +def deeplabv3_resnet50( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + pretrained_backbone: bool = True, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a ResNet-50 backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + pretrained_backbone (bool): If True, the backbone will be pre-trained. + """ + if pretrained: + aux_loss = True + pretrained_backbone = False + + backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) + + if pretrained: + arch = "deeplabv3_resnet50_coco" + _load_weights(arch, model, model_urls.get(arch, None), progress) + return model + + +def deeplabv3_resnet101( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + pretrained_backbone: bool = True, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a ResNet-101 backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): The number of classes + aux_loss (bool, optional): If True, include an auxiliary classifier + pretrained_backbone (bool): If True, the backbone will be pre-trained. + """ + if pretrained: + aux_loss = True + pretrained_backbone = False + + backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) + + if pretrained: + arch = "deeplabv3_resnet101_coco" + _load_weights(arch, model, model_urls.get(arch, None), progress) + return model + + +def deeplabv3_mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + pretrained_backbone: bool = True, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + pretrained_backbone (bool): If True, the backbone will be pre-trained. + """ + if pretrained: + aux_loss = True + pretrained_backbone = False + + backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) + model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) + + if pretrained: + arch = "deeplabv3_mobilenet_v3_large_coco" + _load_weights(arch, model, model_urls.get(arch, None), progress) + return model diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 6a935e9ac48..fe226be2ce1 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -1,9 +1,19 @@ +from typing import Optional + from torch import nn -from ._utils import _SimpleSegmentationModel +from .. import resnet +from ..feature_extraction import create_feature_extractor +from ._utils import _SimpleSegmentationModel, _load_weights + + +__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"] -__all__ = ["FCN"] +model_urls = { + "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", + "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", +} class FCN(_SimpleSegmentationModel): @@ -35,3 +45,78 @@ def __init__(self, in_channels: int, channels: int) -> None: ] super(FCNHead, self).__init__(*layers) + + +def _fcn_resnet( + backbone: resnet.ResNet, + num_classes: int, + aux: Optional[bool], +) -> FCN: + return_layers = {"layer4": "out"} + if aux: + return_layers["layer3"] = "aux" + backbone = create_feature_extractor(backbone, return_layers) + + aux_classifier = FCNHead(1024, num_classes) if aux else None + classifier = FCNHead(2048, num_classes) + return FCN(backbone, classifier, aux_classifier) + + +def fcn_resnet50( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + pretrained_backbone: bool = True, +) -> FCN: + """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + pretrained_backbone (bool): If True, the backbone will be pre-trained. + """ + if pretrained: + aux_loss = True + pretrained_backbone = False + + backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + model = _fcn_resnet(backbone, num_classes, aux_loss) + + if pretrained: + arch = "fcn_resnet50_coco" + _load_weights(arch, model, model_urls.get(arch, None), progress) + return model + + +def fcn_resnet101( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + aux_loss: Optional[bool] = None, + pretrained_backbone: bool = True, +) -> FCN: + """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + pretrained_backbone (bool): If True, the backbone will be pre-trained. + """ + if pretrained: + aux_loss = True + pretrained_backbone = False + + backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + model = _fcn_resnet(backbone, num_classes, aux_loss) + + if pretrained: + arch = "fcn_resnet101_coco" + _load_weights(arch, model, model_urls.get(arch, None), progress) + return model diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 654e2811315..df4b21e2ee9 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -1,11 +1,20 @@ from collections import OrderedDict -from typing import Dict +from typing import Any, Dict from torch import nn, Tensor from torch.nn import functional as F +from .. import mobilenetv3 +from ..feature_extraction import create_feature_extractor +from ._utils import _load_weights -__all__ = ["LRASPP"] + +__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"] + + +model_urls = { + "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", +} class LRASPP(nn.Module): @@ -68,3 +77,47 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor: x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False) return self.low_classifier(low) + self.high_classifier(x) + + +def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP: + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + low_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + high_pos = stage_indices[-1] # use C5 which has output_stride = 16 + low_channels = backbone[low_pos].out_channels + high_channels = backbone[high_pos].out_channels + backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"}) + + return LRASPP(backbone, low_channels, high_channels, num_classes) + + +def lraspp_mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 21, + pretrained_backbone: bool = True, + **kwargs: Any, +) -> LRASPP: + """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, the backbone will be pre-trained. + """ + if kwargs.pop("aux_loss", False): + raise NotImplementedError("This model does not use auxiliary loss") + if pretrained: + pretrained_backbone = False + + backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) + model = _lraspp_mobilenetv3(backbone, num_classes) + + if pretrained: + arch = "lraspp_mobilenet_v3_large_coco" + _load_weights(arch, model, model_urls.get(arch, None), progress) + return model diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index c19e36e4705..1c1d56f487c 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,242 +1,10 @@ -from typing import Any, Optional +import warnings -from torch import nn +# Import all methods/classes for BC: +from . import * # noqa: F401, F403 -from ..._internally_replaced_utils import load_state_dict_from_url -from .. import mobilenetv3 -from .. import resnet -from ..feature_extraction import create_feature_extractor -from .deeplabv3 import DeepLabHead, DeepLabV3 -from .fcn import FCN, FCNHead -from .lraspp import LRASPP - -__all__ = [ - "fcn_resnet50", - "fcn_resnet101", - "deeplabv3_resnet50", - "deeplabv3_resnet101", - "deeplabv3_mobilenet_v3_large", - "lraspp_mobilenet_v3_large", -] - - -model_urls = { - "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", -} - - -def _segm_model( - name: str, backbone_name: str, num_classes: int, aux: Optional[bool], pretrained_backbone: bool = True -) -> nn.Module: - if "resnet" in backbone_name: - backbone = resnet.__dict__[backbone_name]( - pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True] - ) - out_layer = "layer4" - out_inplanes = 2048 - aux_layer = "layer3" - aux_inplanes = 1024 - elif "mobilenet_v3" in backbone_name: - backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features - - # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. - # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] - out_pos = stage_indices[-1] # use C5 which has output_stride = 16 - out_layer = str(out_pos) - out_inplanes = backbone[out_pos].out_channels - aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 - aux_layer = str(aux_pos) - aux_inplanes = backbone[aux_pos].out_channels - else: - raise NotImplementedError("backbone {} is not supported as of now".format(backbone_name)) - - return_layers = {out_layer: "out"} - if aux: - return_layers[aux_layer] = "aux" - backbone = create_feature_extractor(backbone, return_layers) - - aux_classifier = None - if aux: - aux_classifier = FCNHead(aux_inplanes, num_classes) - - model_map = { - "deeplabv3": (DeepLabHead, DeepLabV3), - "fcn": (FCNHead, FCN), - } - classifier = model_map[name][0](out_inplanes, num_classes) - base_model = model_map[name][1] - - model = base_model(backbone, classifier, aux_classifier) - return model - - -def _load_model( - arch_type: str, - backbone: str, - pretrained: bool, - progress: bool, - num_classes: int, - aux_loss: Optional[bool], - **kwargs: Any, -) -> nn.Module: - if pretrained: - aux_loss = True - kwargs["pretrained_backbone"] = False - model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs) - if pretrained: - _load_weights(model, arch_type, backbone, progress) - return model - - -def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None: - arch = arch_type + "_" + backbone + "_coco" - model_url = model_urls.get(arch, None) - if model_url is None: - raise NotImplementedError("pretrained {} is not supported as of now".format(arch)) - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) - - -def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP: - backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features - - # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. - # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] - low_pos = stage_indices[-4] # use C2 here which has output_stride = 8 - high_pos = stage_indices[-1] # use C5 which has output_stride = 16 - low_channels = backbone[low_pos].out_channels - high_channels = backbone[high_pos].out_channels - - backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"}) - - model = LRASPP(backbone, low_channels, high_channels, num_classes) - return model - - -def fcn_resnet50( - pretrained: bool = False, - progress: bool = True, - num_classes: int = 21, - aux_loss: Optional[bool] = None, - **kwargs: Any, -) -> nn.Module: - """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. - - Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC - progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - aux_loss (bool): If True, it uses an auxiliary loss - """ - return _load_model("fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs) - - -def fcn_resnet101( - pretrained: bool = False, - progress: bool = True, - num_classes: int = 21, - aux_loss: Optional[bool] = None, - **kwargs: Any, -) -> nn.Module: - """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. - - Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC - progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - aux_loss (bool): If True, it uses an auxiliary loss - """ - return _load_model("fcn", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs) - - -def deeplabv3_resnet50( - pretrained: bool = False, - progress: bool = True, - num_classes: int = 21, - aux_loss: Optional[bool] = None, - **kwargs: Any, -) -> nn.Module: - """Constructs a DeepLabV3 model with a ResNet-50 backbone. - - Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC - progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - aux_loss (bool): If True, it uses an auxiliary loss - """ - return _load_model("deeplabv3", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs) - - -def deeplabv3_resnet101( - pretrained: bool = False, - progress: bool = True, - num_classes: int = 21, - aux_loss: Optional[bool] = None, - **kwargs: Any, -) -> nn.Module: - """Constructs a DeepLabV3 model with a ResNet-101 backbone. - - Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC - progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): The number of classes - aux_loss (bool): If True, include an auxiliary classifier - """ - return _load_model("deeplabv3", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs) - - -def deeplabv3_mobilenet_v3_large( - pretrained: bool = False, - progress: bool = True, - num_classes: int = 21, - aux_loss: Optional[bool] = None, - **kwargs: Any, -) -> nn.Module: - """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. - - Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC - progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - aux_loss (bool): If True, it uses an auxiliary loss - """ - return _load_model("deeplabv3", "mobilenet_v3_large", pretrained, progress, num_classes, aux_loss, **kwargs) - - -def lraspp_mobilenet_v3_large( - pretrained: bool = False, progress: bool = True, num_classes: int = 21, **kwargs: Any -) -> nn.Module: - """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. - - Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC - progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - """ - if kwargs.pop("aux_loss", False): - raise NotImplementedError("This model does not use auxiliary loss") - - backbone_name = "mobilenet_v3_large" - if pretrained: - kwargs["pretrained_backbone"] = False - model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs) - - if pretrained: - _load_weights(model, "lraspp", backbone_name, progress) - - return model +warnings.warn( + "The 'torchvision.models.segmentation.segmentation' module is deprecated. Please use directly the parent module " + "instead." +)