diff --git a/docs/source/models.rst b/docs/source/models.rst index 674ac052c8d..308ba75481b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -10,6 +10,7 @@ architectures: - `SqueezeNet`_ - `DenseNet`_ - `Inception`_ v3 +- `GoogLeNet`_ You can construct a model with random weights by calling its constructor: @@ -22,6 +23,7 @@ You can construct a model with random weights by calling its constructor: squeezenet = models.squeezenet1_0() densenet = models.densenet161() inception = models.inception_v3() + googlenet = models.googlenet() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -35,6 +37,7 @@ These can be constructed by passing ``pretrained=True``: vgg16 = models.vgg16(pretrained=True) densenet = models.densenet161(pretrained=True) inception = models.inception_v3(pretrained=True) + googlenet = models.googlenet(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See @@ -84,6 +87,7 @@ Densenet-169 24.00 7.00 Densenet-201 22.80 6.43 Densenet-161 22.35 6.20 Inception v3 22.55 6.44 +GoogleNet 30.22 10.47 ================================ ============= ============= @@ -93,6 +97,7 @@ Inception v3 22.55 6.44 .. _SqueezeNet: https://arxiv.org/abs/1602.07360 .. _DenseNet: https://arxiv.org/abs/1608.06993 .. _Inception: https://arxiv.org/abs/1512.00567 +.. _GoogLeNet: https://arxiv.org/abs/1409.4842 .. currentmodule:: torchvision.models @@ -142,3 +147,8 @@ Inception v3 .. autofunction:: inception_v3 +GoogLeNet +------------ + +.. autofunction:: googlenet + diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 079992e0269..7437c51597f 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -4,3 +4,4 @@ from .squeezenet import * from .inception import * from .densenet import * +from .googlenet import * diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py new file mode 100644 index 00000000000..9f50d93147e --- /dev/null +++ b/torchvision/models/googlenet.py @@ -0,0 +1,183 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils import model_zoo + +__all__ = ['GoogLeNet', 'googlenet'] + +model_urls = { + # GoogLeNet ported from TensorFlow + 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', +} + + +def googlenet(pretrained=False, **kwargs): + r"""GoogLeNet (Inception v1) model architecture from + `"Going Deeper with Convolutions" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + if 'transform_input' not in kwargs: + kwargs['transform_input'] = True + kwargs['init_weights'] = False + model = GoogLeNet(**kwargs) + model.load_state_dict(model_zoo.load_url(model_urls['googlenet'])) + return model + + return GoogLeNet(**kwargs) + + +class GoogLeNet(nn.Module): + + def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True): + super(GoogLeNet, self).__init__() + self.aux_logits = aux_logits + self.transform_input = transform_input + + self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.conv2 = BasicConv2d(64, 64, kernel_size=1) + self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) + self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + + self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) + self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) + self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + + self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) + self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) + self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) + self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) + self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) + self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) + self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) + if aux_logits: + self.aux1 = InceptionAux(512, num_classes) + self.aux2 = InceptionAux(528, num_classes) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(0.4) + self.fc = nn.Linear(1024, num_classes) + + if init_weights: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.2) + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + if self.transform_input: + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + + x = self.conv1(x) + x = self.maxpool1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.maxpool2(x) + + x = self.inception3a(x) + x = self.inception3b(x) + x = self.maxpool3(x) + x = self.inception4a(x) + if self.training and self.aux_logits: + aux1 = self.aux1(x) + + x = self.inception4b(x) + x = self.inception4c(x) + x = self.inception4d(x) + if self.training and self.aux_logits: + aux2 = self.aux2(x) + + x = self.inception4e(x) + x = self.maxpool4(x) + x = self.inception5a(x) + x = self.inception5b(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.fc(x) + if self.training and self.aux_logits: + return aux1, aux2, x + return x + + +class Inception(nn.Module): + + def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): + super(Inception, self).__init__() + + self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) + + self.branch2 = nn.Sequential( + BasicConv2d(in_channels, ch3x3red, kernel_size=1), + BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) + ) + + self.branch3 = nn.Sequential( + BasicConv2d(in_channels, ch5x5red, kernel_size=1), + BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) + ) + + self.branch4 = nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), + BasicConv2d(in_channels, pool_proj, kernel_size=1) + ) + + def forward(self, x): + branch1 = self.branch1(x) + branch2 = self.branch2(x) + branch3 = self.branch3(x) + branch4 = self.branch4(x) + + outputs = [branch1, branch2, branch3, branch4] + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes): + super(InceptionAux, self).__init__() + self.conv = BasicConv2d(in_channels, 128, kernel_size=1) + + self.fc1 = nn.Linear(2048, 1024) + self.fc2 = nn.Linear(1024, num_classes) + + def forward(self, x): + x = F.adaptive_avg_pool2d(x, (4, 4)) + + x = self.conv(x) + x = x.view(x.size(0), -1) + x = F.relu(self.fc1(x), inplace=True) + x = F.dropout(x, 0.7, training=self.training) + x = self.fc2(x) + + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True)