Skip to content

Commit 37a9ee5

Browse files
authored
Add EfficientNet Architecture in TorchVision (#4293)
* Adding code skeleton * Adding MBConvConfig. * Extend SqueezeExcitation to support custom min_value and activation. * Implement MBConv. * Replace stochastic_depth with operator. * Adding the rest of the EfficientNet implementation * Update torchvision/models/efficientnet.py * Replacing 1st activation of SE with SiLU. * Adding efficientnet_b3. * Replace mobilenetv3 assets with custom. * Switch to standard sigmoid and reconfiguring BN. * Reconfiguration of efficientnet. * Add repr * Add weights. * Update weights. * Adding B5-B7 weights. * Update docs and hubconf. * Fix doc link. * Fix typo on comment.
1 parent d004d77 commit 37a9ee5

16 files changed

+441
-7
lines changed

docs/source/models.rst

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ architectures for image classification:
2727
- `ResNeXt`_
2828
- `Wide ResNet`_
2929
- `MNASNet`_
30+
- `EfficientNet`_
3031

3132
You can construct a model with random weights by calling its constructor:
3233

@@ -47,6 +48,14 @@ You can construct a model with random weights by calling its constructor:
4748
resnext50_32x4d = models.resnext50_32x4d()
4849
wide_resnet50_2 = models.wide_resnet50_2()
4950
mnasnet = models.mnasnet1_0()
51+
efficientnet_b0 = models.efficientnet_b0()
52+
efficientnet_b1 = models.efficientnet_b1()
53+
efficientnet_b2 = models.efficientnet_b2()
54+
efficientnet_b3 = models.efficientnet_b3()
55+
efficientnet_b4 = models.efficientnet_b4()
56+
efficientnet_b5 = models.efficientnet_b5()
57+
efficientnet_b6 = models.efficientnet_b6()
58+
efficientnet_b7 = models.efficientnet_b7()
5059
5160
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
5261
These can be constructed by passing ``pretrained=True``:
@@ -68,6 +77,14 @@ These can be constructed by passing ``pretrained=True``:
6877
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
6978
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
7079
mnasnet = models.mnasnet1_0(pretrained=True)
80+
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
81+
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
82+
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
83+
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
84+
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
85+
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
86+
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
87+
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
7188
7289
Instancing a pre-trained model will download its weights to a cache directory.
7390
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
@@ -113,7 +130,10 @@ Unfortunately, the concrete `subset` that was used is lost. For more
113130
information see `this discussion <https://github.com/pytorch/vision/issues/1439>`_
114131
or `these experiments <https://github.com/pytorch/vision/pull/1965>`_.
115132

116-
ImageNet 1-crop error rates (224x224)
133+
The sizes of the EfficientNet models depend on the variant. For the exact input sizes
134+
`check here <https://github.com/pytorch/vision/blob/d2bfd639e46e1c5dc3c177f889dc7750c8d137c7/references/classification/train.py#L92-L93>`_
135+
136+
ImageNet 1-crop error rates
117137

118138
================================ ============= =============
119139
Model Acc@1 Acc@5
@@ -151,6 +171,14 @@ Wide ResNet-50-2 78.468 94.086
151171
Wide ResNet-101-2 78.848 94.284
152172
MNASNet 1.0 73.456 91.510
153173
MNASNet 0.5 67.734 87.490
174+
EfficientNet-B0 77.692 93.532
175+
EfficientNet-B1 78.642 94.186
176+
EfficientNet-B2 80.608 95.310
177+
EfficientNet-B3 82.008 96.054
178+
EfficientNet-B4 83.384 96.594
179+
EfficientNet-B5 83.444 96.628
180+
EfficientNet-B6 84.008 96.916
181+
EfficientNet-B7 84.122 96.908
154182
================================ ============= =============
155183

156184

@@ -166,6 +194,7 @@ MNASNet 0.5 67.734 87.490
166194
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
167195
.. _ResNeXt: https://arxiv.org/abs/1611.05431
168196
.. _MNASNet: https://arxiv.org/abs/1807.11626
197+
.. _EfficientNet: https://arxiv.org/abs/1905.11946
169198

170199
.. currentmodule:: torchvision.models
171200

@@ -267,6 +296,18 @@ MNASNet
267296
.. autofunction:: mnasnet1_0
268297
.. autofunction:: mnasnet1_3
269298

299+
EfficientNet
300+
------------
301+
302+
.. autofunction:: efficientnet_b0
303+
.. autofunction:: efficientnet_b1
304+
.. autofunction:: efficientnet_b2
305+
.. autofunction:: efficientnet_b3
306+
.. autofunction:: efficientnet_b4
307+
.. autofunction:: efficientnet_b5
308+
.. autofunction:: efficientnet_b6
309+
.. autofunction:: efficientnet_b7
310+
270311
Quantized Models
271312
----------------
272313

hubconf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
1616
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
1717
mnasnet1_3
18+
from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \
19+
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
1820

1921
# segmentation
2022
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \

references/classification/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@
6868
and [#3354](https://github.com/pytorch/vision/pull/3354) for details.
6969

7070

71+
### EfficientNet
72+
73+
The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108).
74+
75+
The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).
76+
7177
## Mixed precision training
7278
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex).
7379

references/classification/presets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torchvision.transforms import autoaugment, transforms
2+
from torchvision.transforms.functional import InterpolationMode
23

34

45
class ClassificationPresetTrain:
@@ -24,10 +25,11 @@ def __call__(self, img):
2425

2526

2627
class ClassificationPresetEval:
27-
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
28+
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
29+
interpolation=InterpolationMode.BILINEAR):
2830

2931
self.transforms = transforms.Compose([
30-
transforms.Resize(resize_size),
32+
transforms.Resize(resize_size, interpolation=interpolation),
3133
transforms.CenterCrop(crop_size),
3234
transforms.ToTensor(),
3335
transforms.Normalize(mean=mean, std=std),

references/classification/train.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.utils.data
77
from torch import nn
88
import torchvision
9+
from torchvision.transforms.functional import InterpolationMode
910

1011
import presets
1112
import utils
@@ -82,7 +83,18 @@ def _get_cache_path(filepath):
8283
def load_data(traindir, valdir, args):
8384
# Data loading code
8485
print("Loading data")
85-
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
86+
resize_size, crop_size = 256, 224
87+
interpolation = InterpolationMode.BILINEAR
88+
if args.model == 'inception_v3':
89+
resize_size, crop_size = 342, 299
90+
elif args.model.startswith('efficientnet_'):
91+
sizes = {
92+
'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300),
93+
'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600),
94+
}
95+
e_type = args.model.replace('efficientnet_', '')
96+
resize_size, crop_size = sizes[e_type]
97+
interpolation = InterpolationMode.BICUBIC
8698

8799
print("Loading training data")
88100
st = time.time()
@@ -113,7 +125,8 @@ def load_data(traindir, valdir, args):
113125
else:
114126
dataset_test = torchvision.datasets.ImageFolder(
115127
valdir,
116-
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
128+
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size,
129+
interpolation=interpolation))
117130
if args.cache_dataset:
118131
print("Saving dataset_test to {}".format(cache_path))
119132
utils.mkdir(os.path.dirname(cache_path))
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .mobilenet import *
99
from .mnasnet import *
1010
from .shufflenetv2 import *
11+
from .efficientnet import *
1112
from . import segmentation
1213
from . import detection
1314
from . import video

0 commit comments

Comments
 (0)