From c33246e7c8aa7e601d739883d3d40052cd2d3bec Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 13 Feb 2022 21:07:22 +0800 Subject: [PATCH 01/38] Create dropblock.py --- torchvision/ops/dropblock.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 torchvision/ops/dropblock.py diff --git a/torchvision/ops/dropblock.py b/torchvision/ops/dropblock.py new file mode 100644 index 00000000000..aaa328524fc --- /dev/null +++ b/torchvision/ops/dropblock.py @@ -0,0 +1 @@ +import torch From ff34c8ebd852e9d48b530f1d5ecc52a307f9bd39 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 19 Feb 2022 23:04:10 +0800 Subject: [PATCH 02/38] add dropblock2d --- torchvision/ops/dropblock.py | 63 ++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/torchvision/ops/dropblock.py b/torchvision/ops/dropblock.py index aaa328524fc..1f6c51bfa94 100644 --- a/torchvision/ops/dropblock.py +++ b/torchvision/ops/dropblock.py @@ -1 +1,64 @@ import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ..utils import _log_api_usage_once + + +class DropBlock2d(nn.Module): + """ + Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks" + `. + + Args: + p (float): Probability of an element to be dropped. + block_size (int): Size of the block to drop. + inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False`` + """ + + def __init__(self, p: float, block_size: int, inplace: bool = False) -> None: + super(DropBlock2d, self).__init__() + _log_api_usage_once(self) + + if p < 0.0 or p > 1.0: + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") + self.p = p + self.block_size = block_size + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Args: + input (Tensor): Input feature map on which some areas will be randomly + dropped. + Returns: + Tensor: The tensor after DropBlock layer. + """ + if not self.training: + return input + + N, C, H, W = input.size() + # compute the gamma of Bernoulli distribution + gamma = (self.p * H * W) / ((self.block_size ** 2) * ((H - self.block_size + 1) * \ + (W - self.block_size + 1))) + mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1) + mask = torch.bernoulli(torch.full(mask_shape, gamma, device=input.device)) + + mask = F.pad(mask, [self.block_size // 2] * 4, value=0) + mask = F.max_pool2d( + input=mask, + stride=(1, 1), + kernel_size=(self.block_size, self.block_size), + padding=self.block_size // 2) + mask = 1 - mask + normalize_scale = mask.numel() / (1e-6 + mask.sum()) + if self.inplace: + input.mul_(mask * normalize_scale) + else: + input = input * mask * normalize_scale + return input + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, " + \ + f"inplace={self.inplace})" + return s From 2a86d775e8cef38fb4e65f3d8bfc7f1fd41f3b87 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 19 Feb 2022 23:11:31 +0800 Subject: [PATCH 03/38] fix pylint --- torchvision/ops/dropblock.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/torchvision/ops/dropblock.py b/torchvision/ops/dropblock.py index 1f6c51bfa94..8eea84ec272 100644 --- a/torchvision/ops/dropblock.py +++ b/torchvision/ops/dropblock.py @@ -36,20 +36,17 @@ def forward(self, input: Tensor) -> Tensor: """ if not self.training: return input - + N, C, H, W = input.size() # compute the gamma of Bernoulli distribution - gamma = (self.p * H * W) / ((self.block_size ** 2) * ((H - self.block_size + 1) * \ - (W - self.block_size + 1))) + gamma = (self.p * H * W) / ((self.block_size ** 2) * ((H - self.block_size + 1) * (W - self.block_size + 1))) mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1) mask = torch.bernoulli(torch.full(mask_shape, gamma, device=input.device)) mask = F.pad(mask, [self.block_size // 2] * 4, value=0) mask = F.max_pool2d( - input=mask, - stride=(1, 1), - kernel_size=(self.block_size, self.block_size), - padding=self.block_size // 2) + input=mask, stride=(1, 1), kernel_size=(self.block_size, self.block_size), padding=self.block_size + ) mask = 1 - mask normalize_scale = mask.numel() / (1e-6 + mask.sum()) if self.inplace: @@ -59,6 +56,5 @@ def forward(self, input: Tensor) -> Tensor: return input def __repr__(self) -> str: - s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, " + \ - f"inplace={self.inplace})" + s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" return s From a90e036244ad8cd9e8254e44b7cd59e4b4981045 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 20 Feb 2022 11:05:47 +0800 Subject: [PATCH 04/38] refactor dropblock --- torchvision/ops/dropblock.py | 74 +++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/torchvision/ops/dropblock.py b/torchvision/ops/dropblock.py index 8eea84ec272..81b3c39e405 100644 --- a/torchvision/ops/dropblock.py +++ b/torchvision/ops/dropblock.py @@ -5,26 +5,65 @@ from ..utils import _log_api_usage_once -class DropBlock2d(nn.Module): +def drop_block2d(input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, + training: bool = True) -> Tensor: """ Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks" `. Args: + input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one + being its batch i.e. a batch with ``N`` rows. p (float): Probability of an element to be dropped. block_size (int): Size of the block to drop. - inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False`` + inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. + eps (float): A value added to the denominator for numerical stability. Default: 1e-6. + training (bool): apply dropblock if is ``True``. Default: ``True` + Returns: + Tensor[N, ...]: The randomly zeroed tensor after dropblock. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(drop_block2d) + if p < 0.0 or p > 1.0: + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") + if not training or p == 0.0: + return input + + N, C, H, W = input.size() + # compute the gamma of Bernoulli distribution + gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1))) + noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, + device=input.device) + noise.bernoulli_(gamma) + + noise = F.pad(noise, [block_size // 2] * 4, value=0) + noise = F.max_pool2d( + noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2 + ) + noise = 1 - noise + normalize_scale = noise.numel() / (eps + noise.sum()) + if inplace: + input.mul_(noise).mul_(normalize_scale) + else: + input = input * noise * normalize_scale + return input + + +torch.fx.wrap("drop_block2d") + + +class DropBlock2d(nn.Module): + """ + See :func:`drop_block2d`. """ - def __init__(self, p: float, block_size: int, inplace: bool = False) -> None: - super(DropBlock2d, self).__init__() - _log_api_usage_once(self) + def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: + super().__init__() - if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") self.p = p self.block_size = block_size self.inplace = inplace + self.eps = eps def forward(self, input: Tensor) -> Tensor: """ @@ -34,26 +73,7 @@ def forward(self, input: Tensor) -> Tensor: Returns: Tensor: The tensor after DropBlock layer. """ - if not self.training: - return input - - N, C, H, W = input.size() - # compute the gamma of Bernoulli distribution - gamma = (self.p * H * W) / ((self.block_size ** 2) * ((H - self.block_size + 1) * (W - self.block_size + 1))) - mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1) - mask = torch.bernoulli(torch.full(mask_shape, gamma, device=input.device)) - - mask = F.pad(mask, [self.block_size // 2] * 4, value=0) - mask = F.max_pool2d( - input=mask, stride=(1, 1), kernel_size=(self.block_size, self.block_size), padding=self.block_size - ) - mask = 1 - mask - normalize_scale = mask.numel() / (1e-6 + mask.sum()) - if self.inplace: - input.mul_(mask * normalize_scale) - else: - input = input * mask * normalize_scale - return input + return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training) def __repr__(self) -> str: s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" From 09f13965a33f944306b45dde2a2d2d9544a2b9ae Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 20 Feb 2022 11:07:11 +0800 Subject: [PATCH 05/38] add dropblock --- torchvision/ops/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 8ba10080c1f..f733627d3e7 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -21,6 +21,7 @@ from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth +from .drop_block import drop_block2d, DropBlock2d _register_custom_op() From f2799816768c016731ef62515dc6b9aae13112e7 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 20 Feb 2022 11:07:34 +0800 Subject: [PATCH 06/38] Rename dropblock.py to drop_block.py --- torchvision/ops/{dropblock.py => drop_block.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename torchvision/ops/{dropblock.py => drop_block.py} (100%) diff --git a/torchvision/ops/dropblock.py b/torchvision/ops/drop_block.py similarity index 100% rename from torchvision/ops/dropblock.py rename to torchvision/ops/drop_block.py From ade32f06f2f569114f71f8fc415651bd53080ce8 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 20 Feb 2022 11:17:21 +0800 Subject: [PATCH 07/38] fix pylint --- torchvision/ops/drop_block.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 81b3c39e405..93e4f50f44d 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -5,8 +5,9 @@ from ..utils import _log_api_usage_once -def drop_block2d(input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, - training: bool = True) -> Tensor: +def drop_block2d( + input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True +) -> Tensor: """ Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks" `. @@ -32,18 +33,15 @@ def drop_block2d(input: Tensor, p: float, block_size: int, inplace: bool = False N, C, H, W = input.size() # compute the gamma of Bernoulli distribution gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1))) - noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, - device=input.device) + noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device) noise.bernoulli_(gamma) noise = F.pad(noise, [block_size // 2] * 4, value=0) - noise = F.max_pool2d( - noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2 - ) + noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2) noise = 1 - noise normalize_scale = noise.numel() / (eps + noise.sum()) if inplace: - input.mul_(noise).mul_(normalize_scale) + input.mul_(noise).mul_(normalize_scale) else: input = input * noise * normalize_scale return input From 2f7a10dff58b9d2e12a6eb0c0c96ee8cdc67bba2 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 20 Feb 2022 11:18:01 +0800 Subject: [PATCH 08/38] add dropblock --- torchvision/ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index f733627d3e7..4b113b5ac7c 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -11,6 +11,7 @@ ) from .boxes import box_convert from .deform_conv import deform_conv2d, DeformConv2d +from .drop_block import drop_block2d, DropBlock2d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss @@ -21,7 +22,6 @@ from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth -from .drop_block import drop_block2d, DropBlock2d _register_custom_op() From 29edef564958a85e8429fb80c8e8e9af85864817 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 20 Feb 2022 20:07:22 +0800 Subject: [PATCH 09/38] add dropblock3d --- torchvision/ops/drop_block.py | 65 +++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 93e4f50f44d..9a994c6d144 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -50,6 +50,53 @@ def drop_block2d( torch.fx.wrap("drop_block2d") +def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, + training: bool = True) -> Tensor: + """ + Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks" + `. + + Args: + input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one + being its batch i.e. a batch with ``N`` rows. + p (float): Probability of an element to be dropped. + block_size (int): Size of the block to drop. + inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. + eps (float): A value added to the denominator for numerical stability. Default: 1e-6. + training (bool): apply dropblock if is ``True``. Default: ``True` + Returns: + Tensor[N, ...]: The randomly zeroed tensor after dropblock. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(drop_block3d) + if p < 0.0 or p > 1.0: + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") + if not training or p == 0.0: + return input + + N, C, D, H, W = input.size() + # compute the gamma of Bernoulli distribution + gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) + noise = torch.empty((N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, + device=input.device) + noise.bernoulli_(gamma) + + noise = F.pad(noise, [block_size // 2] * 6, value=0) + noise = F.max_pool3d( + noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2 + ) + noise = 1 - noise + normalize_scale = noise.numel() / (eps + noise.sum()) + if inplace: + input.mul_(noise).mul_(normalize_scale) + else: + input = input * noise * normalize_scale + return input + + +torch.fx.wrap("drop_block3d") + + class DropBlock2d(nn.Module): """ See :func:`drop_block2d`. @@ -76,3 +123,21 @@ def forward(self, input: Tensor) -> Tensor: def __repr__(self) -> str: s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" return s + +class DropBlock3d(DropBlock2d): + """ + See :func:`drop_block3d`. + """ + + def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: + super().__init__(p, block_size, inplace, eps) + + def forward(self, input: Tensor) -> Tensor: + """ + Args: + input (Tensor): Input feature map on which some areas will be randomly + dropped. + Returns: + Tensor: The tensor after DropBlock layer. + """ + return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training) From 9969e966282f80204f008aa71af7802ec4a8ca88 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 20 Feb 2022 20:07:55 +0800 Subject: [PATCH 10/38] add drop_block3d --- torchvision/ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 4b113b5ac7c..8a837fe9e79 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -11,7 +11,7 @@ ) from .boxes import box_convert from .deform_conv import deform_conv2d, DeformConv2d -from .drop_block import drop_block2d, DropBlock2d +from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d, from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss From e5c505e98301927081cef18667dca20ba1bbc6ff Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 10:25:44 +0800 Subject: [PATCH 11/38] add dropblock --- torchvision/ops/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 8a837fe9e79..bc92616b702 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -11,7 +11,7 @@ ) from .boxes import box_convert from .deform_conv import deform_conv2d, DeformConv2d -from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d, +from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss @@ -55,4 +55,8 @@ "ConvNormActivation", "SqueezeExcitation", "generalized_box_iou_loss", + "drop_block2d", + "DropBlock2d", + "drop_block3d", + "DropBlock3d" ] From 5ba51bede4f2a307e0209cab8a091da8f87a5ac6 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 10:26:13 +0800 Subject: [PATCH 12/38] Update drop_block.py --- torchvision/ops/drop_block.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 9a994c6d144..38565610920 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -123,7 +123,8 @@ def forward(self, input: Tensor) -> Tensor: def __repr__(self) -> str: s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" return s - + + class DropBlock3d(DropBlock2d): """ See :func:`drop_block3d`. From 2901effa534c95798f5cdced7ee9081d2dfd7c5f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 18:56:46 +0800 Subject: [PATCH 13/38] Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis --- torchvision/ops/drop_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 38565610920..02ce98ddc7d 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -124,7 +124,7 @@ def __repr__(self) -> str: s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" return s - +torch.fx.wrap("drop_block3d") class DropBlock3d(DropBlock2d): """ See :func:`drop_block3d`. From 90f86f6890f48e0d85a2151f227e4e79d00de874 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 18:56:52 +0800 Subject: [PATCH 14/38] Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis --- torchvision/ops/drop_block.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 02ce98ddc7d..b13f7016233 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -47,7 +47,6 @@ def drop_block2d( return input -torch.fx.wrap("drop_block2d") def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, From 918c979d5602a3334783f9f8272d0963e18f490b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 18:57:14 +0800 Subject: [PATCH 15/38] Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis --- torchvision/ops/drop_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index b13f7016233..c9bfd4cc057 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -93,7 +93,7 @@ def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False return input -torch.fx.wrap("drop_block3d") +torch.fx.wrap("drop_block2d") class DropBlock2d(nn.Module): From 77ea0ab34c8d5ee048a5ff39ead4cef56a3db020 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 18:58:16 +0800 Subject: [PATCH 16/38] Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis --- torchvision/ops/drop_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index c9bfd4cc057..e95aff887f9 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -87,7 +87,7 @@ def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False noise = 1 - noise normalize_scale = noise.numel() / (eps + noise.sum()) if inplace: - input.mul_(noise).mul_(normalize_scale) + input.mul_(noise).mul_(normalize_scale) else: input = input * noise * normalize_scale return input From fefa74e0c7cbc589a48a850f3d4b8e4e437ea2f0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 19:00:00 +0800 Subject: [PATCH 17/38] Update drop_block.py --- torchvision/ops/drop_block.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index e95aff887f9..bcee2cf5091 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -123,7 +123,10 @@ def __repr__(self) -> str: s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" return s + torch.fx.wrap("drop_block3d") + + class DropBlock3d(DropBlock2d): """ See :func:`drop_block3d`. From f5b79ee6313f74d30ac44b9f1c882bb8493b58e0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 19:00:22 +0800 Subject: [PATCH 18/38] Update drop_block.py --- torchvision/ops/drop_block.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index bcee2cf5091..75e10673c7d 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -47,8 +47,6 @@ def drop_block2d( return input - - def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True) -> Tensor: """ From b45a9e6ecb334572554e1239f688cf8bda9ee50f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 19:15:06 +0800 Subject: [PATCH 19/38] import torch.fx --- torchvision/ops/drop_block.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 75e10673c7d..b3a6d6f7399 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -1,4 +1,5 @@ import torch +import torch.fx import torch.nn.functional as F from torch import nn, Tensor From c669853892d23da01cae27fa22899b5bc9154fc6 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 19:33:37 +0800 Subject: [PATCH 20/38] fix lint --- torchvision/ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index bc92616b702..c1756c58a54 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -58,5 +58,5 @@ "drop_block2d", "DropBlock2d", "drop_block3d", - "DropBlock3d" + "DropBlock3d", ] From 892f1e52dd8e23a8ea54de2191e654ae61557afd Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 19:36:00 +0800 Subject: [PATCH 21/38] fix lint --- torchvision/ops/drop_block.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index b3a6d6f7399..48662c6a6bb 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -45,11 +45,11 @@ def drop_block2d( input.mul_(noise).mul_(normalize_scale) else: input = input * noise * normalize_scale - return input + return input -def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, - training: bool = True) -> Tensor: +def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True +) -> Tensor: """ Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks" `. @@ -75,8 +75,9 @@ def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False N, C, D, H, W = input.size() # compute the gamma of Bernoulli distribution gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) - noise = torch.empty((N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, - device=input.device) + noise = torch.empty( + (N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device + ) noise.bernoulli_(gamma) noise = F.pad(noise, [block_size // 2] * 6, value=0) @@ -89,7 +90,7 @@ def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False input.mul_(noise).mul_(normalize_scale) else: input = input * noise * normalize_scale - return input + return input torch.fx.wrap("drop_block2d") From 7c5e909fabe9c1f1029d56656b7f1528228d7fcf Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 22:08:09 +0800 Subject: [PATCH 22/38] Update drop_block.py --- torchvision/ops/drop_block.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 48662c6a6bb..9a3eba76671 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -26,12 +26,11 @@ def drop_block2d( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block2d) - if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") if not training or p == 0.0: return input N, C, H, W = input.size() + block_size = min(blcok_size, W, H) # compute the gamma of Bernoulli distribution gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1))) noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device) @@ -67,12 +66,11 @@ def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block3d) - if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") if not training or p == 0.0: return input N, C, D, H, W = input.size() + block_size = min(blcok_size, D, H, W) # compute the gamma of Bernoulli distribution gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) noise = torch.empty( @@ -90,7 +88,7 @@ def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False input.mul_(noise).mul_(normalize_scale) else: input = input * noise * normalize_scale - return input + return input torch.fx.wrap("drop_block2d") @@ -103,7 +101,11 @@ class DropBlock2d(nn.Module): def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: super().__init__() - + + if p < 0.0 or p > 1.0: + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") + if block_size % 2 == 0: + raise ValueError(f"block size has to be an odd number, but got {block_size}") self.p = p self.block_size = block_size self.inplace = inplace From d06bc2455a978e67bd2544c6db5be6b43b69c037 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 22:09:29 +0800 Subject: [PATCH 23/38] improve dropblock --- torchvision/ops/drop_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 9a3eba76671..1c57fe9ccfc 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -70,7 +70,7 @@ def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False return input N, C, D, H, W = input.size() - block_size = min(blcok_size, D, H, W) + block_size = min(block_size, D, H, W) # compute the gamma of Bernoulli distribution gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) noise = torch.empty( From fdac2f45a732a880224d4674230faf1d2e6d08ac Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 22:13:22 +0800 Subject: [PATCH 24/38] add dropblock --- docs/source/ops.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 2a960474205..e3373b98317 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -20,6 +20,8 @@ Operators box_iou clip_boxes_to_image deform_conv2d + drop_block2d + drop_block3d generalized_box_iou generalized_box_iou_loss masks_to_boxes @@ -47,3 +49,5 @@ Operators FrozenBatchNorm2d ConvNormActivation SqueezeExcitation + DropBlock2d + DropBlock3d From aedd5f0e00e6b19b21f156cd5a98ded789bb6352 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 22:40:22 +0800 Subject: [PATCH 25/38] refactor dropblock --- torchvision/ops/drop_block.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 1c57fe9ccfc..e993b1e2874 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -26,6 +26,12 @@ def drop_block2d( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block2d) + if p < 0.0 or p > 1.0: + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") + if block_size % 2 == 0: + raise ValueError(f"block size has to be an odd number, but got {block_size}") + if input.ndim != 4: + raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: return input @@ -47,7 +53,8 @@ def drop_block2d( return input -def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True +def drop_block3d( + input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True ) -> Tensor: """ Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks" @@ -66,6 +73,12 @@ def drop_block3d(input: Tensor, p: float, block_size: int, inplace: bool = False """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block3d) + if p < 0.0 or p > 1.0: + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") + if block_size % 2 == 0: + raise ValueError(f"block size has to be an odd number, but got {block_size}") + if input.ndim != 5: + raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: return input @@ -101,11 +114,7 @@ class DropBlock2d(nn.Module): def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: super().__init__() - - if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") - if block_size % 2 == 0: - raise ValueError(f"block size has to be an odd number, but got {block_size}") + self.p = p self.block_size = block_size self.inplace = inplace From af7305ec404927f498606cb0e4223df019c26b83 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 21 Feb 2022 22:59:29 +0800 Subject: [PATCH 26/38] fix doc --- torchvision/ops/drop_block.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index e993b1e2874..cf27db0eeb2 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -20,9 +20,10 @@ def drop_block2d( block_size (int): Size of the block to drop. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. eps (float): A value added to the denominator for numerical stability. Default: 1e-6. - training (bool): apply dropblock if is ``True``. Default: ``True` + training (bool): apply dropblock if is ``True``. Default: ``True`` + Returns: - Tensor[N, ...]: The randomly zeroed tensor after dropblock. + Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block2d) @@ -67,9 +68,10 @@ def drop_block3d( block_size (int): Size of the block to drop. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. eps (float): A value added to the denominator for numerical stability. Default: 1e-6. - training (bool): apply dropblock if is ``True``. Default: ``True` + training (bool): apply dropblock if is ``True``. Default: ``True`` + Returns: - Tensor[N, ...]: The randomly zeroed tensor after dropblock. + Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block3d) From 2dd89afaef65eecaea9451bd65ad44ba2614113d Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 22 Feb 2022 10:43:27 +0800 Subject: [PATCH 27/38] remove the limitation of block_size --- torchvision/ops/drop_block.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index cf27db0eeb2..03982e2c1d2 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -21,16 +21,14 @@ def drop_block2d( inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. eps (float): A value added to the denominator for numerical stability. Default: 1e-6. training (bool): apply dropblock if is ``True``. Default: ``True`` - + Returns: Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block2d) if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") - if block_size % 2 == 0: - raise ValueError(f"block size has to be an odd number, but got {block_size}") + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") if input.ndim != 4: raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: @@ -76,9 +74,7 @@ def drop_block3d( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block3d) if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") - if block_size % 2 == 0: - raise ValueError(f"block size has to be an odd number, but got {block_size}") + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") if input.ndim != 5: raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: From 4f40274e223ab6cf56baec199ca2b0537b7e1f05 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 22 Feb 2022 10:44:15 +0800 Subject: [PATCH 28/38] Update torchvision/ops/drop_block.py Co-authored-by: Vasilis Vryniotis --- torchvision/ops/drop_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 03982e2c1d2..2aa562ab7c8 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -35,7 +35,7 @@ def drop_block2d( return input N, C, H, W = input.size() - block_size = min(blcok_size, W, H) + block_size = min(block_size, W, H) # compute the gamma of Bernoulli distribution gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1))) noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device) From b1f91e5e6f0a88ac79af15bf146925c7f40cf56b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 22 Feb 2022 14:24:11 +0800 Subject: [PATCH 29/38] fix lint --- torchvision/ops/drop_block.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 2aa562ab7c8..185760b96b9 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -20,7 +20,7 @@ def drop_block2d( block_size (int): Size of the block to drop. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. eps (float): A value added to the denominator for numerical stability. Default: 1e-6. - training (bool): apply dropblock if is ``True``. Default: ``True`` + training (bool): apply dropblock if is ``True``. Default: ``True``. Returns: Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock. @@ -28,7 +28,7 @@ def drop_block2d( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block2d) if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") if input.ndim != 4: raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: @@ -66,7 +66,7 @@ def drop_block3d( block_size (int): Size of the block to drop. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. eps (float): A value added to the denominator for numerical stability. Default: 1e-6. - training (bool): apply dropblock if is ``True``. Default: ``True`` + training (bool): apply dropblock if is ``True``. Default: ``True``. Returns: Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock. @@ -74,7 +74,7 @@ def drop_block3d( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block3d) if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") if input.ndim != 5: raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: From 60cf55919e974263bf938cb157d71a1516b5b26c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 22 Feb 2022 14:38:36 +0800 Subject: [PATCH 30/38] fix lint --- torchvision/ops/drop_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index 185760b96b9..a798677f60f 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -28,7 +28,7 @@ def drop_block2d( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block2d) if p < 0.0 or p > 1.0: - raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") if input.ndim != 4: raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: From 2b3d9cc50cb5a088c4cb65108ff44fdc2ad32d90 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 22 Feb 2022 16:00:55 +0800 Subject: [PATCH 31/38] add dropblock --- test/test_ops.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 6b35b4f0091..41bf6eacc42 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from functools import lru_cache from typing import Callable, List, Tuple +from itertools import product import numpy as np import pytest @@ -57,6 +58,16 @@ def forward(self, a): self.layer(a) +class DropBlockWrapper(nn.Module): + def __init__(self, obj): + super().__init__() + self.layer = obj + self.n_inputs = 1 + + def forward(self, a): + self.layer(a) + + class RoIOpTester(ABC): dtype = torch.float64 @@ -1355,6 +1366,56 @@ def test_split_normalization_params(self, norm_layer): assert len(params[0]) == 92 assert len(params[1]) == 82 + + +class TestDropBlock: + @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("dim", (2, 3)) + @pytest.mark.parametrize("p", (0, 0.2, 0.5, 0.8)) + @pytest.mark.parametrize("block_size", [5, 7, 9, 11]) + @pytest.mark.parametrize("inplace", [True, False]) + def test_drop_block(self, seed, dim, p, block_size, inplace): + torch.manual_seed(seed) + batch_size = 5 + channels = 3 + height = 11 + width = height + depth = height + if dim == 2: + x = torch.ones(size=(batch_size, channels, height, width)) + layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace) + feature_size = height * width + elif dim == 3: + x = torch.ones(size=(batch_size, channels, depth, height, width)) + layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace) + feature_size = depth * height * width + layer.__repr__() + + out = layer(x) + if p == 0: + assert out.equal(x) + if block_size == height: + for b, c in product(range(batch_size), range(channels)): + assert out[b, c].count_nonzero() in (0, feature_size) + + def make_obj(self, dim, p, block_size, inplace, wrap=False): + if dim == 2: + obj = ops.DropBlock2d(p, block_size, inplace) + elif dim == 3: + obj = ops.DropBlock3d(p, block_size, inplace) + return DropBlockWrapper(obj) if wrap else obj + + @pytest.mark.parametrize("dim", (2, 3)) + @pytest.mark.parametrize("p", (0, 1)) + @pytest.mark.parametrize("block_size", [5, 7, 9, 11]) + @pytest.mark.parametrize("inplace", [True, False]) + def test_is_leaf_node(self, dim, p, block_size, inplace): + op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True) + graph_node_names = get_graph_node_names(op_obj) + + assert len(graph_node_names) == 2 + assert len(graph_node_names[0]) == len(graph_node_names[1]) + assert len(graph_node_names[0]) == 1 + op_obj.n_inputs if __name__ == "__main__": From 4019e7a17f478603040f58ecc218982f4ffc38d1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 22 Feb 2022 10:57:23 +0000 Subject: [PATCH 32/38] Fix linter --- test/test_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 41bf6eacc42..b71aaaca490 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,8 +2,8 @@ import os from abc import ABC, abstractmethod from functools import lru_cache -from typing import Callable, List, Tuple from itertools import product +from typing import Callable, List, Tuple import numpy as np import pytest @@ -1366,8 +1366,8 @@ def test_split_normalization_params(self, norm_layer): assert len(params[0]) == 92 assert len(params[1]) == 82 - - + + class TestDropBlock: @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("dim", (2, 3)) From dcf9296d479360c27d1a83e6b0dd639d1fe6a329 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 23 Feb 2022 14:54:25 +0800 Subject: [PATCH 33/38] add dropblock random check --- test/test_ops.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index b71aaaca490..4f209c8d162 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1370,8 +1370,8 @@ def test_split_normalization_params(self, norm_layer): class TestDropBlock: @pytest.mark.parametrize("seed", range(10)) - @pytest.mark.parametrize("dim", (2, 3)) - @pytest.mark.parametrize("p", (0, 0.2, 0.5, 0.8)) + @pytest.mark.parametrize("dim", [2, 3]) + @pytest.mark.parametrize("p", [0, 0.2, 0.5, 0.8]) @pytest.mark.parametrize("block_size", [5, 7, 9, 11]) @pytest.mark.parametrize("inplace", [True, False]) def test_drop_block(self, seed, dim, p, block_size, inplace): @@ -1398,6 +1398,38 @@ def test_drop_block(self, seed, dim, p, block_size, inplace): for b, c in product(range(batch_size), range(channels)): assert out[b, c].count_nonzero() in (0, feature_size) + @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("dim", [2, 3]) + @pytest.mark.parametrize("p", [0.1, 0.2]) + @pytest.mark.parametrize("block_size", [3, 5]) + @pytest.mark.parametrize("inplace", [False,]) + def test_drop_block_random(self, seed, dim, p, block_size, inplace): + torch.manual_seed(seed) + batch_size = 5 + channels = 3 + height = 16 + width = height + depth = height + if dim == 2: + x = torch.ones(size=(batch_size, channels, height, width)) + layer = DropBlock2d(p=p, block_size=block_size, inplace=inplace) + elif dim == 3: + x = torch.ones(size=(batch_size, channels, depth, height, width)) + layer = DropBlock3d(p=p, block_size=block_size, inplace=inplace) + + trials = 250 + num_samples = 0 + counts = 0 + cell_numel = torch.tensor(x.shape).prod() + for _ in range(trials): + with torch.no_grad(): + out = layer(x) + non_zero_count = out.nonzero().size(0) + counts += cell_numel - non_zero_count + num_samples += cell_numel + + assert abs(p - counts / num_samples) / p < 0.15 + def make_obj(self, dim, p, block_size, inplace, wrap=False): if dim == 2: obj = ops.DropBlock2d(p, block_size, inplace) @@ -1406,7 +1438,7 @@ def make_obj(self, dim, p, block_size, inplace, wrap=False): return DropBlockWrapper(obj) if wrap else obj @pytest.mark.parametrize("dim", (2, 3)) - @pytest.mark.parametrize("p", (0, 1)) + @pytest.mark.parametrize("p", [0, 1]) @pytest.mark.parametrize("block_size", [5, 7, 9, 11]) @pytest.mark.parametrize("inplace", [True, False]) def test_is_leaf_node(self, dim, p, block_size, inplace): From 84cd3dc5d368c5bf74b5851d50755b5d52bdddf6 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 23 Feb 2022 15:10:22 +0800 Subject: [PATCH 34/38] reduce test time --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4f209c8d162..1a85b1236d4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1401,13 +1401,13 @@ def test_drop_block(self, seed, dim, p, block_size, inplace): @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("p", [0.1, 0.2]) - @pytest.mark.parametrize("block_size", [3, 5]) + @pytest.mark.parametrize("block_size", [3,]) @pytest.mark.parametrize("inplace", [False,]) def test_drop_block_random(self, seed, dim, p, block_size, inplace): torch.manual_seed(seed) batch_size = 5 channels = 3 - height = 16 + height = 11 width = height depth = height if dim == 2: From b159f4d56295e073f968652e219ce4e29fa2fe71 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 23 Feb 2022 15:54:09 +0800 Subject: [PATCH 35/38] Update test_ops.py --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1a85b1236d4..fe567d52492 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1412,10 +1412,10 @@ def test_drop_block_random(self, seed, dim, p, block_size, inplace): depth = height if dim == 2: x = torch.ones(size=(batch_size, channels, height, width)) - layer = DropBlock2d(p=p, block_size=block_size, inplace=inplace) + layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace) elif dim == 3: x = torch.ones(size=(batch_size, channels, depth, height, width)) - layer = DropBlock3d(p=p, block_size=block_size, inplace=inplace) + layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace) trials = 250 num_samples = 0 From ebea5396d8f06dea15de6b6cecaa42d2bc20636f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 23 Feb 2022 20:41:07 +0800 Subject: [PATCH 36/38] speed the dropblock test --- test/test_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index fe567d52492..b1ef1c6293e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1371,8 +1371,8 @@ def test_split_normalization_params(self, norm_layer): class TestDropBlock: @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("dim", [2, 3]) - @pytest.mark.parametrize("p", [0, 0.2, 0.5, 0.8]) - @pytest.mark.parametrize("block_size", [5, 7, 9, 11]) + @pytest.mark.parametrize("p", [0, 0.5]) + @pytest.mark.parametrize("block_size", [5, 11]) @pytest.mark.parametrize("inplace", [True, False]) def test_drop_block(self, seed, dim, p, block_size, inplace): torch.manual_seed(seed) @@ -1439,7 +1439,7 @@ def make_obj(self, dim, p, block_size, inplace, wrap=False): @pytest.mark.parametrize("dim", (2, 3)) @pytest.mark.parametrize("p", [0, 1]) - @pytest.mark.parametrize("block_size", [5, 7, 9, 11]) + @pytest.mark.parametrize("block_size", [5, 7]) @pytest.mark.parametrize("inplace", [True, False]) def test_is_leaf_node(self, dim, p, block_size, inplace): op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True) From 8d89128f93f8bc9c22b95b3146bd5e9f98e3db44 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 23 Feb 2022 22:57:14 +0800 Subject: [PATCH 37/38] fix lint --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index b1ef1c6293e..ad9aaefee52 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1401,8 +1401,8 @@ def test_drop_block(self, seed, dim, p, block_size, inplace): @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("p", [0.1, 0.2]) - @pytest.mark.parametrize("block_size", [3,]) - @pytest.mark.parametrize("inplace", [False,]) + @pytest.mark.parametrize("block_size", [3]) + @pytest.mark.parametrize("inplace", [False]) def test_drop_block_random(self, seed, dim, p, block_size, inplace): torch.manual_seed(seed) batch_size = 5 From 4e76a423bc0ef8c7065125f60745f4745da44885 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 25 Feb 2022 13:21:48 +0000 Subject: [PATCH 38/38] Patch scripts for training dropblock resnet --- references/classification/README.md | 7 +++++++ references/classification/train.py | 1 + torchvision/models/resnet.py | 22 ++++++++++++++++++---- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/references/classification/README.md b/references/classification/README.md index e75336f23ca..1e0a253e914 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -53,6 +53,13 @@ torchrun --nproc_per_node=8 train.py --model $MODEL Here `$MODEL` is one of `resnet18`, `resnet34`, `resnet50`, `resnet101` or `resnet152`. +### ResNet with dropblock +``` +torchrun --nproc_per_node=8 train.py --model resnet50 -b 128 --lr 0.4 --epochs 270 +``` + + + ### ResNext ``` torchrun --nproc_per_node=8 train.py\ diff --git a/references/classification/train.py b/references/classification/train.py index b00a11fcac3..3a1c8674856 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -288,6 +288,7 @@ def main(args): f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR " "are supported." ) + main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[125, 200, 250], gamma=0.1) if args.lr_warmup_epochs > 0: if args.lr_warmup_method == "linear": diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index b0bb8d13ade..43a6d19e9d5 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -6,6 +6,7 @@ from .._internally_replaced_utils import load_state_dict_from_url from ..utils import _log_api_usage_once +from ..ops import DropBlock2d __all__ = [ @@ -122,6 +123,7 @@ def __init__( base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, + p: float = 0.0, ) -> None: super().__init__() if norm_layer is None: @@ -130,31 +132,40 @@ def __init__( # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) + # we won't be doing scheduled p + self.drop1 = DropBlock2d(p, 7) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) + self.drop2 = DropBlock2d(p, 7) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) + self.drop3 = DropBlock2d(p, 7) self.relu = nn.ReLU(inplace=True) self.downsample = downsample + self.drop4 = DropBlock2d(p, 7) self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x - + # as in https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/resnet/resnet_model.py#L545-L579 out = self.conv1(x) out = self.bn1(out) out = self.relu(out) + out = self.drop1(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) + out = self.drop2(out) out = self.conv3(out) out = self.bn3(out) + out = self.drop3(out) if self.downsample is not None: identity = self.downsample(x) + identity = self.drop4(identity) out += identity out = self.relu(out) @@ -198,8 +209,9 @@ def __init__( self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + # https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/resnet/resnet_main.py#L393-L394 + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], p=0.1 / 4) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], p=0.1) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) @@ -227,6 +239,7 @@ def _make_layer( blocks: int, stride: int = 1, dilate: bool = False, + p: float = 0.0, ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None @@ -243,7 +256,7 @@ def _make_layer( layers = [] layers.append( block( - self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer + self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, p ) ) self.inplanes = planes * block.expansion @@ -256,6 +269,7 @@ def _make_layer( base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, + p=p ) )