Skip to content

Add EfficientNet Architecture in TorchVision #4293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ architectures for image classification:
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
- `EfficientNet`_

You can construct a model with random weights by calling its constructor:

Expand All @@ -47,6 +48,14 @@ You can construct a model with random weights by calling its constructor:
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
efficientnet_b0 = models.efficientnet_b0()
efficientnet_b1 = models.efficientnet_b1()
efficientnet_b2 = models.efficientnet_b2()
efficientnet_b3 = models.efficientnet_b3()
efficientnet_b4 = models.efficientnet_b4()
efficientnet_b5 = models.efficientnet_b5()
efficientnet_b6 = models.efficientnet_b6()
efficientnet_b7 = models.efficientnet_b7()

We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
Expand All @@ -68,6 +77,14 @@ These can be constructed by passing ``pretrained=True``:
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True)

Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
Expand Down Expand Up @@ -113,7 +130,10 @@ Unfortunately, the concrete `subset` that was used is lost. For more
information see `this discussion <https://github.com/pytorch/vision/issues/1439>`_
or `these experiments <https://github.com/pytorch/vision/pull/1965>`_.

ImageNet 1-crop error rates (224x224)
The sizes of the EfficientNet models depend on the variant. For the exact input sizes
`check here <https://github.com/pytorch/vision/blob/d2bfd639e46e1c5dc3c177f889dc7750c8d137c7/references/classification/train.py#L92-L93>`_

ImageNet 1-crop error rates

================================ ============= =============
Model Acc@1 Acc@5
Expand Down Expand Up @@ -151,6 +171,14 @@ Wide ResNet-50-2 78.468 94.086
Wide ResNet-101-2 78.848 94.284
MNASNet 1.0 73.456 91.510
MNASNet 0.5 67.734 87.490
EfficientNet-B0 77.692 93.532
EfficientNet-B1 78.642 94.186
EfficientNet-B2 80.608 95.310
EfficientNet-B3 82.008 96.054
EfficientNet-B4 83.384 96.594
EfficientNet-B5 83.444 96.628
EfficientNet-B6 84.008 96.916
EfficientNet-B7 84.122 96.908
================================ ============= =============


Expand All @@ -166,6 +194,7 @@ MNASNet 0.5 67.734 87.490
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626
.. _EfficientNet: https://arxiv.org/abs/1905.11946

.. currentmodule:: torchvision.models

Expand Down Expand Up @@ -267,6 +296,18 @@ MNASNet
.. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3

EfficientNet
------------

.. autofunction:: efficientnet_b0
.. autofunction:: efficientnet_b1
.. autofunction:: efficientnet_b2
.. autofunction:: efficientnet_b3
.. autofunction:: efficientnet_b4
.. autofunction:: efficientnet_b5
.. autofunction:: efficientnet_b6
.. autofunction:: efficientnet_b7

Quantized Models
----------------

Expand Down
2 changes: 2 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3
from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7

# segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
Expand Down
6 changes: 6 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@
and [#3354](https://github.com/pytorch/vision/pull/3354) for details.


### EfficientNet

The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108).

The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).

## Mixed precision training
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex).

Expand Down
6 changes: 4 additions & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode


class ClassificationPresetTrain:
Expand All @@ -24,10 +25,11 @@ def __call__(self, img):


class ClassificationPresetEval:
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR):

self.transforms = transforms.Compose([
transforms.Resize(resize_size),
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
Expand Down
17 changes: 15 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils.data
from torch import nn
import torchvision
from torchvision.transforms.functional import InterpolationMode

import presets
import utils
Expand Down Expand Up @@ -82,7 +83,18 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
resize_size, crop_size = 256, 224
interpolation = InterpolationMode.BILINEAR
if args.model == 'inception_v3':
resize_size, crop_size = 342, 299
elif args.model.startswith('efficientnet_'):
sizes = {
'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300),
'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600),
}
e_type = args.model.replace('efficientnet_', '')
resize_size, crop_size = sizes[e_type]
interpolation = InterpolationMode.BICUBIC
Comment on lines +86 to +97
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @datumbox on chat, I think it would be good in the future to factor this out somewhere else, maybe as a set of custom preset transforms


print("Loading training data")
st = time.time()
Expand Down Expand Up @@ -113,7 +125,8 @@ def load_data(traindir, valdir, args):
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size,
interpolation=interpolation))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import *
from .efficientnet import *
from . import segmentation
from . import detection
from . import video
Expand Down
Loading