Skip to content

Moving common layers to ops #4504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 30, 2021
18 changes: 17 additions & 1 deletion torchvision/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import OrderedDict

from torch import nn
from typing import Dict
from typing import Dict, Optional


class IntermediateLayerGetter(nn.ModuleDict):
Expand Down Expand Up @@ -64,3 +64,19 @@ def forward(self, x):
out_name = self.return_layers[name]
out[out_name] = x
return out


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
18 changes: 9 additions & 9 deletions torchvision/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
from .. import mobilenet
from ..mobilenetv3 import ConvBNActivation
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation


__all__ = ['ssdlite320_mobilenet_v3_large']
Expand All @@ -28,8 +28,8 @@ def _prediction_block(in_channels: int, out_channels: int, kernel_size: int,
norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
return nn.Sequential(
# 3x3 depthwise with stride 1 and padding 1
ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=nn.ReLU6),
ConvNormActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=nn.ReLU6),

# 1x1 projetion to output channels
nn.Conv2d(in_channels, out_channels, 1)
Expand All @@ -41,16 +41,16 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
intermediate_channels = out_channels // 2
return nn.Sequential(
# 1x1 projection to half output channels
ConvBNActivation(in_channels, intermediate_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(in_channels, intermediate_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),

# 3x3 depthwise with stride 2 and padding 1
ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),

# 1x1 projetion to output channels
ConvBNActivation(intermediate_channels, out_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(intermediate_channels, out_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
)


Expand Down
53 changes: 13 additions & 40 deletions torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

from functools import partial
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Callable, List, Optional, Sequence

from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation
from ._utils import _make_divisible
from torchvision.ops import StochasticDepth

from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible


__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3",
"efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"]
Expand All @@ -31,32 +30,6 @@
}


class SqueezeExcitation(nn.Module):
def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., nn.Module] = nn.ReLU,
scale_activation: Callable[..., nn.Module] = nn.Sigmoid,
) -> None:
super().__init__()
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()

def _scale(self, input: Tensor) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale)
scale = self.activation(scale)
scale = self.fc2(scale)
return self.scale_activation(scale)

def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input


class MBConvConfig:
# Stores information listed at Table 1 of the EfficientNet paper
def __init__(self,
Expand Down Expand Up @@ -106,21 +79,21 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer:
# expand
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(ConvNormActivation(cnf.input_channels, expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))

# depthwise
layers.append(ConvBNActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(ConvNormActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))

# squeeze and excitation
squeeze_channels = max(1, cnf.input_channels // 4)
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))

# project
layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity))
layers.append(ConvNormActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=None))

self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
Expand Down Expand Up @@ -174,8 +147,8 @@ def __init__(

# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.SiLU))
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.SiLU))

# building inverted residual blocks
total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting])
Expand All @@ -202,8 +175,8 @@ def __init__(
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.SiLU))
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.SiLU))

self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
Expand Down
72 changes: 26 additions & 46 deletions torchvision/models/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import torch
import warnings

from functools import partial
from torch import nn
from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ._utils import _make_divisible
from typing import Callable, Any, Optional, List


Expand All @@ -13,50 +18,21 @@
}


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


class ConvBNActivation(nn.Sequential):
def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
padding = (kernel_size - 1) // 2 * dilation
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation_layer is None:
activation_layer = nn.ReLU6
super().__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_planes),
activation_layer(inplace=True)
)
self.out_channels = out_planes
# necessary for backwards compatibility
class _DeprecatedConvBNAct(ConvNormActivation):
def __init__(self, *args, **kwargs):
warnings.warn(
"The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning)
if kwargs.get("norm_layer", None) is None:
kwargs["norm_layer"] = nn.BatchNorm2d
if kwargs.get("activation_layer", None) is None:
kwargs["activation_layer"] = nn.ReLU6
super().__init__(*args, **kwargs)


# necessary for backwards compatibility
ConvBNReLU = ConvBNActivation
ConvBNReLU = _DeprecatedConvBNAct
ConvBNActivation = _DeprecatedConvBNAct


class InvertedResidual(nn.Module):
Expand All @@ -81,10 +57,12 @@ def __init__(
layers: List[nn.Module] = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.ReLU6))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer,
activation_layer=nn.ReLU6),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
Expand Down Expand Up @@ -154,7 +132,8 @@ def __init__(
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer,
activation_layer=nn.ReLU6)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
Expand All @@ -163,7 +142,8 @@ def __init__(
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.ReLU6))
# make it nn.Sequential
self.features = nn.Sequential(*features)

Expand Down
29 changes: 15 additions & 14 deletions torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import Any, Callable, List, Optional, Sequence

from .._internally_replaced_utils import load_state_dict_from_url
from .efficientnet import SqueezeExcitation as SElayer
from .mobilenetv2 import _make_divisible, ConvBNActivation
from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer
from ._utils import _make_divisible


__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
Expand All @@ -28,7 +28,8 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4):
self.relu = self.activation
delattr(self, 'activation')
warnings.warn(
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
"This SqueezeExcitation class is deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning)


class InvertedResidualConfig:
Expand Down Expand Up @@ -64,21 +65,21 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod

# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(ConvNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))

# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se:
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))

# project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity))
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=None))

self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
Expand Down Expand Up @@ -130,8 +131,8 @@ def __init__(

# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.Hardswish))
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.Hardswish))

# building inverted residual blocks
for cnf in inverted_residual_setting:
Expand All @@ -140,8 +141,8 @@ def __init__(
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish))
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish))

self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
Expand Down
5 changes: 3 additions & 2 deletions torchvision/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from typing import Any

from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model
from ...ops.misc import ConvNormActivation


__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2']
Expand Down Expand Up @@ -55,7 +56,7 @@ def forward(self, x: Tensor) -> Tensor:

def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvBNReLU:
if type(m) == ConvNormActivation:
fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == QuantizableInvertedResidual:
m.fuse_model()
Expand Down
Loading