Skip to content

Add typing annotations to models/quantization #4232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b2f6615
fix
oke-aditya May 20, 2021
4fb038d
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 20, 2021
deda5d7
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
5490821
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
4cfc220
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
6306746
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 24, 2021
e8c93cf
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 28, 2021
6871ccc
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 28, 2021
80060bf
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 29, 2021
3b5f0ca
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 30, 2021
304fa9e
add typings
oke-aditya Jul 30, 2021
b4b35b3
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Jul 30, 2021
a20ee80
fixup some more types
oke-aditya Jul 30, 2021
78ddc54
Type more
oke-aditya Jul 30, 2021
1e0de0c
remove mypy ignore
oke-aditya Aug 2, 2021
4bc3a80
add missing typings
oke-aditya Aug 2, 2021
3c1bd67
fix a few mypy errors
oke-aditya Aug 2, 2021
479c64c
fix mypy errors
oke-aditya Aug 2, 2021
ee1b93e
fix mypy
oke-aditya Aug 2, 2021
03b3ba4
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Aug 5, 2021
73f28df
ignore types
oke-aditya Aug 5, 2021
d059d04
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Aug 16, 2021
2528d13
fixup annotation
oke-aditya Aug 16, 2021
10d4d67
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Aug 17, 2021
2bba207
fix remaining types
oke-aditya Aug 17, 2021
a6b8528
cleanup #TODO comments
pmeier Aug 18, 2021
cf5f6de
Merge branch 'main' into add_typing3
datumbox Aug 31, 2021
57d8970
Merge branch 'main' into add_typing3
datumbox Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ ignore_errors=True

ignore_errors = True

[mypy-torchvision.models.quantization.*]

ignore_errors = True

[mypy-torchvision.ops.*]

ignore_errors = True
Expand Down
49 changes: 30 additions & 19 deletions torchvision/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Any
from torch import Tensor

from ..._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.googlenet import (
Expand All @@ -18,7 +20,13 @@
}


def googlenet(pretrained=False, progress=True, quantize=False, **kwargs):
def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableGoogLeNet":

r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.

Expand Down Expand Up @@ -70,48 +78,51 @@ def googlenet(pretrained=False, progress=True, quantize=False, **kwargs):

if not original_aux_logits:
model.aux_logits = False
model.aux1 = None
model.aux2 = None
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model


class QuantizableBasicConv2d(BasicConv2d):

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x

def fuse_model(self):
def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInception(Inception):

def __init__(self, *args, **kwargs):
super(QuantizableInception, self).__init__(
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInception, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.cat = nn.quantized.FloatFunctional()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.cat.cat(outputs, 1)


class QuantizableInceptionAux(InceptionAux):

def __init__(self, *args, **kwargs):
super(QuantizableInceptionAux, self).__init__(
conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = F.adaptive_avg_pool2d(x, (4, 4))
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
Expand All @@ -130,17 +141,17 @@ def forward(self, x):


class QuantizableGoogLeNet(GoogLeNet):

def __init__(self, *args, **kwargs):
super(QuantizableGoogLeNet, self).__init__(
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc]
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux],
*args,
**kwargs
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x)
x = self.quant(x)
x, aux1, aux2 = self._forward(x)
Expand All @@ -153,7 +164,7 @@ def forward(self, x):
else:
return self.eager_outputs(x, aux2, aux1)

def fuse_model(self):
def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in googlenet model

Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
Expand Down
94 changes: 69 additions & 25 deletions torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Any, List

from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
from ..._internally_replaced_utils import load_state_dict_from_url
Expand All @@ -22,7 +25,13 @@
}


def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
def inception_v3(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableInception3":

r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

Expand Down Expand Up @@ -84,68 +93,93 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):


class QuantizableBasicConv2d(inception_module.BasicConv2d):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x

def fuse_model(self):
def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInceptionA(inception_module.InceptionA):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionA, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionA, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop.cat(outputs, 1)


class QuantizableInceptionB(inception_module.InceptionB):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionB, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionB, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop.cat(outputs, 1)


class QuantizableInceptionC(inception_module.InceptionC):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionC, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionC, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop.cat(outputs, 1)


class QuantizableInceptionD(inception_module.InceptionD):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionD, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionD, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop.cat(outputs, 1)


class QuantizableInceptionE(inception_module.InceptionE):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionE, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop1 = nn.quantized.FloatFunctional()
self.myop2 = nn.quantized.FloatFunctional()
self.myop3 = nn.quantized.FloatFunctional()

def _forward(self, x):
def _forward(self, x: Tensor) -> List[Tensor]:
branch1x1 = self.branch1x1(x)

branch3x3 = self.branch3x3_1(x)
Expand All @@ -166,18 +200,28 @@ def _forward(self, x):
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop3.cat(outputs, 1)


class QuantizableInceptionAux(inception_module.InceptionAux):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionAux, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)


class QuantizableInception3(inception_module.Inception3):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
def __init__(
self,
num_classes: int = 1000,
aux_logits: bool = True,
transform_input: bool = False,
) -> None:
super(QuantizableInception3, self).__init__(
num_classes=num_classes,
aux_logits=aux_logits,
Expand All @@ -195,7 +239,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
def forward(self, x: Tensor) -> InceptionOutputs:
x = self._transform_input(x)
x = self.quant(x)
x, aux = self._forward(x)
Expand All @@ -208,7 +252,7 @@ def forward(self, x):
else:
return self.eager_outputs(x, aux)

def fuse_model(self):
def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in inception model

Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
Expand Down
24 changes: 17 additions & 7 deletions torchvision/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from torch import nn
from torch import Tensor

from ..._internally_replaced_utils import load_state_dict_from_url

from typing import Any

from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model
Expand All @@ -14,24 +19,24 @@


class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInvertedResidual, self).__init__(*args, **kwargs)
self.skip_add = nn.quantized.FloatFunctional()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
return self.skip_add.add(x, self.conv(x))
else:
return self.conv(x)

def fuse_model(self):
def fuse_model(self) -> None:
for idx in range(len(self.conv)):
if type(self.conv[idx]) == nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)


class QuantizableMobileNetV2(MobileNetV2):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
MobileNet V2 main class

Expand All @@ -42,21 +47,26 @@ def __init__(self, *args, **kwargs):
self.quant = QuantStub()
self.dequant = DeQuantStub()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = self._forward_impl(x)
x = self.dequant(x)
return x

def fuse_model(self):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvBNReLU:
fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == QuantizableInvertedResidual:
m.fuse_model()


def mobilenet_v2(pretrained=False, progress=True, quantize=False, **kwargs):
def mobilenet_v2(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
Expand Down
Loading