|
1 |
| -from typing import List |
| 1 | +from typing import List, Optional |
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | from torch import nn
|
5 | 5 | from torch.nn import functional as F
|
6 | 6 |
|
7 |
| -from ._utils import _SimpleSegmentationModel |
| 7 | +from .. import mobilenetv3 |
| 8 | +from .. import resnet |
| 9 | +from ..feature_extraction import create_feature_extractor |
| 10 | +from ._utils import _SimpleSegmentationModel, _load_weights |
| 11 | +from .fcn import FCNHead |
8 | 12 |
|
9 | 13 |
|
10 |
| -__all__ = ["DeepLabV3"] |
| 14 | +__all__ = [ |
| 15 | + "DeepLabV3", |
| 16 | + "deeplabv3_resnet50", |
| 17 | + "deeplabv3_resnet101", |
| 18 | + "deeplabv3_mobilenet_v3_large", |
| 19 | +] |
| 20 | + |
| 21 | + |
| 22 | +model_urls = { |
| 23 | + "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", |
| 24 | + "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", |
| 25 | + "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", |
| 26 | +} |
11 | 27 |
|
12 | 28 |
|
13 | 29 | class DeepLabV3(_SimpleSegmentationModel):
|
@@ -95,3 +111,131 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 | 111 | _res.append(conv(x))
|
96 | 112 | res = torch.cat(_res, dim=1)
|
97 | 113 | return self.project(res)
|
| 114 | + |
| 115 | + |
| 116 | +def _deeplabv3_resnet( |
| 117 | + backbone: resnet.ResNet, |
| 118 | + num_classes: int, |
| 119 | + aux: Optional[bool], |
| 120 | +) -> DeepLabV3: |
| 121 | + return_layers = {"layer4": "out"} |
| 122 | + if aux: |
| 123 | + return_layers["layer3"] = "aux" |
| 124 | + backbone = create_feature_extractor(backbone, return_layers) |
| 125 | + |
| 126 | + aux_classifier = FCNHead(1024, num_classes) if aux else None |
| 127 | + classifier = DeepLabHead(2048, num_classes) |
| 128 | + return DeepLabV3(backbone, classifier, aux_classifier) |
| 129 | + |
| 130 | + |
| 131 | +def _deeplabv3_mobilenetv3( |
| 132 | + backbone: mobilenetv3.MobileNetV3, |
| 133 | + num_classes: int, |
| 134 | + aux: Optional[bool], |
| 135 | +) -> DeepLabV3: |
| 136 | + backbone = backbone.features |
| 137 | + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. |
| 138 | + # The first and last blocks are always included because they are the C0 (conv1) and Cn. |
| 139 | + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] |
| 140 | + out_pos = stage_indices[-1] # use C5 which has output_stride = 16 |
| 141 | + out_inplanes = backbone[out_pos].out_channels |
| 142 | + aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 |
| 143 | + aux_inplanes = backbone[aux_pos].out_channels |
| 144 | + return_layers = {str(out_pos): "out"} |
| 145 | + if aux: |
| 146 | + return_layers[str(aux_pos)] = "aux" |
| 147 | + backbone = create_feature_extractor(backbone, return_layers) |
| 148 | + |
| 149 | + aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None |
| 150 | + classifier = DeepLabHead(out_inplanes, num_classes) |
| 151 | + return DeepLabV3(backbone, classifier, aux_classifier) |
| 152 | + |
| 153 | + |
| 154 | +def deeplabv3_resnet50( |
| 155 | + pretrained: bool = False, |
| 156 | + progress: bool = True, |
| 157 | + num_classes: int = 21, |
| 158 | + aux_loss: Optional[bool] = None, |
| 159 | + pretrained_backbone: bool = True, |
| 160 | +) -> DeepLabV3: |
| 161 | + """Constructs a DeepLabV3 model with a ResNet-50 backbone. |
| 162 | +
|
| 163 | + Args: |
| 164 | + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which |
| 165 | + contains the same classes as Pascal VOC |
| 166 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 167 | + num_classes (int): number of output classes of the model (including the background) |
| 168 | + aux_loss (bool, optional): If True, it uses an auxiliary loss |
| 169 | + pretrained_backbone (bool): If True, the backbone will be pre-trained. |
| 170 | + """ |
| 171 | + if pretrained: |
| 172 | + aux_loss = True |
| 173 | + pretrained_backbone = False |
| 174 | + |
| 175 | + backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) |
| 176 | + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) |
| 177 | + |
| 178 | + if pretrained: |
| 179 | + arch = "deeplabv3_resnet50_coco" |
| 180 | + _load_weights(arch, model, model_urls.get(arch, None), progress) |
| 181 | + return model |
| 182 | + |
| 183 | + |
| 184 | +def deeplabv3_resnet101( |
| 185 | + pretrained: bool = False, |
| 186 | + progress: bool = True, |
| 187 | + num_classes: int = 21, |
| 188 | + aux_loss: Optional[bool] = None, |
| 189 | + pretrained_backbone: bool = True, |
| 190 | +) -> DeepLabV3: |
| 191 | + """Constructs a DeepLabV3 model with a ResNet-101 backbone. |
| 192 | +
|
| 193 | + Args: |
| 194 | + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which |
| 195 | + contains the same classes as Pascal VOC |
| 196 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 197 | + num_classes (int): The number of classes |
| 198 | + aux_loss (bool, optional): If True, include an auxiliary classifier |
| 199 | + pretrained_backbone (bool): If True, the backbone will be pre-trained. |
| 200 | + """ |
| 201 | + if pretrained: |
| 202 | + aux_loss = True |
| 203 | + pretrained_backbone = False |
| 204 | + |
| 205 | + backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) |
| 206 | + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) |
| 207 | + |
| 208 | + if pretrained: |
| 209 | + arch = "deeplabv3_resnet101_coco" |
| 210 | + _load_weights(arch, model, model_urls.get(arch, None), progress) |
| 211 | + return model |
| 212 | + |
| 213 | + |
| 214 | +def deeplabv3_mobilenet_v3_large( |
| 215 | + pretrained: bool = False, |
| 216 | + progress: bool = True, |
| 217 | + num_classes: int = 21, |
| 218 | + aux_loss: Optional[bool] = None, |
| 219 | + pretrained_backbone: bool = True, |
| 220 | +) -> DeepLabV3: |
| 221 | + """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. |
| 222 | +
|
| 223 | + Args: |
| 224 | + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which |
| 225 | + contains the same classes as Pascal VOC |
| 226 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 227 | + num_classes (int): number of output classes of the model (including the background) |
| 228 | + aux_loss (bool, optional): If True, it uses an auxiliary loss |
| 229 | + pretrained_backbone (bool): If True, the backbone will be pre-trained. |
| 230 | + """ |
| 231 | + if pretrained: |
| 232 | + aux_loss = True |
| 233 | + pretrained_backbone = False |
| 234 | + |
| 235 | + backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) |
| 236 | + model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) |
| 237 | + |
| 238 | + if pretrained: |
| 239 | + arch = "deeplabv3_mobilenet_v3_large_coco" |
| 240 | + _load_weights(arch, model, model_urls.get(arch, None), progress) |
| 241 | + return model |
0 commit comments