Skip to content

Commit de1d2ad

Browse files
committed
Passing backbones instead of backbone names to builders.
1 parent f534046 commit de1d2ad

File tree

3 files changed

+86
-80
lines changed

3 files changed

+86
-80
lines changed

torchvision/models/segmentation/deeplabv3.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List
1+
from typing import List
22

33
import torch
44
from torch import nn
@@ -114,48 +114,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
114114

115115

116116
def _deeplabv3_resnet(
117-
backbone_name: str,
118-
pretrained: bool,
119-
progress: bool,
117+
backbone: resnet.ResNet,
120118
num_classes: int,
121119
aux: bool,
122-
pretrained_backbone: bool = True,
123120
) -> DeepLabV3:
124-
if pretrained:
125-
aux = True
126-
pretrained_backbone = False
127-
128-
backbone = resnet.__dict__[backbone_name](
129-
pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]
130-
)
131121
return_layers = {"layer4": "out"}
132122
if aux:
133123
return_layers["layer3"] = "aux"
134124
backbone = create_feature_extractor(backbone, return_layers)
135125

136126
aux_classifier = FCNHead(1024, num_classes) if aux else None
137127
classifier = DeepLabHead(2048, num_classes)
138-
model = DeepLabV3(backbone, classifier, aux_classifier)
139-
140-
if pretrained:
141-
arch = "deeplabv3_" + backbone_name + "_coco"
142-
_load_weights(arch, model, model_urls.get(arch, None), progress)
143-
return model
128+
return DeepLabV3(backbone, classifier, aux_classifier)
144129

145130

146131
def _deeplabv3_mobilenetv3(
147-
backbone_name: str,
148-
pretrained: bool,
149-
progress: bool,
132+
backbone: mobilenetv3.MobileNetV3,
150133
num_classes: int,
151134
aux: bool,
152-
pretrained_backbone: bool = True,
153135
) -> DeepLabV3:
154-
if pretrained:
155-
aux = True
156-
pretrained_backbone = False
157-
158-
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
159136
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
160137
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
161138
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
@@ -170,20 +147,15 @@ def _deeplabv3_mobilenetv3(
170147

171148
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
172149
classifier = DeepLabHead(out_inplanes, num_classes)
173-
model = DeepLabV3(backbone, classifier, aux_classifier)
174-
175-
if pretrained:
176-
arch = "deeplabv3_" + backbone_name + "_coco"
177-
_load_weights(arch, model, model_urls.get(arch, None), progress)
178-
return model
150+
return DeepLabV3(backbone, classifier, aux_classifier)
179151

180152

181153
def deeplabv3_resnet50(
182154
pretrained: bool = False,
183155
progress: bool = True,
184156
num_classes: int = 21,
185157
aux_loss: bool = False,
186-
**kwargs: Any,
158+
pretrained_backbone: bool = True,
187159
) -> DeepLabV3:
188160
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
189161
@@ -193,16 +165,27 @@ def deeplabv3_resnet50(
193165
progress (bool): If True, displays a progress bar of the download to stderr
194166
num_classes (int): number of output classes of the model (including the background)
195167
aux_loss (bool): If True, it uses an auxiliary loss
168+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
196169
"""
197-
return _deeplabv3_resnet("resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
170+
if pretrained:
171+
aux_loss = True
172+
pretrained_backbone = False
173+
174+
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
175+
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
176+
177+
if pretrained:
178+
arch = "deeplabv3_resnet50_coco"
179+
_load_weights(arch, model, model_urls.get(arch, None), progress)
180+
return model
198181

199182

200183
def deeplabv3_resnet101(
201184
pretrained: bool = False,
202185
progress: bool = True,
203186
num_classes: int = 21,
204187
aux_loss: bool = False,
205-
**kwargs: Any,
188+
pretrained_backbone: bool = True,
206189
) -> DeepLabV3:
207190
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
208191
@@ -212,16 +195,27 @@ def deeplabv3_resnet101(
212195
progress (bool): If True, displays a progress bar of the download to stderr
213196
num_classes (int): The number of classes
214197
aux_loss (bool): If True, include an auxiliary classifier
198+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
215199
"""
216-
return _deeplabv3_resnet("resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
200+
if pretrained:
201+
aux_loss = True
202+
pretrained_backbone = False
203+
204+
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
205+
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
206+
207+
if pretrained:
208+
arch = "deeplabv3_resnet101_coco"
209+
_load_weights(arch, model, model_urls.get(arch, None), progress)
210+
return model
217211

218212

219213
def deeplabv3_mobilenet_v3_large(
220214
pretrained: bool = False,
221215
progress: bool = True,
222216
num_classes: int = 21,
223217
aux_loss: bool = False,
224-
**kwargs: Any,
218+
pretrained_backbone: bool = True,
225219
) -> DeepLabV3:
226220
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
227221
@@ -231,5 +225,16 @@ def deeplabv3_mobilenet_v3_large(
231225
progress (bool): If True, displays a progress bar of the download to stderr
232226
num_classes (int): number of output classes of the model (including the background)
233227
aux_loss (bool): If True, it uses an auxiliary loss
228+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
234229
"""
235-
return _deeplabv3_mobilenetv3("mobilenet_v3_large", pretrained, progress, num_classes, aux_loss, **kwargs)
230+
if pretrained:
231+
aux_loss = True
232+
pretrained_backbone = False
233+
234+
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True).features
235+
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
236+
237+
if pretrained:
238+
arch = "deeplabv3_mobilenet_v3_large_coco"
239+
_load_weights(arch, model, model_urls.get(arch, None), progress)
240+
return model

torchvision/models/segmentation/fcn.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Any
2-
31
from torch import nn
42

53
from .. import resnet
@@ -48,41 +46,26 @@ def __init__(self, in_channels: int, channels: int) -> None:
4846

4947

5048
def _fcn_resnet(
51-
backbone_name: str,
52-
pretrained: bool,
53-
progress: bool,
49+
backbone: resnet.ResNet,
5450
num_classes: int,
5551
aux: bool,
56-
pretrained_backbone: bool = True,
5752
) -> FCN:
58-
if pretrained:
59-
aux = True
60-
pretrained_backbone = False
61-
62-
backbone = resnet.__dict__[backbone_name](
63-
pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]
64-
)
6553
return_layers = {"layer4": "out"}
6654
if aux:
6755
return_layers["layer3"] = "aux"
6856
backbone = create_feature_extractor(backbone, return_layers)
6957

7058
aux_classifier = FCNHead(1024, num_classes) if aux else None
7159
classifier = FCNHead(2048, num_classes)
72-
model = FCN(backbone, classifier, aux_classifier)
73-
74-
if pretrained:
75-
arch = "fcn_" + backbone_name + "_coco"
76-
_load_weights(arch, model, model_urls.get(arch, None), progress)
77-
return model
60+
return FCN(backbone, classifier, aux_classifier)
7861

7962

8063
def fcn_resnet50(
8164
pretrained: bool = False,
8265
progress: bool = True,
8366
num_classes: int = 21,
8467
aux_loss: bool = False,
85-
**kwargs: Any,
68+
pretrained_backbone: bool = True,
8669
) -> FCN:
8770
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
8871
@@ -92,16 +75,27 @@ def fcn_resnet50(
9275
progress (bool): If True, displays a progress bar of the download to stderr
9376
num_classes (int): number of output classes of the model (including the background)
9477
aux_loss (bool): If True, it uses an auxiliary loss
78+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
9579
"""
96-
return _fcn_resnet("resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
80+
if pretrained:
81+
aux_loss = True
82+
pretrained_backbone = False
83+
84+
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
85+
model = _fcn_resnet(backbone, num_classes, aux_loss)
86+
87+
if pretrained:
88+
arch = "fcn_resnet50_coco"
89+
_load_weights(arch, model, model_urls.get(arch, None), progress)
90+
return model
9791

9892

9993
def fcn_resnet101(
10094
pretrained: bool = False,
10195
progress: bool = True,
10296
num_classes: int = 21,
10397
aux_loss: bool = False,
104-
**kwargs: Any,
98+
pretrained_backbone: bool = True,
10599
) -> FCN:
106100
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
107101
@@ -111,5 +105,16 @@ def fcn_resnet101(
111105
progress (bool): If True, displays a progress bar of the download to stderr
112106
num_classes (int): number of output classes of the model (including the background)
113107
aux_loss (bool): If True, it uses an auxiliary loss
108+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
114109
"""
115-
return _fcn_resnet("resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
110+
if pretrained:
111+
aux_loss = True
112+
pretrained_backbone = False
113+
114+
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
115+
model = _fcn_resnet(backbone, num_classes, aux_loss)
116+
117+
if pretrained:
118+
arch = "fcn_resnet101_coco"
119+
_load_weights(arch, model, model_urls.get(arch, None), progress)
120+
return model

torchvision/models/segmentation/lraspp.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from typing import Any, Dict
2+
from typing import Dict
33

44
from torch import nn, Tensor
55
from torch.nn import functional as F
@@ -79,13 +79,7 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor:
7979
return self.low_classifier(low) + self.high_classifier(x)
8080

8181

82-
def _lraspp_mobilenetv3(
83-
backbone_name: str, pretrained: bool, progress: bool, num_classes: int, pretrained_backbone: bool = True
84-
) -> LRASPP:
85-
if pretrained:
86-
pretrained_backbone = False
87-
88-
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
82+
def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP:
8983
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
9084
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
9185
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
@@ -95,16 +89,11 @@ def _lraspp_mobilenetv3(
9589
high_channels = backbone[high_pos].out_channels
9690
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
9791

98-
model = LRASPP(backbone, low_channels, high_channels, num_classes)
99-
100-
if pretrained:
101-
arch = "lraspp_" + backbone_name + "_coco"
102-
_load_weights(arch, model, model_urls.get(arch, None), progress)
103-
return model
92+
return LRASPP(backbone, low_channels, high_channels, num_classes)
10493

10594

10695
def lraspp_mobilenet_v3_large(
107-
pretrained: bool = False, progress: bool = True, num_classes: int = 21, **kwargs: Any
96+
pretrained: bool = False, progress: bool = True, num_classes: int = 21, pretrained_backbone: bool = True
10897
) -> LRASPP:
10998
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
11099
@@ -113,8 +102,15 @@ def lraspp_mobilenet_v3_large(
113102
contains the same classes as Pascal VOC
114103
progress (bool): If True, displays a progress bar of the download to stderr
115104
num_classes (int): number of output classes of the model (including the background)
105+
pretrained_backbone (bool): If True, the backbone will be pre-trained.
116106
"""
117-
if kwargs.pop("aux_loss", False):
118-
raise NotImplementedError("This model does not use auxiliary loss")
107+
if pretrained:
108+
pretrained_backbone = False
109+
110+
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True).features
111+
model = _lraspp_mobilenetv3(backbone, num_classes)
119112

120-
return _lraspp_mobilenetv3("mobilenet_v3_large", pretrained, progress, num_classes, **kwargs)
113+
if pretrained:
114+
arch = "lraspp_mobilenet_v3_large_coco"
115+
_load_weights(arch, model, model_urls.get(arch, None), progress)
116+
return model

0 commit comments

Comments
 (0)