Skip to content

[NOMERGE] Dropblock ResNet50 training #5483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c33246e
Create dropblock.py
xiaohu2015 Feb 13, 2022
ff34c8e
add dropblock2d
xiaohu2015 Feb 19, 2022
2a86d77
fix pylint
xiaohu2015 Feb 19, 2022
a90e036
refactor dropblock
xiaohu2015 Feb 20, 2022
09f1396
add dropblock
xiaohu2015 Feb 20, 2022
f279981
Rename dropblock.py to drop_block.py
xiaohu2015 Feb 20, 2022
f8cb184
Merge branch 'pytorch:main' into main
xiaohu2015 Feb 20, 2022
ade32f0
fix pylint
xiaohu2015 Feb 20, 2022
2f7a10d
add dropblock
xiaohu2015 Feb 20, 2022
29edef5
add dropblock3d
xiaohu2015 Feb 20, 2022
9969e96
add drop_block3d
xiaohu2015 Feb 20, 2022
bb5be85
Merge branch 'main' into main
xiaohu2015 Feb 20, 2022
a6900f6
Merge branch 'pytorch:main' into main
xiaohu2015 Feb 21, 2022
e5c505e
add dropblock
xiaohu2015 Feb 21, 2022
5ba51be
Update drop_block.py
xiaohu2015 Feb 21, 2022
2901eff
Update torchvision/ops/drop_block.py
xiaohu2015 Feb 21, 2022
90f86f6
Update torchvision/ops/drop_block.py
xiaohu2015 Feb 21, 2022
918c979
Update torchvision/ops/drop_block.py
xiaohu2015 Feb 21, 2022
77ea0ab
Update torchvision/ops/drop_block.py
xiaohu2015 Feb 21, 2022
fefa74e
Update drop_block.py
xiaohu2015 Feb 21, 2022
f5b79ee
Update drop_block.py
xiaohu2015 Feb 21, 2022
8c84c73
Merge branch 'pytorch:main' into main
xiaohu2015 Feb 21, 2022
b45a9e6
import torch.fx
xiaohu2015 Feb 21, 2022
c669853
fix lint
xiaohu2015 Feb 21, 2022
892f1e5
fix lint
xiaohu2015 Feb 21, 2022
7c5e909
Update drop_block.py
xiaohu2015 Feb 21, 2022
d06bc24
improve dropblock
xiaohu2015 Feb 21, 2022
fdac2f4
add dropblock
xiaohu2015 Feb 21, 2022
aedd5f0
refactor dropblock
xiaohu2015 Feb 21, 2022
af7305e
fix doc
xiaohu2015 Feb 21, 2022
2dd89af
remove the limitation of block_size
xiaohu2015 Feb 22, 2022
4f40274
Update torchvision/ops/drop_block.py
xiaohu2015 Feb 22, 2022
b1f91e5
fix lint
xiaohu2015 Feb 22, 2022
60cf559
fix lint
xiaohu2015 Feb 22, 2022
2b3d9cc
add dropblock
xiaohu2015 Feb 22, 2022
4019e7a
Fix linter
datumbox Feb 22, 2022
df0001a
Merge branch 'main' into main
xiaohu2015 Feb 23, 2022
dcf9296
add dropblock random check
xiaohu2015 Feb 23, 2022
84cd3dc
reduce test time
xiaohu2015 Feb 23, 2022
b159f4d
Update test_ops.py
xiaohu2015 Feb 23, 2022
ebea539
speed the dropblock test
xiaohu2015 Feb 23, 2022
6ba5147
Merge branch 'main' into main
datumbox Feb 23, 2022
8d89128
fix lint
xiaohu2015 Feb 23, 2022
4e76a42
Patch scripts for training dropblock resnet
datumbox Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Operators
box_iou
clip_boxes_to_image
deform_conv2d
drop_block2d
drop_block3d
generalized_box_iou
generalized_box_iou_loss
masks_to_boxes
Expand Down Expand Up @@ -47,3 +49,5 @@ Operators
FrozenBatchNorm2d
ConvNormActivation
SqueezeExcitation
DropBlock2d
DropBlock3d
7 changes: 7 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ torchrun --nproc_per_node=8 train.py --model $MODEL

Here `$MODEL` is one of `resnet18`, `resnet34`, `resnet50`, `resnet101` or `resnet152`.

### ResNet with dropblock
```
torchrun --nproc_per_node=8 train.py --model resnet50 -b 128 --lr 0.4 --epochs 270
```



### ResNext
```
torchrun --nproc_per_node=8 train.py\
Expand Down
1 change: 1 addition & 0 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def main(args):
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported."
)
main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[125, 200, 250], gamma=0.1)

if args.lr_warmup_epochs > 0:
if args.lr_warmup_method == "linear":
Expand Down
93 changes: 93 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from abc import ABC, abstractmethod
from functools import lru_cache
from itertools import product
from typing import Callable, List, Tuple

import numpy as np
Expand Down Expand Up @@ -57,6 +58,16 @@ def forward(self, a):
self.layer(a)


class DropBlockWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 1

def forward(self, a):
self.layer(a)


class RoIOpTester(ABC):
dtype = torch.float64

Expand Down Expand Up @@ -1357,5 +1368,87 @@ def test_split_normalization_params(self, norm_layer):
assert len(params[1]) == 82


class TestDropBlock:
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("dim", [2, 3])
@pytest.mark.parametrize("p", [0, 0.5])
@pytest.mark.parametrize("block_size", [5, 11])
@pytest.mark.parametrize("inplace", [True, False])
def test_drop_block(self, seed, dim, p, block_size, inplace):
torch.manual_seed(seed)
batch_size = 5
channels = 3
height = 11
width = height
depth = height
if dim == 2:
x = torch.ones(size=(batch_size, channels, height, width))
layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
feature_size = height * width
elif dim == 3:
x = torch.ones(size=(batch_size, channels, depth, height, width))
layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
feature_size = depth * height * width
layer.__repr__()

out = layer(x)
if p == 0:
assert out.equal(x)
if block_size == height:
for b, c in product(range(batch_size), range(channels)):
assert out[b, c].count_nonzero() in (0, feature_size)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("dim", [2, 3])
@pytest.mark.parametrize("p", [0.1, 0.2])
@pytest.mark.parametrize("block_size", [3])
@pytest.mark.parametrize("inplace", [False])
def test_drop_block_random(self, seed, dim, p, block_size, inplace):
torch.manual_seed(seed)
batch_size = 5
channels = 3
height = 11
width = height
depth = height
if dim == 2:
x = torch.ones(size=(batch_size, channels, height, width))
layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
elif dim == 3:
x = torch.ones(size=(batch_size, channels, depth, height, width))
layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)

trials = 250
num_samples = 0
counts = 0
cell_numel = torch.tensor(x.shape).prod()
for _ in range(trials):
with torch.no_grad():
out = layer(x)
non_zero_count = out.nonzero().size(0)
counts += cell_numel - non_zero_count
num_samples += cell_numel

assert abs(p - counts / num_samples) / p < 0.15

def make_obj(self, dim, p, block_size, inplace, wrap=False):
if dim == 2:
obj = ops.DropBlock2d(p, block_size, inplace)
elif dim == 3:
obj = ops.DropBlock3d(p, block_size, inplace)
return DropBlockWrapper(obj) if wrap else obj

@pytest.mark.parametrize("dim", (2, 3))
@pytest.mark.parametrize("p", [0, 1])
@pytest.mark.parametrize("block_size", [5, 7])
@pytest.mark.parametrize("inplace", [True, False])
def test_is_leaf_node(self, dim, p, block_size, inplace):
op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
graph_node_names = get_graph_node_names(op_obj)

assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


if __name__ == "__main__":
pytest.main([__file__])
22 changes: 18 additions & 4 deletions torchvision/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .._internally_replaced_utils import load_state_dict_from_url
from ..utils import _log_api_usage_once
from ..ops import DropBlock2d


__all__ = [
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
p: float = 0.0,
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -130,31 +132,40 @@ def __init__(
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
# we won't be doing scheduled p
self.drop1 = DropBlock2d(p, 7)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.drop2 = DropBlock2d(p, 7)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.drop3 = DropBlock2d(p, 7)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.drop4 = DropBlock2d(p, 7)
self.stride = stride

def forward(self, x: Tensor) -> Tensor:
identity = x

# as in https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/resnet/resnet_model.py#L545-L579
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.drop1(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.drop2(out)

out = self.conv3(out)
out = self.bn3(out)
out = self.drop3(out)

if self.downsample is not None:
identity = self.downsample(x)

identity = self.drop4(identity)
out += identity
out = self.relu(out)

Expand Down Expand Up @@ -198,8 +209,9 @@ def __init__(
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
# https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/resnet/resnet_main.py#L393-L394
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], p=0.1 / 4)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], p=0.1)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

Expand Down Expand Up @@ -227,6 +239,7 @@ def _make_layer(
blocks: int,
stride: int = 1,
dilate: bool = False,
p: float = 0.0,
) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
Expand All @@ -243,7 +256,7 @@ def _make_layer(
layers = []
layers.append(
block(
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, p
)
)
self.inplanes = planes * block.expansion
Expand All @@ -256,6 +269,7 @@ def _make_layer(
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
p=p
)
)

Expand Down
5 changes: 5 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from .boxes import box_convert
from .deform_conv import deform_conv2d, DeformConv2d
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss
Expand Down Expand Up @@ -54,4 +55,8 @@
"ConvNormActivation",
"SqueezeExcitation",
"generalized_box_iou_loss",
"drop_block2d",
"DropBlock2d",
"drop_block3d",
"DropBlock3d",
]
Loading