Skip to content

Commit ede7980

Browse files
committed
Fixing mypy
1 parent de1d2ad commit ede7980

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

torchvision/models/segmentation/deeplabv3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _deeplabv3_mobilenetv3(
133133
num_classes: int,
134134
aux: bool,
135135
) -> DeepLabV3:
136+
backbone = backbone.features
136137
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
137138
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
138139
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
@@ -231,7 +232,7 @@ def deeplabv3_mobilenet_v3_large(
231232
aux_loss = True
232233
pretrained_backbone = False
233234

234-
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True).features
235+
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
235236
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
236237

237238
if pretrained:

torchvision/models/segmentation/lraspp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor:
8080

8181

8282
def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP:
83+
backbone = backbone.features
8384
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
8485
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
8586
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
@@ -107,7 +108,7 @@ def lraspp_mobilenet_v3_large(
107108
if pretrained:
108109
pretrained_backbone = False
109110

110-
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True).features
111+
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
111112
model = _lraspp_mobilenetv3(backbone, num_classes)
112113

113114
if pretrained:

0 commit comments

Comments
 (0)