diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index c2479a828c8..34fce9c07cc 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -1,8 +1,10 @@ import warnings import torch +from torch import Tensor import torch.nn as nn from .utils import load_state_dict_from_url +from typing import Any, Dict, List __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] @@ -22,8 +24,15 @@ class _InvertedResidual(nn.Module): - def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, - bn_momentum=0.1): + def __init__( + self, + in_ch: int, + out_ch: int, + kernel_size: int, + stride: int, + expansion_factor: int, + bn_momentum: float = 0.1 + ): super(_InvertedResidual, self).__init__() assert stride in [1, 2] assert kernel_size in [3, 5] @@ -43,15 +52,15 @@ def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, nn.Conv2d(mid_ch, out_ch, 1, bias=False), nn.BatchNorm2d(out_ch, momentum=bn_momentum)) - def forward(self, input): + def forward(self, input: Tensor) -> Tensor: if self.apply_residual: return self.layers(input) + input else: return self.layers(input) -def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, - bn_momentum): +def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, + bn_momentum: float) -> nn.Sequential: """ Creates a stack of inverted residuals. """ assert repeats >= 1 # First one has no skip, because feature map size changes. @@ -65,7 +74,7 @@ def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, return nn.Sequential(first, *remaining) -def _round_to_multiple_of(val, divisor, round_up_bias=0.9): +def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int: """ Asymmetric rounding to make `val` divisible by `divisor`. With default bias, will round up, unless the number is no more than 10% greater than the smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ @@ -74,7 +83,7 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9): return new_val if new_val >= round_up_bias * val else new_val + divisor -def _get_depths(alpha): +def _get_depths(alpha: float) -> List[int]: """ Scales tensor depths as in reference MobileNet code, prefers rouding up rather than down. """ depths = [32, 16, 24, 40, 80, 96, 192, 320] @@ -95,7 +104,12 @@ class MNASNet(torch.nn.Module): # Version 2 adds depth scaling in the initial stages of the network. _version = 2 - def __init__(self, alpha, num_classes=1000, dropout=0.2): + def __init__( + self, + alpha: float, + num_classes: int = 1000, + dropout: float = 0.2 + ): super(MNASNet, self).__init__() assert alpha > 0.0 self.alpha = alpha @@ -130,13 +144,13 @@ def __init__(self, alpha, num_classes=1000, dropout=0.2): nn.Linear(1280, num_classes)) self._initialize_weights() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.layers(x) # Equivalent to global avgpool and removing H and W dimensions. x = x.mean([2, 3]) return self.classifier(x) - def _initialize_weights(self): + def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", @@ -151,8 +165,8 @@ def _initialize_weights(self): nonlinearity="sigmoid") nn.init.zeros_(m.bias) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool, + missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None: version = local_metadata.get("version", None) assert version in [1, 2] @@ -192,7 +206,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, unexpected_keys, error_msgs) -def _load_pretrained(model_name, model, progress): +def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None: if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: raise ValueError( "No checkpoint is available for model type {}".format(model_name)) @@ -201,7 +215,7 @@ def _load_pretrained(model_name, model, progress): load_state_dict_from_url(checkpoint_url, progress=progress)) -def mnasnet0_5(pretrained=False, progress=True, **kwargs): +def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 0.5 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. @@ -215,7 +229,7 @@ def mnasnet0_5(pretrained=False, progress=True, **kwargs): return model -def mnasnet0_75(pretrained=False, progress=True, **kwargs): +def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 0.75 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. @@ -229,7 +243,7 @@ def mnasnet0_75(pretrained=False, progress=True, **kwargs): return model -def mnasnet1_0(pretrained=False, progress=True, **kwargs): +def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 1.0 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. @@ -243,7 +257,7 @@ def mnasnet1_0(pretrained=False, progress=True, **kwargs): return model -def mnasnet1_3(pretrained=False, progress=True, **kwargs): +def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 1.3 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_.