diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index df5ab9a044c..ec2e287d974 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,7 +1,7 @@ from collections import OrderedDict from torch import nn -from typing import Dict +from typing import Dict, Optional class IntermediateLayerGetter(nn.ModuleDict): @@ -64,3 +64,19 @@ def forward(self, x): out_name = self.return_layers[name] out[out_name] = x return out + + +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 08d48c68020..c78caad1be7 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -11,8 +11,8 @@ from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .. import mobilenet -from ..mobilenetv3 import ConvBNActivation from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops.misc import ConvNormActivation __all__ = ['ssdlite320_mobilenet_v3_large'] @@ -28,8 +28,8 @@ def _prediction_block(in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential: return nn.Sequential( # 3x3 depthwise with stride 1 and padding 1 - ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, - norm_layer=norm_layer, activation_layer=nn.ReLU6), + ConvNormActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, + norm_layer=norm_layer, activation_layer=nn.ReLU6), # 1x1 projetion to output channels nn.Conv2d(in_channels, out_channels, 1) @@ -41,16 +41,16 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., intermediate_channels = out_channels // 2 return nn.Sequential( # 1x1 projection to half output channels - ConvBNActivation(in_channels, intermediate_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation), + ConvNormActivation(in_channels, intermediate_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation), # 3x3 depthwise with stride 2 and padding 1 - ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, - groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), + ConvNormActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, + groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), # 1x1 projetion to output channels - ConvBNActivation(intermediate_channels, out_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation), + ConvNormActivation(intermediate_channels, out_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation), ) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index bad5b57b25b..4dd23c1ea45 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -4,14 +4,13 @@ from functools import partial from torch import nn, Tensor -from torch.nn import functional as F from typing import Any, Callable, List, Optional, Sequence from .._internally_replaced_utils import load_state_dict_from_url +from ..ops.misc import ConvNormActivation, SqueezeExcitation +from ._utils import _make_divisible from torchvision.ops import StochasticDepth -from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible - __all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3", "efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"] @@ -31,32 +30,6 @@ } -class SqueezeExcitation(nn.Module): - def __init__( - self, - input_channels: int, - squeeze_channels: int, - activation: Callable[..., nn.Module] = nn.ReLU, - scale_activation: Callable[..., nn.Module] = nn.Sigmoid, - ) -> None: - super().__init__() - self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) - self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) - self.activation = activation() - self.scale_activation = scale_activation() - - def _scale(self, input: Tensor) -> Tensor: - scale = F.adaptive_avg_pool2d(input, 1) - scale = self.fc1(scale) - scale = self.activation(scale) - scale = self.fc2(scale) - return self.scale_activation(scale) - - def forward(self, input: Tensor) -> Tensor: - scale = self._scale(input) - return scale * input - - class MBConvConfig: # Stores information listed at Table 1 of the EfficientNet paper def __init__(self, @@ -106,21 +79,21 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: # expand expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) if expanded_channels != cnf.input_channels: - layers.append(ConvBNActivation(cnf.input_channels, expanded_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append(ConvNormActivation(cnf.input_channels, expanded_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation_layer)) # depthwise - layers.append(ConvBNActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel, - stride=cnf.stride, groups=expanded_channels, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append(ConvNormActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel, + stride=cnf.stride, groups=expanded_channels, + norm_layer=norm_layer, activation_layer=activation_layer)) # squeeze and excitation squeeze_channels = max(1, cnf.input_channels // 4) layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True))) # project - layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, - activation_layer=nn.Identity)) + layers.append(ConvNormActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, + activation_layer=None)) self.block = nn.Sequential(*layers) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") @@ -174,8 +147,8 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels - layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, - activation_layer=nn.SiLU)) + layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, + activation_layer=nn.SiLU)) # building inverted residual blocks total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting]) @@ -202,8 +175,8 @@ def __init__( # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 4 * lastconv_input_channels - layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=nn.SiLU)) + layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=nn.SiLU)) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 0cfa4f371e3..9e68fbfc5c7 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,7 +1,12 @@ import torch +import warnings + +from functools import partial from torch import nn from torch import Tensor from .._internally_replaced_utils import load_state_dict_from_url +from ..ops.misc import ConvNormActivation +from ._utils import _make_divisible from typing import Callable, Any, Optional, List @@ -13,50 +18,21 @@ } -def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: - """ - This function is taken from the original tf repo. - It ensures that all layers have a channel number that is divisible by 8 - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class ConvBNActivation(nn.Sequential): - def __init__( - self, - in_planes: int, - out_planes: int, - kernel_size: int = 3, - stride: int = 1, - groups: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - activation_layer: Optional[Callable[..., nn.Module]] = None, - dilation: int = 1, - ) -> None: - padding = (kernel_size - 1) // 2 * dilation - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if activation_layer is None: - activation_layer = nn.ReLU6 - super().__init__( - nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, - bias=False), - norm_layer(out_planes), - activation_layer(inplace=True) - ) - self.out_channels = out_planes +# necessary for backwards compatibility +class _DeprecatedConvBNAct(ConvNormActivation): + def __init__(self, *args, **kwargs): + warnings.warn( + "The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. " + "Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning) + if kwargs.get("norm_layer", None) is None: + kwargs["norm_layer"] = nn.BatchNorm2d + if kwargs.get("activation_layer", None) is None: + kwargs["activation_layer"] = nn.ReLU6 + super().__init__(*args, **kwargs) -# necessary for backwards compatibility -ConvBNReLU = ConvBNActivation +ConvBNReLU = _DeprecatedConvBNAct +ConvBNActivation = _DeprecatedConvBNAct class InvertedResidual(nn.Module): @@ -81,10 +57,12 @@ def __init__( layers: List[nn.Module] = [] if expand_ratio != 1: # pw - layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) + layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, + activation_layer=nn.ReLU6)) layers.extend([ # dw - ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), + ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer, + activation_layer=nn.ReLU6), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), norm_layer(oup), @@ -154,7 +132,8 @@ def __init__( # building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] + features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, + activation_layer=nn.ReLU6)] # building inverted residual blocks for t, c, n, s in inverted_residual_setting: output_channel = _make_divisible(c * width_mult, round_nearest) @@ -163,7 +142,8 @@ def __init__( features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) input_channel = output_channel # building last several layers - features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) + features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, + activation_layer=nn.ReLU6)) # make it nn.Sequential self.features = nn.Sequential(*features) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 0485c9d61e5..537e2136bbb 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,8 +6,8 @@ from typing import Any, Callable, List, Optional, Sequence from .._internally_replaced_utils import load_state_dict_from_url -from .efficientnet import SqueezeExcitation as SElayer -from .mobilenetv2 import _make_divisible, ConvBNActivation +from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer +from ._utils import _make_divisible __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] @@ -28,7 +28,8 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4): self.relu = self.activation delattr(self, 'activation') warnings.warn( - "This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning) + "This SqueezeExcitation class is deprecated and will be removed in future versions. " + "Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning) class InvertedResidualConfig: @@ -64,21 +65,21 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod # expand if cnf.expanded_channels != cnf.input_channels: - layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append(ConvNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation_layer)) # depthwise stride = 1 if cnf.dilation > 1 else cnf.stride - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, - stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append(ConvNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, + stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, + norm_layer=norm_layer, activation_layer=activation_layer)) if cnf.use_se: squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8) layers.append(se_layer(cnf.expanded_channels, squeeze_channels)) # project - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, - activation_layer=nn.Identity)) + layers.append(ConvNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, + activation_layer=None)) self.block = nn.Sequential(*layers) self.out_channels = cnf.out_channels @@ -130,8 +131,8 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels - layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, - activation_layer=nn.Hardswish)) + layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, + activation_layer=nn.Hardswish)) # building inverted residual blocks for cnf in inverted_residual_setting: @@ -140,8 +141,8 @@ def __init__( # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels - layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=nn.Hardswish)) + layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=nn.Hardswish)) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index f914fdc7815..2349afff447 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -5,9 +5,10 @@ from typing import Any -from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from torch.quantization import QuantStub, DeQuantStub, fuse_modules from .utils import _replace_relu, quantize_model +from ...ops.misc import ConvNormActivation __all__ = ['QuantizableMobileNetV2', 'mobilenet_v2'] @@ -55,7 +56,7 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self) -> None: for m in self.modules(): - if type(m) == ConvBNReLU: + if type(m) == ConvNormActivation: fuse_modules(m, ['0', '1', '2'], inplace=True) if type(m) == QuantizableInvertedResidual: m.fuse_model() diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 8c64f137053..8655a9b0a45 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,8 +1,8 @@ import torch from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from ..efficientnet import SqueezeExcitation as SElayer -from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ +from ...ops.misc import ConvNormActivation, SqueezeExcitation +from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3,\ model_urls, _mobilenet_v3_conf from torch.quantization import QuantStub, DeQuantStub, fuse_modules from typing import Any, List, Optional @@ -17,7 +17,7 @@ } -class QuantizableSqueezeExcitation(SElayer): +class QuantizableSqueezeExcitation(SqueezeExcitation): _version = 2 def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -103,9 +103,9 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self) -> None: for m in self.modules(): - if type(m) == ConvBNActivation: + if type(m) == ConvNormActivation: modules_to_fuse = ['0', '1'] - if type(m[2]) == nn.ReLU: + if len(m) == 3 and type(m[2]) == nn.ReLU: modules_to_fuse.append('2') fuse_modules(m, modules_to_fuse, inplace=True) elif type(m) == QuantizableSqueezeExcitation: diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index bbab59c4074..0bd89f7799f 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -12,8 +12,8 @@ from torch import nn, Tensor from .._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible -from torchvision.models.efficientnet import SqueezeExcitation +from ..ops.misc import ConvNormActivation, SqueezeExcitation +from ._utils import _make_divisible __all__ = ["RegNet", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", @@ -32,7 +32,7 @@ } -class SimpleStemIN(ConvBNActivation): +class SimpleStemIN(ConvNormActivation): """Simple stem for ImageNet: 3x3, BN, ReLU.""" def __init__( @@ -64,10 +64,10 @@ def __init__( w_b = int(round(width_out * bottleneck_multiplier)) g = w_b // group_width - layers["a"] = ConvBNActivation(width_in, w_b, kernel_size=1, stride=1, - norm_layer=norm_layer, activation_layer=activation_layer) - layers["b"] = ConvBNActivation(w_b, w_b, kernel_size=3, stride=stride, groups=g, - norm_layer=norm_layer, activation_layer=activation_layer) + layers["a"] = ConvNormActivation(width_in, w_b, kernel_size=1, stride=1, + norm_layer=norm_layer, activation_layer=activation_layer) + layers["b"] = ConvNormActivation(w_b, w_b, kernel_size=3, stride=stride, groups=g, + norm_layer=norm_layer, activation_layer=activation_layer) if se_ratio: # The SE reduction ratio is defined with respect to the @@ -79,8 +79,8 @@ def __init__( activation=activation_layer, ) - layers["c"] = ConvBNActivation(w_b, width_out, kernel_size=1, stride=1, - norm_layer=norm_layer, activation_layer=nn.Identity) + layers["c"] = ConvNormActivation(w_b, width_out, kernel_size=1, stride=1, + norm_layer=norm_layer, activation_layer=None) super().__init__(layers) @@ -104,8 +104,8 @@ def __init__( self.proj = None should_proj = (width_in != width_out) or (stride != 1) if should_proj: - self.proj = ConvBNActivation(width_in, width_out, kernel_size=1, - stride=stride, norm_layer=norm_layer, activation_layer=nn.Identity) + self.proj = ConvNormActivation(width_in, width_out, kernel_size=1, + stride=stride, norm_layer=norm_layer, activation_layer=None) self.f = BottleneckTransform( width_in, width_out, diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 7e43caa78d6..7ee8df3371e 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -11,7 +11,7 @@ import warnings import torch from torch import Tensor -from typing import List, Optional +from typing import Callable, List, Optional class Conv2d(torch.nn.Conv2d): @@ -97,3 +97,56 @@ def forward(self, x: Tensor) -> Tensor: def __repr__(self) -> str: return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" + + +class ConvNormActivation(torch.nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: bool = True, + ) -> None: + if padding is None: + padding = (kernel_size - 1) // 2 * dilation + layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, + dilation=dilation, groups=groups, bias=norm_layer is None)] + if norm_layer is not None: + layers.append(norm_layer(out_channels)) + if activation_layer is not None: + layers.append(activation_layer(inplace=inplace)) + super().__init__(*layers) + self.out_channels = out_channels + + +class SqueezeExcitation(torch.nn.Module): + def __init__( + self, + input_channels: int, + squeeze_channels: int, + activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, + scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, + ) -> None: + super().__init__() + self.avgpool = torch.nn.AdaptiveAvgPool2d(1) + self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) + self.activation = activation() + self.scale_activation = scale_activation() + + def _scale(self, input: Tensor) -> Tensor: + scale = self.avgpool(input) + scale = self.fc1(scale) + scale = self.activation(scale) + scale = self.fc2(scale) + return self.scale_activation(scale) + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input) + return scale * input