Skip to content

Added annotation typing to googlenet #2858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional, Tuple
from torch import Tensor
from .utils import load_state_dict_from_url
from typing import Optional, Tuple, List, Callable, Any

__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]

Expand All @@ -23,7 +23,7 @@
_GoogLeNetOutputs = GoogLeNetOutputs


def googlenet(pretrained=False, progress=True, **kwargs):
def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "GoogLeNet":
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.

Expand Down Expand Up @@ -52,8 +52,8 @@ def googlenet(pretrained=False, progress=True, **kwargs):
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None
model.aux2 = None
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model

return GoogLeNet(**kwargs)
Expand All @@ -62,8 +62,14 @@ def googlenet(pretrained=False, progress=True, **kwargs):
class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']

def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=None,
blocks=None):
def __init__(
self,
num_classes: int = 1000,
aux_logits: bool = True,
transform_input: bool = False,
init_weights: Optional[bool] = None,
blocks: Optional[List[Callable[..., nn.Module]]] = None
) -> None:
super(GoogLeNet, self).__init__()
if blocks is None:
blocks = [BasicConv2d, Inception, InceptionAux]
Expand Down Expand Up @@ -104,8 +110,8 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, ini
self.aux1 = inception_aux_block(512, num_classes)
self.aux2 = inception_aux_block(528, num_classes)
else:
self.aux1 = None
self.aux2 = None
self.aux1 = None # type: ignore[assignment]
self.aux2 = None # type: ignore[assignment]

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.2)
Expand All @@ -114,7 +120,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, ini
if init_weights:
self._initialize_weights()

def _initialize_weights(self):
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
Expand All @@ -127,17 +133,15 @@ def _initialize_weights(self):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _transform_input(self, x):
# type: (Tensor) -> Tensor
def _transform_input(self, x: Tensor) -> Tensor:
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x

def _forward(self, x):
# type: (Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
# N x 3 x 224 x 224
x = self.conv1(x)
# N x 64 x 112 x 112
Expand Down Expand Up @@ -199,8 +203,7 @@ def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> Goog
else:
return x # type: ignore[return-value]

def forward(self, x):
# type: (Tensor) -> GoogLeNetOutputs
def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x)
x, aux1, aux2 = self._forward(x)
aux_defined = self.training and self.aux_logits
Expand All @@ -214,8 +217,17 @@ def forward(self, x):

class Inception(nn.Module):

def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj,
conv_block=None):
def __init__(
self,
in_channels: int,
ch1x1: int,
ch3x3red: int,
ch3x3: int,
ch5x5red: int,
ch5x5: int,
pool_proj: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(Inception, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
Expand All @@ -238,7 +250,7 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr
conv_block(in_channels, pool_proj, kernel_size=1)
)

def _forward(self, x):
def _forward(self, x: Tensor) -> List[Tensor]:
branch1 = self.branch1(x)
branch2 = self.branch2(x)
branch3 = self.branch3(x)
Expand All @@ -247,14 +259,19 @@ def _forward(self, x):
outputs = [branch1, branch2, branch3, branch4]
return outputs

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

def __init__(self, in_channels, num_classes, conv_block=None):
def __init__(
self,
in_channels: int,
num_classes: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionAux, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
Expand All @@ -263,7 +280,7 @@ def __init__(self, in_channels, num_classes, conv_block=None):
self.fc1 = nn.Linear(2048, 1024)
self.fc2 = nn.Linear(1024, num_classes)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = F.adaptive_avg_pool2d(x, (4, 4))
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
Expand All @@ -283,12 +300,17 @@ def forward(self, x):

class BasicConv2d(nn.Module):

def __init__(self, in_channels, out_channels, **kwargs):
def __init__(
self,
in_channels: int,
out_channels: int,
**kwargs: Any
) -> None:
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)