Skip to content

Refactor Segmentation models #4646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion torchvision/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .segmentation import *
from .fcn import *
from .deeplabv3 import *
from .lraspp import *
9 changes: 9 additions & 0 deletions torchvision/models/segmentation/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch import nn, Tensor
from torch.nn import functional as F

from ..._internally_replaced_utils import load_state_dict_from_url


class _SimpleSegmentationModel(nn.Module):
__constants__ = ["aux_classifier"]
Expand Down Expand Up @@ -32,3 +34,10 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
result["aux"] = x

return result


def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
150 changes: 147 additions & 3 deletions torchvision/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
from typing import List
from typing import List, Optional

import torch
from torch import nn
from torch.nn import functional as F

from ._utils import _SimpleSegmentationModel
from .. import mobilenetv3
from .. import resnet
from ..feature_extraction import create_feature_extractor
from ._utils import _SimpleSegmentationModel, _load_weights
from .fcn import FCNHead


__all__ = ["DeepLabV3"]
__all__ = [
"DeepLabV3",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
]


model_urls = {
"deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
"deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
"deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
}


class DeepLabV3(_SimpleSegmentationModel):
Expand Down Expand Up @@ -95,3 +111,131 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
_res.append(conv(x))
res = torch.cat(_res, dim=1)
return self.project(res)


def _deeplabv3_resnet(
backbone: resnet.ResNet,
Copy link
Contributor Author

@datumbox datumbox Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We make all private model builders accept the pre-initialized backbones instead of passing the backbone_name. This allows us to reuse the methods on the multi-pretrained weights project.

num_classes: int,
aux: Optional[bool],
) -> DeepLabV3:
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = create_feature_extractor(backbone, return_layers)

aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = DeepLabHead(2048, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall we simplify the code by splitting the previous massive _segm_model() method.



def _deeplabv3_mobilenetv3(
backbone: mobilenetv3.MobileNetV3,
num_classes: int,
aux: Optional[bool],
) -> DeepLabV3:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_inplanes = backbone[aux_pos].out_channels
return_layers = {str(out_pos): "out"}
if aux:
return_layers[str(aux_pos)] = "aux"
backbone = create_feature_extractor(backbone, return_layers)

aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
classifier = DeepLabHead(out_inplanes, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)


def deeplabv3_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of **kwargs we directly expose the pretrained_backbone parameter publicly.

) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False

backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

if pretrained:
arch = "deeplabv3_resnet50_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
Comment on lines +178 to +180
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I suppose you've left this here (instead of putting it in _deeplabv3_resnet because it will be more aligned with your changes to the new weights?

Same for the backbone retrieval code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. I was going back and forth about this. If I were to put it in the builder methods, I would have to copy-paste the whole thing...

I think an additional final clean up wold be necessary prior moving the prototype work to main and there we would be able to move things around. This is a great candidate for such clean up.

return model


def deeplabv3_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): The number of classes
aux_loss (bool, optional): If True, include an auxiliary classifier
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False

backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

if pretrained:
arch = "deeplabv3_resnet101_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model


def deeplabv3_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False

backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)

if pretrained:
arch = "deeplabv3_mobilenet_v3_large_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
89 changes: 87 additions & 2 deletions torchvision/models/segmentation/fcn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from typing import Optional

from torch import nn

from ._utils import _SimpleSegmentationModel
from .. import resnet
from ..feature_extraction import create_feature_extractor
from ._utils import _SimpleSegmentationModel, _load_weights


__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"]


__all__ = ["FCN"]
model_urls = {
"fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
"fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
}


class FCN(_SimpleSegmentationModel):
Expand Down Expand Up @@ -35,3 +45,78 @@ def __init__(self, in_channels: int, channels: int) -> None:
]

super(FCNHead, self).__init__(*layers)


def _fcn_resnet(
backbone: resnet.ResNet,
num_classes: int,
aux: Optional[bool],
) -> FCN:
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = create_feature_extractor(backbone, return_layers)

aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = FCNHead(2048, num_classes)
return FCN(backbone, classifier, aux_classifier)


def fcn_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False

backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)

if pretrained:
arch = "fcn_resnet50_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model


def fcn_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False

backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)

if pretrained:
arch = "fcn_resnet101_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
57 changes: 55 additions & 2 deletions torchvision/models/segmentation/lraspp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from collections import OrderedDict
from typing import Dict
from typing import Any, Dict

from torch import nn, Tensor
from torch.nn import functional as F

from .. import mobilenetv3
from ..feature_extraction import create_feature_extractor
from ._utils import _load_weights

__all__ = ["LRASPP"]

__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"]


model_urls = {
"lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
}


class LRASPP(nn.Module):
Expand Down Expand Up @@ -68,3 +77,47 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor:
x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)

return self.low_classifier(low) + self.high_classifier(x)


def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})

return LRASPP(backbone, low_channels, high_channels, num_classes)


def lraspp_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
pretrained_backbone: bool = True,
**kwargs: Any,
) -> LRASPP:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
if pretrained:
pretrained_backbone = False

backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes)

if pretrained:
arch = "lraspp_mobilenet_v3_large_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
Loading