Skip to content

Commit 052edce

Browse files
authored
Added annotation typing to shufflenet (#2864)
* style: Added annotation typing for shufflenet * fix: Removed duplicate type hint * refactor: Removed un-necessary import * fix: Fixed constructor typing * style: Added black formatting on depthwise_conv * style: Fixed stage typing in shufflenet
1 parent 3852b41 commit 052edce

File tree

1 file changed

+36
-13
lines changed

1 file changed

+36
-13
lines changed

torchvision/models/shufflenetv2.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
from torch import Tensor
23
import torch.nn as nn
34
from .utils import load_state_dict_from_url
5+
from typing import Callable, Any, List
46

57

68
__all__ = [
@@ -16,8 +18,7 @@
1618
}
1719

1820

19-
def channel_shuffle(x, groups):
20-
# type: (torch.Tensor, int) -> torch.Tensor
21+
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
2122
batchsize, num_channels, height, width = x.data.size()
2223
channels_per_group = num_channels // groups
2324

@@ -34,7 +35,12 @@ def channel_shuffle(x, groups):
3435

3536

3637
class InvertedResidual(nn.Module):
37-
def __init__(self, inp, oup, stride):
38+
def __init__(
39+
self,
40+
inp: int,
41+
oup: int,
42+
stride: int
43+
) -> None:
3844
super(InvertedResidual, self).__init__()
3945

4046
if not (1 <= stride <= 3):
@@ -68,10 +74,17 @@ def __init__(self, inp, oup, stride):
6874
)
6975

7076
@staticmethod
71-
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
77+
def depthwise_conv(
78+
i: int,
79+
o: int,
80+
kernel_size: int,
81+
stride: int = 1,
82+
padding: int = 0,
83+
bias: bool = False
84+
) -> nn.Conv2d:
7285
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
7386

74-
def forward(self, x):
87+
def forward(self, x: Tensor) -> Tensor:
7588
if self.stride == 1:
7689
x1, x2 = x.chunk(2, dim=1)
7790
out = torch.cat((x1, self.branch2(x2)), dim=1)
@@ -84,7 +97,13 @@ def forward(self, x):
8497

8598

8699
class ShuffleNetV2(nn.Module):
87-
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual):
100+
def __init__(
101+
self,
102+
stages_repeats: List[int],
103+
stages_out_channels: List[int],
104+
num_classes: int = 1000,
105+
inverted_residual: Callable[..., nn.Module] = InvertedResidual
106+
) -> None:
88107
super(ShuffleNetV2, self).__init__()
89108

90109
if len(stages_repeats) != 3:
@@ -104,6 +123,10 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert
104123

105124
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106125

126+
# Static annotations for mypy
127+
self.stage2: nn.Sequential
128+
self.stage3: nn.Sequential
129+
self.stage4: nn.Sequential
107130
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
108131
for name, repeats, output_channels in zip(
109132
stage_names, stages_repeats, self._stage_out_channels[1:]):
@@ -122,7 +145,7 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert
122145

123146
self.fc = nn.Linear(output_channels, num_classes)
124147

125-
def _forward_impl(self, x):
148+
def _forward_impl(self, x: Tensor) -> Tensor:
126149
# See note [TorchScript super()]
127150
x = self.conv1(x)
128151
x = self.maxpool(x)
@@ -134,11 +157,11 @@ def _forward_impl(self, x):
134157
x = self.fc(x)
135158
return x
136159

137-
def forward(self, x):
160+
def forward(self, x: Tensor) -> Tensor:
138161
return self._forward_impl(x)
139162

140163

141-
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
164+
def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2:
142165
model = ShuffleNetV2(*args, **kwargs)
143166

144167
if pretrained:
@@ -152,7 +175,7 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
152175
return model
153176

154177

155-
def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
178+
def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
156179
"""
157180
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
158181
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
@@ -166,7 +189,7 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
166189
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
167190

168191

169-
def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
192+
def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
170193
"""
171194
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
172195
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
@@ -180,7 +203,7 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
180203
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
181204

182205

183-
def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
206+
def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
184207
"""
185208
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
186209
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
@@ -194,7 +217,7 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
194217
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
195218

196219

197-
def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
220+
def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
198221
"""
199222
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
200223
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"

0 commit comments

Comments
 (0)