1
- from typing import Any , List
1
+ from typing import List
2
2
3
3
import torch
4
4
from torch import nn
@@ -114,48 +114,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
114
114
115
115
116
116
def _deeplabv3_resnet (
117
- backbone_name : str ,
118
- pretrained : bool ,
119
- progress : bool ,
117
+ backbone : resnet .ResNet ,
120
118
num_classes : int ,
121
119
aux : bool ,
122
- pretrained_backbone : bool = True ,
123
120
) -> DeepLabV3 :
124
- if pretrained :
125
- aux = True
126
- pretrained_backbone = False
127
-
128
- backbone = resnet .__dict__ [backbone_name ](
129
- pretrained = pretrained_backbone , replace_stride_with_dilation = [False , True , True ]
130
- )
131
121
return_layers = {"layer4" : "out" }
132
122
if aux :
133
123
return_layers ["layer3" ] = "aux"
134
124
backbone = create_feature_extractor (backbone , return_layers )
135
125
136
126
aux_classifier = FCNHead (1024 , num_classes ) if aux else None
137
127
classifier = DeepLabHead (2048 , num_classes )
138
- model = DeepLabV3 (backbone , classifier , aux_classifier )
139
-
140
- if pretrained :
141
- arch = "deeplabv3_" + backbone_name + "_coco"
142
- _load_weights (arch , model , model_urls .get (arch , None ), progress )
143
- return model
128
+ return DeepLabV3 (backbone , classifier , aux_classifier )
144
129
145
130
146
131
def _deeplabv3_mobilenetv3 (
147
- backbone_name : str ,
148
- pretrained : bool ,
149
- progress : bool ,
132
+ backbone : mobilenetv3 .MobileNetV3 ,
150
133
num_classes : int ,
151
134
aux : bool ,
152
- pretrained_backbone : bool = True ,
153
135
) -> DeepLabV3 :
154
- if pretrained :
155
- aux = True
156
- pretrained_backbone = False
157
-
158
- backbone = mobilenetv3 .__dict__ [backbone_name ](pretrained = pretrained_backbone , dilated = True ).features
159
136
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
160
137
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
161
138
stage_indices = [0 ] + [i for i , b in enumerate (backbone ) if getattr (b , "_is_cn" , False )] + [len (backbone ) - 1 ]
@@ -170,20 +147,15 @@ def _deeplabv3_mobilenetv3(
170
147
171
148
aux_classifier = FCNHead (aux_inplanes , num_classes ) if aux else None
172
149
classifier = DeepLabHead (out_inplanes , num_classes )
173
- model = DeepLabV3 (backbone , classifier , aux_classifier )
174
-
175
- if pretrained :
176
- arch = "deeplabv3_" + backbone_name + "_coco"
177
- _load_weights (arch , model , model_urls .get (arch , None ), progress )
178
- return model
150
+ return DeepLabV3 (backbone , classifier , aux_classifier )
179
151
180
152
181
153
def deeplabv3_resnet50 (
182
154
pretrained : bool = False ,
183
155
progress : bool = True ,
184
156
num_classes : int = 21 ,
185
157
aux_loss : bool = False ,
186
- ** kwargs : Any ,
158
+ pretrained_backbone : bool = True ,
187
159
) -> DeepLabV3 :
188
160
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
189
161
@@ -193,16 +165,27 @@ def deeplabv3_resnet50(
193
165
progress (bool): If True, displays a progress bar of the download to stderr
194
166
num_classes (int): number of output classes of the model (including the background)
195
167
aux_loss (bool): If True, it uses an auxiliary loss
168
+ pretrained_backbone (bool): If True, the backbone will be pre-trained.
196
169
"""
197
- return _deeplabv3_resnet ("resnet50" , pretrained , progress , num_classes , aux_loss , ** kwargs )
170
+ if pretrained :
171
+ aux_loss = True
172
+ pretrained_backbone = False
173
+
174
+ backbone = resnet .resnet50 (pretrained = pretrained_backbone , replace_stride_with_dilation = [False , True , True ])
175
+ model = _deeplabv3_resnet (backbone , num_classes , aux_loss )
176
+
177
+ if pretrained :
178
+ arch = "deeplabv3_resnet50_coco"
179
+ _load_weights (arch , model , model_urls .get (arch , None ), progress )
180
+ return model
198
181
199
182
200
183
def deeplabv3_resnet101 (
201
184
pretrained : bool = False ,
202
185
progress : bool = True ,
203
186
num_classes : int = 21 ,
204
187
aux_loss : bool = False ,
205
- ** kwargs : Any ,
188
+ pretrained_backbone : bool = True ,
206
189
) -> DeepLabV3 :
207
190
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
208
191
@@ -212,16 +195,27 @@ def deeplabv3_resnet101(
212
195
progress (bool): If True, displays a progress bar of the download to stderr
213
196
num_classes (int): The number of classes
214
197
aux_loss (bool): If True, include an auxiliary classifier
198
+ pretrained_backbone (bool): If True, the backbone will be pre-trained.
215
199
"""
216
- return _deeplabv3_resnet ("resnet101" , pretrained , progress , num_classes , aux_loss , ** kwargs )
200
+ if pretrained :
201
+ aux_loss = True
202
+ pretrained_backbone = False
203
+
204
+ backbone = resnet .resnet101 (pretrained = pretrained_backbone , replace_stride_with_dilation = [False , True , True ])
205
+ model = _deeplabv3_resnet (backbone , num_classes , aux_loss )
206
+
207
+ if pretrained :
208
+ arch = "deeplabv3_resnet101_coco"
209
+ _load_weights (arch , model , model_urls .get (arch , None ), progress )
210
+ return model
217
211
218
212
219
213
def deeplabv3_mobilenet_v3_large (
220
214
pretrained : bool = False ,
221
215
progress : bool = True ,
222
216
num_classes : int = 21 ,
223
217
aux_loss : bool = False ,
224
- ** kwargs : Any ,
218
+ pretrained_backbone : bool = True ,
225
219
) -> DeepLabV3 :
226
220
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
227
221
@@ -231,5 +225,16 @@ def deeplabv3_mobilenet_v3_large(
231
225
progress (bool): If True, displays a progress bar of the download to stderr
232
226
num_classes (int): number of output classes of the model (including the background)
233
227
aux_loss (bool): If True, it uses an auxiliary loss
228
+ pretrained_backbone (bool): If True, the backbone will be pre-trained.
234
229
"""
235
- return _deeplabv3_mobilenetv3 ("mobilenet_v3_large" , pretrained , progress , num_classes , aux_loss , ** kwargs )
230
+ if pretrained :
231
+ aux_loss = True
232
+ pretrained_backbone = False
233
+
234
+ backbone = mobilenetv3 .mobilenet_v3_large (pretrained = pretrained_backbone , dilated = True ).features
235
+ model = _deeplabv3_mobilenetv3 (backbone , num_classes , aux_loss )
236
+
237
+ if pretrained :
238
+ arch = "deeplabv3_mobilenet_v3_large_coco"
239
+ _load_weights (arch , model , model_urls .get (arch , None ), progress )
240
+ return model
0 commit comments