Skip to content

Added annotation typing to mnasnet #2856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 22, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings

import torch
from torch import Tensor
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Any, Dict, List

__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']

Expand All @@ -22,8 +24,15 @@

class _InvertedResidual(nn.Module):

def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
bn_momentum=0.1):
def __init__(
self,
in_ch: int,
out_ch: int,
kernel_size: int,
stride: int,
expansion_factor: int,
bn_momentum: float = 0.1
):
super(_InvertedResidual, self).__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
Expand All @@ -43,15 +52,15 @@ def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum))

def forward(self, input):
def forward(self, input: Tensor) -> Tensor:
if self.apply_residual:
return self.layers(input) + input
else:
return self.layers(input)


def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
bn_momentum):
def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int,
bn_momentum: float) -> nn.Sequential:
""" Creates a stack of inverted residuals. """
assert repeats >= 1
# First one has no skip, because feature map size changes.
Expand All @@ -65,7 +74,7 @@ def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
return nn.Sequential(first, *remaining)


def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
Expand All @@ -74,7 +83,7 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
return new_val if new_val >= round_up_bias * val else new_val + divisor


def _get_depths(alpha):
def _get_depths(alpha: float) -> List[int]:
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
depths = [32, 16, 24, 40, 80, 96, 192, 320]
Expand All @@ -95,7 +104,12 @@ class MNASNet(torch.nn.Module):
# Version 2 adds depth scaling in the initial stages of the network.
_version = 2

def __init__(self, alpha, num_classes=1000, dropout=0.2):
def __init__(
self,
alpha: float,
num_classes: int = 1000,
dropout: float = 0.2
):
super(MNASNet, self).__init__()
assert alpha > 0.0
self.alpha = alpha
Expand Down Expand Up @@ -130,13 +144,13 @@ def __init__(self, alpha, num_classes=1000, dropout=0.2):
nn.Linear(1280, num_classes))
self._initialize_weights()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
return self.classifier(x)

def _initialize_weights(self):
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out",
Expand All @@ -151,8 +165,8 @@ def _initialize_weights(self):
nonlinearity="sigmoid")
nn.init.zeros_(m.bias)

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool,
missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None:
version = local_metadata.get("version", None)
assert version in [1, 2]

Expand Down Expand Up @@ -192,7 +206,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
unexpected_keys, error_msgs)


def _load_pretrained(model_name, model, progress):
def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(
"No checkpoint is available for model type {}".format(model_name))
Expand All @@ -201,7 +215,7 @@ def _load_pretrained(model_name, model, progress):
load_state_dict_from_url(checkpoint_url, progress=progress))


def mnasnet0_5(pretrained=False, progress=True, **kwargs):
def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.5 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_.
Expand All @@ -215,7 +229,7 @@ def mnasnet0_5(pretrained=False, progress=True, **kwargs):
return model


def mnasnet0_75(pretrained=False, progress=True, **kwargs):
def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.75 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_.
Expand All @@ -229,7 +243,7 @@ def mnasnet0_75(pretrained=False, progress=True, **kwargs):
return model


def mnasnet1_0(pretrained=False, progress=True, **kwargs):
def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.0 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_.
Expand All @@ -243,7 +257,7 @@ def mnasnet1_0(pretrained=False, progress=True, **kwargs):
return model


def mnasnet1_3(pretrained=False, progress=True, **kwargs):
def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.3 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_.
Expand Down