-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Refactor Segmentation models #4646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5e5a68c
93cdcf6
7213a9d
bee91a1
c4c5ac4
c341ed3
43c677e
f534046
de1d2ad
ede7980
198683a
ad34b89
7a96c40
39392d7
7fef773
49816fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from .segmentation import * | ||
from .fcn import * | ||
from .deeplabv3 import * | ||
from .lraspp import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,29 @@ | ||
from typing import List | ||
from typing import List, Optional | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from ._utils import _SimpleSegmentationModel | ||
from .. import mobilenetv3 | ||
from .. import resnet | ||
from ..feature_extraction import create_feature_extractor | ||
from ._utils import _SimpleSegmentationModel, _load_weights | ||
from .fcn import FCNHead | ||
|
||
|
||
__all__ = ["DeepLabV3"] | ||
__all__ = [ | ||
"DeepLabV3", | ||
"deeplabv3_resnet50", | ||
"deeplabv3_resnet101", | ||
"deeplabv3_mobilenet_v3_large", | ||
] | ||
|
||
|
||
model_urls = { | ||
"deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", | ||
"deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", | ||
"deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", | ||
} | ||
|
||
|
||
class DeepLabV3(_SimpleSegmentationModel): | ||
|
@@ -95,3 +111,131 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
_res.append(conv(x)) | ||
res = torch.cat(_res, dim=1) | ||
return self.project(res) | ||
|
||
|
||
def _deeplabv3_resnet( | ||
backbone: resnet.ResNet, | ||
num_classes: int, | ||
aux: Optional[bool], | ||
) -> DeepLabV3: | ||
return_layers = {"layer4": "out"} | ||
if aux: | ||
return_layers["layer3"] = "aux" | ||
backbone = create_feature_extractor(backbone, return_layers) | ||
|
||
aux_classifier = FCNHead(1024, num_classes) if aux else None | ||
classifier = DeepLabHead(2048, num_classes) | ||
return DeepLabV3(backbone, classifier, aux_classifier) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overall we simplify the code by splitting the previous massive |
||
|
||
|
||
def _deeplabv3_mobilenetv3( | ||
backbone: mobilenetv3.MobileNetV3, | ||
num_classes: int, | ||
aux: Optional[bool], | ||
) -> DeepLabV3: | ||
backbone = backbone.features | ||
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. | ||
# The first and last blocks are always included because they are the C0 (conv1) and Cn. | ||
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] | ||
out_pos = stage_indices[-1] # use C5 which has output_stride = 16 | ||
out_inplanes = backbone[out_pos].out_channels | ||
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 | ||
aux_inplanes = backbone[aux_pos].out_channels | ||
return_layers = {str(out_pos): "out"} | ||
if aux: | ||
return_layers[str(aux_pos)] = "aux" | ||
backbone = create_feature_extractor(backbone, return_layers) | ||
|
||
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None | ||
classifier = DeepLabHead(out_inplanes, num_classes) | ||
return DeepLabV3(backbone, classifier, aux_classifier) | ||
|
||
|
||
def deeplabv3_resnet50( | ||
pretrained: bool = False, | ||
progress: bool = True, | ||
num_classes: int = 21, | ||
aux_loss: Optional[bool] = None, | ||
pretrained_backbone: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of |
||
) -> DeepLabV3: | ||
"""Constructs a DeepLabV3 model with a ResNet-50 backbone. | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which | ||
contains the same classes as Pascal VOC | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
num_classes (int): number of output classes of the model (including the background) | ||
aux_loss (bool, optional): If True, it uses an auxiliary loss | ||
pretrained_backbone (bool): If True, the backbone will be pre-trained. | ||
""" | ||
if pretrained: | ||
aux_loss = True | ||
pretrained_backbone = False | ||
|
||
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) | ||
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) | ||
|
||
if pretrained: | ||
arch = "deeplabv3_resnet50_coco" | ||
_load_weights(arch, model, model_urls.get(arch, None), progress) | ||
Comment on lines
+178
to
+180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I suppose you've left this here (instead of putting it in Same for the backbone retrieval code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes exactly. I was going back and forth about this. If I were to put it in the builder methods, I would have to copy-paste the whole thing... I think an additional final clean up wold be necessary prior moving the prototype work to main and there we would be able to move things around. This is a great candidate for such clean up. |
||
return model | ||
|
||
|
||
def deeplabv3_resnet101( | ||
pretrained: bool = False, | ||
progress: bool = True, | ||
num_classes: int = 21, | ||
aux_loss: Optional[bool] = None, | ||
pretrained_backbone: bool = True, | ||
) -> DeepLabV3: | ||
"""Constructs a DeepLabV3 model with a ResNet-101 backbone. | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which | ||
contains the same classes as Pascal VOC | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
num_classes (int): The number of classes | ||
aux_loss (bool, optional): If True, include an auxiliary classifier | ||
pretrained_backbone (bool): If True, the backbone will be pre-trained. | ||
""" | ||
if pretrained: | ||
aux_loss = True | ||
pretrained_backbone = False | ||
|
||
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) | ||
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) | ||
|
||
if pretrained: | ||
arch = "deeplabv3_resnet101_coco" | ||
_load_weights(arch, model, model_urls.get(arch, None), progress) | ||
return model | ||
|
||
|
||
def deeplabv3_mobilenet_v3_large( | ||
pretrained: bool = False, | ||
progress: bool = True, | ||
num_classes: int = 21, | ||
aux_loss: Optional[bool] = None, | ||
pretrained_backbone: bool = True, | ||
) -> DeepLabV3: | ||
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which | ||
contains the same classes as Pascal VOC | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
num_classes (int): number of output classes of the model (including the background) | ||
aux_loss (bool, optional): If True, it uses an auxiliary loss | ||
pretrained_backbone (bool): If True, the backbone will be pre-trained. | ||
""" | ||
if pretrained: | ||
aux_loss = True | ||
pretrained_backbone = False | ||
|
||
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) | ||
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) | ||
|
||
if pretrained: | ||
arch = "deeplabv3_mobilenet_v3_large_coco" | ||
_load_weights(arch, model, model_urls.get(arch, None), progress) | ||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We make all private model builders accept the pre-initialized backbones instead of passing the
backbone_name
. This allows us to reuse the methods on the multi-pretrained weights project.