Skip to content

Commit a1d6b31

Browse files
authored
Refactor Segmentation models (#4646)
* Move FCN methods to itsown package. * Fix lint. * Move LRASPP methods to their own package. * Move DeepLabV3 methods to their own package. * Adding deprecation warning for torchvision.models.segmentation.segmentation. * Refactoring deeplab. * Setting aux default to false. * Fixing imports. * Passing backbones instead of backbone names to builders. * Fixing mypy * Addressing review comments. * Correcting typing. * Restoring special handling for references.
1 parent e4a4a29 commit a1d6b31

File tree

6 files changed

+305
-247
lines changed

6 files changed

+305
-247
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from .segmentation import *
21
from .fcn import *
32
from .deeplabv3 import *
43
from .lraspp import *

torchvision/models/segmentation/_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from torch import nn, Tensor
55
from torch.nn import functional as F
66

7+
from ..._internally_replaced_utils import load_state_dict_from_url
8+
79

810
class _SimpleSegmentationModel(nn.Module):
911
__constants__ = ["aux_classifier"]
@@ -32,3 +34,10 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
3234
result["aux"] = x
3335

3436
return result
37+
38+
39+
def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None:
40+
if model_url is None:
41+
raise ValueError("No checkpoint is available for {}".format(arch))
42+
state_dict = load_state_dict_from_url(model_url, progress=progress)
43+
model.load_state_dict(state_dict)

torchvision/models/segmentation/deeplabv3.py

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
import torch
44
from torch import nn
55
from torch.nn import functional as F
66

7-
from ._utils import _SimpleSegmentationModel
7+
from .. import mobilenetv3
8+
from .. import resnet
9+
from ..feature_extraction import create_feature_extractor
10+
from ._utils import _SimpleSegmentationModel, _load_weights
11+
from .fcn import FCNHead
812

913

10-
__all__ = ["DeepLabV3"]
14+
__all__ = [
15+
"DeepLabV3",
16+
"deeplabv3_resnet50",
17+
"deeplabv3_resnet101",
18+
"deeplabv3_mobilenet_v3_large",
19+
]
20+
21+
22+
model_urls = {
23+
"deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
24+
"deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
25+
"deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
26+
}
1127

1228

1329
class DeepLabV3(_SimpleSegmentationModel):
@@ -95,3 +111,131 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
95111
_res.append(conv(x))
96112
res = torch.cat(_res, dim=1)
97113
return self.project(res)
114+
115+
116+
def _deeplabv3_resnet(
117+
backbone: resnet.ResNet,
118+
num_classes: int,
119+
aux: Optional[bool],
120+
) -> DeepLabV3:
121+
return_layers = {"layer4": "out"}
122+
if aux:
123+
return_layers["layer3"] = "aux"
124+
backbone = create_feature_extractor(backbone, return_layers)
125+
126+
aux_classifier = FCNHead(1024, num_classes) if aux else None
127+
classifier = DeepLabHead(2048, num_classes)
128+
return DeepLabV3(backbone, classifier, aux_classifier)
129+
130+
131+
def _deeplabv3_mobilenetv3(
132+
backbone: mobilenetv3.MobileNetV3,
133+
num_classes: int,
134+
aux: Optional[bool],
135+
) -> DeepLabV3:
136+
backbone = backbone.features
137+
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
138+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
139+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
140+
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
141+
out_inplanes = backbone[out_pos].out_channels
142+
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
143+
aux_inplanes = backbone[aux_pos].out_channels
144+
return_layers = {str(out_pos): "out"}
145+
if aux:
146+
return_layers[str(aux_pos)] = "aux"
147+
backbone = create_feature_extractor(backbone, return_layers)
148+
149+
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
150+
classifier = DeepLabHead(out_inplanes, num_classes)
151+
return DeepLabV3(backbone, classifier, aux_classifier)
152+
153+
154+
def deeplabv3_resnet50(
155+
pretrained: bool = False,
156+
progress: bool = True,
157+
num_classes: int = 21,
158+
aux_loss: Optional[bool] = None,
159+
pretrained_backbone: bool = True,
160+
) -> DeepLabV3:
161+
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
162+
163+
Args:
164+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
165+
contains the same classes as Pascal VOC
166+
progress (bool): If True, displays a progress bar of the download to stderr
167+
num_classes (int): number of output classes of the model (including the background)
168+
aux_loss (bool, optional): If True, it uses an auxiliary loss
169+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
170+
"""
171+
if pretrained:
172+
aux_loss = True
173+
pretrained_backbone = False
174+
175+
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
176+
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
177+
178+
if pretrained:
179+
arch = "deeplabv3_resnet50_coco"
180+
_load_weights(arch, model, model_urls.get(arch, None), progress)
181+
return model
182+
183+
184+
def deeplabv3_resnet101(
185+
pretrained: bool = False,
186+
progress: bool = True,
187+
num_classes: int = 21,
188+
aux_loss: Optional[bool] = None,
189+
pretrained_backbone: bool = True,
190+
) -> DeepLabV3:
191+
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
192+
193+
Args:
194+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
195+
contains the same classes as Pascal VOC
196+
progress (bool): If True, displays a progress bar of the download to stderr
197+
num_classes (int): The number of classes
198+
aux_loss (bool, optional): If True, include an auxiliary classifier
199+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
200+
"""
201+
if pretrained:
202+
aux_loss = True
203+
pretrained_backbone = False
204+
205+
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
206+
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
207+
208+
if pretrained:
209+
arch = "deeplabv3_resnet101_coco"
210+
_load_weights(arch, model, model_urls.get(arch, None), progress)
211+
return model
212+
213+
214+
def deeplabv3_mobilenet_v3_large(
215+
pretrained: bool = False,
216+
progress: bool = True,
217+
num_classes: int = 21,
218+
aux_loss: Optional[bool] = None,
219+
pretrained_backbone: bool = True,
220+
) -> DeepLabV3:
221+
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
222+
223+
Args:
224+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
225+
contains the same classes as Pascal VOC
226+
progress (bool): If True, displays a progress bar of the download to stderr
227+
num_classes (int): number of output classes of the model (including the background)
228+
aux_loss (bool, optional): If True, it uses an auxiliary loss
229+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
230+
"""
231+
if pretrained:
232+
aux_loss = True
233+
pretrained_backbone = False
234+
235+
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
236+
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
237+
238+
if pretrained:
239+
arch = "deeplabv3_mobilenet_v3_large_coco"
240+
_load_weights(arch, model, model_urls.get(arch, None), progress)
241+
return model

torchvision/models/segmentation/fcn.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1+
from typing import Optional
2+
13
from torch import nn
24

3-
from ._utils import _SimpleSegmentationModel
5+
from .. import resnet
6+
from ..feature_extraction import create_feature_extractor
7+
from ._utils import _SimpleSegmentationModel, _load_weights
8+
9+
10+
__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"]
411

512

6-
__all__ = ["FCN"]
13+
model_urls = {
14+
"fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
15+
"fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
16+
}
717

818

919
class FCN(_SimpleSegmentationModel):
@@ -35,3 +45,78 @@ def __init__(self, in_channels: int, channels: int) -> None:
3545
]
3646

3747
super(FCNHead, self).__init__(*layers)
48+
49+
50+
def _fcn_resnet(
51+
backbone: resnet.ResNet,
52+
num_classes: int,
53+
aux: Optional[bool],
54+
) -> FCN:
55+
return_layers = {"layer4": "out"}
56+
if aux:
57+
return_layers["layer3"] = "aux"
58+
backbone = create_feature_extractor(backbone, return_layers)
59+
60+
aux_classifier = FCNHead(1024, num_classes) if aux else None
61+
classifier = FCNHead(2048, num_classes)
62+
return FCN(backbone, classifier, aux_classifier)
63+
64+
65+
def fcn_resnet50(
66+
pretrained: bool = False,
67+
progress: bool = True,
68+
num_classes: int = 21,
69+
aux_loss: Optional[bool] = None,
70+
pretrained_backbone: bool = True,
71+
) -> FCN:
72+
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
73+
74+
Args:
75+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
76+
contains the same classes as Pascal VOC
77+
progress (bool): If True, displays a progress bar of the download to stderr
78+
num_classes (int): number of output classes of the model (including the background)
79+
aux_loss (bool, optional): If True, it uses an auxiliary loss
80+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
81+
"""
82+
if pretrained:
83+
aux_loss = True
84+
pretrained_backbone = False
85+
86+
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
87+
model = _fcn_resnet(backbone, num_classes, aux_loss)
88+
89+
if pretrained:
90+
arch = "fcn_resnet50_coco"
91+
_load_weights(arch, model, model_urls.get(arch, None), progress)
92+
return model
93+
94+
95+
def fcn_resnet101(
96+
pretrained: bool = False,
97+
progress: bool = True,
98+
num_classes: int = 21,
99+
aux_loss: Optional[bool] = None,
100+
pretrained_backbone: bool = True,
101+
) -> FCN:
102+
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
103+
104+
Args:
105+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
106+
contains the same classes as Pascal VOC
107+
progress (bool): If True, displays a progress bar of the download to stderr
108+
num_classes (int): number of output classes of the model (including the background)
109+
aux_loss (bool, optional): If True, it uses an auxiliary loss
110+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
111+
"""
112+
if pretrained:
113+
aux_loss = True
114+
pretrained_backbone = False
115+
116+
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
117+
model = _fcn_resnet(backbone, num_classes, aux_loss)
118+
119+
if pretrained:
120+
arch = "fcn_resnet101_coco"
121+
_load_weights(arch, model, model_urls.get(arch, None), progress)
122+
return model

torchvision/models/segmentation/lraspp.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
from collections import OrderedDict
2-
from typing import Dict
2+
from typing import Any, Dict
33

44
from torch import nn, Tensor
55
from torch.nn import functional as F
66

7+
from .. import mobilenetv3
8+
from ..feature_extraction import create_feature_extractor
9+
from ._utils import _load_weights
710

8-
__all__ = ["LRASPP"]
11+
12+
__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"]
13+
14+
15+
model_urls = {
16+
"lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
17+
}
918

1019

1120
class LRASPP(nn.Module):
@@ -68,3 +77,47 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor:
6877
x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)
6978

7079
return self.low_classifier(low) + self.high_classifier(x)
80+
81+
82+
def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP:
83+
backbone = backbone.features
84+
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
85+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
86+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
87+
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
88+
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
89+
low_channels = backbone[low_pos].out_channels
90+
high_channels = backbone[high_pos].out_channels
91+
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
92+
93+
return LRASPP(backbone, low_channels, high_channels, num_classes)
94+
95+
96+
def lraspp_mobilenet_v3_large(
97+
pretrained: bool = False,
98+
progress: bool = True,
99+
num_classes: int = 21,
100+
pretrained_backbone: bool = True,
101+
**kwargs: Any,
102+
) -> LRASPP:
103+
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
104+
105+
Args:
106+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
107+
contains the same classes as Pascal VOC
108+
progress (bool): If True, displays a progress bar of the download to stderr
109+
num_classes (int): number of output classes of the model (including the background)
110+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
111+
"""
112+
if kwargs.pop("aux_loss", False):
113+
raise NotImplementedError("This model does not use auxiliary loss")
114+
if pretrained:
115+
pretrained_backbone = False
116+
117+
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
118+
model = _lraspp_mobilenetv3(backbone, num_classes)
119+
120+
if pretrained:
121+
arch = "lraspp_mobilenet_v3_large_coco"
122+
_load_weights(arch, model, model_urls.get(arch, None), progress)
123+
return model

0 commit comments

Comments
 (0)