Skip to content

Commit 732fc0b

Browse files
authored
Adding multiweight support for inception prototype model (#4821)
* Moving original builder at the bottom of the page to use proper typing. * Adding multiweight support to inception. * Update doc.
1 parent 00b963a commit 732fc0b

File tree

4 files changed

+93
-38
lines changed

4 files changed

+93
-38
lines changed

references/classification/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ The weights of the Inception V3 model are ported from the original paper rather
3838
Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command:
3939

4040
```
41-
torchrun --nproc_per_node=8 train.py --model inception_v3
41+
torchrun --nproc_per_node=8 train.py --model inception_v3\
4242
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
4343
```
4444

torchvision/models/inception.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,6 @@
2626
_InceptionOutputs = InceptionOutputs
2727

2828

29-
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
30-
r"""Inception v3 model architecture from
31-
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
32-
The required minimum input size of the model is 75x75.
33-
34-
.. note::
35-
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
36-
N x 3 x 299 x 299, so ensure your images are sized accordingly.
37-
38-
Args:
39-
pretrained (bool): If True, returns a model pre-trained on ImageNet
40-
progress (bool): If True, displays a progress bar of the download to stderr
41-
aux_logits (bool): If True, add an auxiliary branch that can improve training.
42-
Default: *True*
43-
transform_input (bool): If True, preprocesses the input according to the method with which it
44-
was trained on ImageNet. Default: *False*
45-
"""
46-
if pretrained:
47-
if "transform_input" not in kwargs:
48-
kwargs["transform_input"] = True
49-
if "aux_logits" in kwargs:
50-
original_aux_logits = kwargs["aux_logits"]
51-
kwargs["aux_logits"] = True
52-
else:
53-
original_aux_logits = True
54-
kwargs["init_weights"] = False # we are loading weights from a pretrained model
55-
model = Inception3(**kwargs)
56-
state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
57-
model.load_state_dict(state_dict)
58-
if not original_aux_logits:
59-
model.aux_logits = False
60-
model.AuxLogits = None
61-
return model
62-
63-
return Inception3(**kwargs)
64-
65-
6629
class Inception3(nn.Module):
6730
def __init__(
6831
self,
@@ -442,3 +405,40 @@ def forward(self, x: Tensor) -> Tensor:
442405
x = self.conv(x)
443406
x = self.bn(x)
444407
return F.relu(x, inplace=True)
408+
409+
410+
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Inception3:
411+
r"""Inception v3 model architecture from
412+
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
413+
The required minimum input size of the model is 75x75.
414+
415+
.. note::
416+
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
417+
N x 3 x 299 x 299, so ensure your images are sized accordingly.
418+
419+
Args:
420+
pretrained (bool): If True, returns a model pre-trained on ImageNet
421+
progress (bool): If True, displays a progress bar of the download to stderr
422+
aux_logits (bool): If True, add an auxiliary branch that can improve training.
423+
Default: *True*
424+
transform_input (bool): If True, preprocesses the input according to the method with which it
425+
was trained on ImageNet. Default: *False*
426+
"""
427+
if pretrained:
428+
if "transform_input" not in kwargs:
429+
kwargs["transform_input"] = True
430+
if "aux_logits" in kwargs:
431+
original_aux_logits = kwargs["aux_logits"]
432+
kwargs["aux_logits"] = True
433+
else:
434+
original_aux_logits = True
435+
kwargs["init_weights"] = False # we are loading weights from a pretrained model
436+
model = Inception3(**kwargs)
437+
state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
438+
model.load_state_dict(state_dict)
439+
if not original_aux_logits:
440+
model.aux_logits = False
441+
model.AuxLogits = None
442+
return model
443+
444+
return Inception3(**kwargs)

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .densenet import *
33
from .efficientnet import *
44
from .googlenet import *
5+
from .inception import *
56
from .mnasnet import *
67
from .mobilenetv2 import *
78
from .mobilenetv3 import *
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
8+
from ..transforms.presets import ImageNetEval
9+
from ._api import Weights, WeightEntry
10+
from ._meta import _IMAGENET_CATEGORIES
11+
12+
13+
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"]
14+
15+
16+
_common_meta = {"size": (299, 299), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
17+
18+
19+
class Inception3Weights(Weights):
20+
ImageNet1K_TFV1 = WeightEntry(
21+
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
22+
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
23+
meta={
24+
**_common_meta,
25+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
26+
"acc@1": 77.294,
27+
"acc@5": 93.450,
28+
},
29+
)
30+
31+
32+
def inception_v3(weights: Optional[Inception3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
33+
if "pretrained" in kwargs:
34+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
35+
weights = Inception3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
36+
weights = Inception3Weights.verify(weights)
37+
38+
original_aux_logits = kwargs.get("aux_logits", True)
39+
if weights is not None:
40+
if "transform_input" not in kwargs:
41+
kwargs["transform_input"] = True
42+
kwargs["aux_logits"] = True
43+
kwargs["init_weights"] = False
44+
kwargs["num_classes"] = len(weights.meta["categories"])
45+
46+
model = Inception3(**kwargs)
47+
48+
if weights is not None:
49+
model.load_state_dict(weights.state_dict(progress=progress))
50+
if not original_aux_logits:
51+
model.aux_logits = False
52+
model.AuxLogits = None
53+
54+
return model

0 commit comments

Comments
 (0)