Skip to content

Commit 730c5e1

Browse files
authored
Add SSD architecture with VGG16 backbone (#3403)
* Early skeleton of API. * Adding MultiFeatureMap and vgg16 backbone. * Making vgg16 backbone same as paper. * Making code generic to support all vggs. * Moving vgg's extra layers a separate class + L2 scaling. * Adding header vgg layers. * Fix maxpool patching. * Refactoring code to allow for support of different backbones & sizes: - Skeleton for Default Boxes generator class - Dynamic estimation of configuration when possible - Addition of types * Complete the implementation of DefaultBox generator. * Replace randn with empty. * Minor refactoring * Making clamping between 0 and 1 optional. * Change xywh to xyxy encoding. * Adding parameters and reusing objects in constructor. * Temporarily inherit from Retina to avoid dup code. * Implement forward methods + temp workarounds to inherit from retina. * Inherit more methods from retinanet. * Fix type error. * Add Regression loss. * Fixing JIT issues. * Change JIT workaround to minimize new code. * Fixing initialization bug. * Add classification loss. * Update todos. * Add weight loading support. * Support SSD512. * Change kernel_size to get output size 1x1 * Add xavier init and refactoring. * Adding unit-tests and fixing JIT issues. * Add a test for dbox generator. * Remove unnecessary import. * Workaround on GeneralizedRCNNTransform to support fixed size input. * Remove unnecessary random calls from the test. * Remove more rand calls from the test. * change mapping and handling of empty labels * Fix JIT warnings. * Speed up loss. * Convert 0-1 dboxes to original size. * Fix warning. * Fix tests. * Update comments. * Fixing minor bugs. * Introduce a custom DBoxMatcher. * Minor refactoring * Move extra layer definition inside feature extractor. * handle no bias on init. * Remove fixed image size limitation * Change initialization values for bias of classification head. * Refactoring and update test file. * Adding ResNet backbone. * Minor refactoring. * Remove inheritance of retina and general refactoring. * SSD should fix the input size. * Fixing messages and comments. * Silently ignoring exception if test-only. * Update comments. * Update regression loss. * Restore Xavier init everywhere, update the negative sampling method, change the clipping approach. * Fixing tests. * Refactor to move the losses from the Head to the SSD. * Removing resnet50 ssd version. * Adding support for best performing backbone and its config. * Refactor and clean up the API. * Fix lint * Update todos and comments. * Adding RandomHorizontalFlip and RandomIoUCrop transforms. * Adding necessary checks to our tranforms. * Adding RandomZoomOut. * Adding RandomPhotometricDistort. * Moving Detection transforms to references. * Update presets * fix lint * leave compose and object * Adding scaling for completeness. * Adding params in the repr * Remove unnecessary import. * minor refactoring * Remove unnecessary call. * Give better names to DBox* classes * Port num_anchors estimation in generator * Remove rescaling and fix presets * Add the ability to pass a custom head and refactoring. * fix lint * Fix unit-test * Update todos. * Change mean values. * Change the default parameter of SSD to train the full VGG16 and remove the catch of exception for eval only. * Adding documentation * Adding weights and updating readmes. * Update the model weights with a more performing model. * Adding doc for head. * Restore import.
1 parent 7c35e13 commit 730c5e1

14 files changed

+1032
-57
lines changed

docs/source/models.rst

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,18 @@ Object Detection, Instance Segmentation and Person Keypoint Detection
381381
The models subpackage contains definitions for the following model
382382
architectures for detection:
383383

384-
- `Faster R-CNN ResNet-50 FPN <https://arxiv.org/abs/1506.01497>`_
385-
- `Mask R-CNN ResNet-50 FPN <https://arxiv.org/abs/1703.06870>`_
384+
- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
385+
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
386+
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_
387+
- `SSD <https://arxiv.org/abs/1512.02325>`_
386388

387389
The pre-trained models for detection, instance segmentation and
388390
keypoint detection are initialized with the classification models
389391
in torchvision.
390392

391393
The models expect a list of ``Tensor[C, H, W]``, in the range ``0-1``.
392-
The models internally resize the images so that they have a minimum size
393-
of ``800``. This option can be changed by passing the option ``min_size``
394-
to the constructor of the models.
394+
The models internally resize the images but the behaviour varies depending
395+
on the model. Check the constructor of the models for more information.
395396

396397

397398
For object detection and instance segmentation, the pre-trained
@@ -425,6 +426,7 @@ Faster R-CNN ResNet-50 FPN 37.0 - -
425426
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
426427
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
427428
RetinaNet ResNet-50 FPN 36.4 - -
429+
SSD VGG16 25.1 - -
428430
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
429431
====================================== ======= ======== ===========
430432

@@ -483,6 +485,7 @@ Faster R-CNN ResNet-50 FPN 0.2288 0.0590
483485
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
484486
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
485487
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
488+
SSD VGG16 0.2093 0.0744 1.5
486489
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
487490
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
488491
====================================== =================== ================== ===========
@@ -502,6 +505,12 @@ RetinaNet
502505
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn
503506

504507

508+
SSD
509+
------------
510+
511+
.. autofunction:: torchvision.models.detection.ssd300_vgg16
512+
513+
505514
Mask R-CNN
506515
----------
507516

references/detection/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
4848
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
4949
```
5050

51+
### SSD VGG16
52+
```
53+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
54+
--dataset coco --model ssd300_vgg16 --epochs 120\
55+
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
56+
--weight-decay 0.0005 --data-augmentation ssd
57+
```
58+
5159

5260
### Mask R-CNN
5361
```

references/detection/presets.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@
22

33

44
class DetectionPresetTrain:
5-
def __init__(self, hflip_prob=0.5):
6-
trans = [T.ToTensor()]
7-
if hflip_prob > 0:
8-
trans.append(T.RandomHorizontalFlip(hflip_prob))
9-
10-
self.transforms = T.Compose(trans)
5+
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
6+
if data_augmentation == 'hflip':
7+
self.transforms = T.Compose([
8+
T.RandomHorizontalFlip(p=hflip_prob),
9+
T.ToTensor(),
10+
])
11+
elif data_augmentation == 'ssd':
12+
self.transforms = T.Compose([
13+
T.RandomPhotometricDistort(),
14+
T.RandomZoomOut(fill=list(mean)),
15+
T.RandomIoUCrop(),
16+
T.RandomHorizontalFlip(p=hflip_prob),
17+
T.ToTensor(),
18+
])
19+
else:
20+
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
1121

1222
def __call__(self, img, target):
1323
return self.transforms(img, target)

references/detection/train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def get_dataset(name, image_set, transform, data_path):
4747
return ds, num_classes
4848

4949

50-
def get_transform(train):
51-
return presets.DetectionPresetTrain() if train else presets.DetectionPresetEval()
50+
def get_transform(train, data_augmentation):
51+
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()
5252

5353

5454
def main(args):
@@ -60,8 +60,9 @@ def main(args):
6060
# Data loading code
6161
print("Loading data")
6262

63-
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
64-
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)
63+
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args.data_augmentation),
64+
args.data_path)
65+
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)
6566

6667
print("Creating data loaders")
6768
if args.distributed:
@@ -179,6 +180,7 @@ def main(args):
179180
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
180181
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
181182
help='number of trainable layers of backbone')
183+
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
182184
parser.add_argument(
183185
"--test-only",
184186
dest="test_only",

references/detection/transforms.py

Lines changed: 210 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
import random
1+
import torch
2+
import torchvision
23

4+
from torch import nn, Tensor
35
from torchvision.transforms import functional as F
6+
from torchvision.transforms import transforms as T
7+
from typing import List, Tuple, Dict, Optional
48

59

610
def _flip_coco_person_keypoints(kps, width):
@@ -23,27 +27,213 @@ def __call__(self, image, target):
2327
return image, target
2428

2529

26-
class RandomHorizontalFlip(object):
27-
def __init__(self, prob):
28-
self.prob = prob
29-
30-
def __call__(self, image, target):
31-
if random.random() < self.prob:
32-
height, width = image.shape[-2:]
33-
image = image.flip(-1)
34-
bbox = target["boxes"]
35-
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
36-
target["boxes"] = bbox
37-
if "masks" in target:
38-
target["masks"] = target["masks"].flip(-1)
39-
if "keypoints" in target:
40-
keypoints = target["keypoints"]
41-
keypoints = _flip_coco_person_keypoints(keypoints, width)
42-
target["keypoints"] = keypoints
30+
class RandomHorizontalFlip(T.RandomHorizontalFlip):
31+
def forward(self, image: Tensor,
32+
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
33+
if torch.rand(1) < self.p:
34+
image = F.hflip(image)
35+
if target is not None:
36+
width, _ = F._get_image_size(image)
37+
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
38+
if "masks" in target:
39+
target["masks"] = target["masks"].flip(-1)
40+
if "keypoints" in target:
41+
keypoints = target["keypoints"]
42+
keypoints = _flip_coco_person_keypoints(keypoints, width)
43+
target["keypoints"] = keypoints
4344
return image, target
4445

4546

46-
class ToTensor(object):
47-
def __call__(self, image, target):
47+
class ToTensor(nn.Module):
48+
def forward(self, image: Tensor,
49+
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
4850
image = F.to_tensor(image)
4951
return image, target
52+
53+
54+
class RandomIoUCrop(nn.Module):
55+
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5,
56+
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40):
57+
super().__init__()
58+
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
59+
self.min_scale = min_scale
60+
self.max_scale = max_scale
61+
self.min_aspect_ratio = min_aspect_ratio
62+
self.max_aspect_ratio = max_aspect_ratio
63+
if sampler_options is None:
64+
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
65+
self.options = sampler_options
66+
self.trials = trials
67+
68+
def forward(self, image: Tensor,
69+
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
70+
if target is None:
71+
raise ValueError("The targets can't be None for this transform.")
72+
73+
if isinstance(image, torch.Tensor):
74+
if image.ndimension() not in {2, 3}:
75+
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
76+
elif image.ndimension() == 2:
77+
image = image.unsqueeze(0)
78+
79+
orig_w, orig_h = F._get_image_size(image)
80+
81+
while True:
82+
# sample an option
83+
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
84+
min_jaccard_overlap = self.options[idx]
85+
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
86+
return image, target
87+
88+
for _ in range(self.trials):
89+
# check the aspect ratio limitations
90+
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
91+
new_w = int(orig_w * r[0])
92+
new_h = int(orig_h * r[1])
93+
aspect_ratio = new_w / new_h
94+
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
95+
continue
96+
97+
# check for 0 area crops
98+
r = torch.rand(2)
99+
left = int((orig_w - new_w) * r[0])
100+
top = int((orig_h - new_h) * r[1])
101+
right = left + new_w
102+
bottom = top + new_h
103+
if left == right or top == bottom:
104+
continue
105+
106+
# check for any valid boxes with centers within the crop area
107+
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
108+
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
109+
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
110+
if not is_within_crop_area.any():
111+
continue
112+
113+
# check at least 1 box with jaccard limitations
114+
boxes = target["boxes"][is_within_crop_area]
115+
ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]],
116+
dtype=boxes.dtype, device=boxes.device))
117+
if ious.max() < min_jaccard_overlap:
118+
continue
119+
120+
# keep only valid boxes and perform cropping
121+
target["boxes"] = boxes
122+
target["labels"] = target["labels"][is_within_crop_area]
123+
target["boxes"][:, 0::2] -= left
124+
target["boxes"][:, 1::2] -= top
125+
target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
126+
target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
127+
image = F.crop(image, top, left, new_h, new_w)
128+
129+
return image, target
130+
131+
132+
class RandomZoomOut(nn.Module):
133+
def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5):
134+
super().__init__()
135+
if fill is None:
136+
fill = [0., 0., 0.]
137+
self.fill = fill
138+
self.side_range = side_range
139+
if side_range[0] < 1. or side_range[0] > side_range[1]:
140+
raise ValueError("Invalid canvas side range provided {}.".format(side_range))
141+
self.p = p
142+
143+
@torch.jit.unused
144+
def _get_fill_value(self, is_pil):
145+
# type: (bool) -> int
146+
# We fake the type to make it work on JIT
147+
return tuple(int(x) for x in self.fill) if is_pil else 0
148+
149+
def forward(self, image: Tensor,
150+
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
151+
if isinstance(image, torch.Tensor):
152+
if image.ndimension() not in {2, 3}:
153+
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
154+
elif image.ndimension() == 2:
155+
image = image.unsqueeze(0)
156+
157+
if torch.rand(1) < self.p:
158+
return image, target
159+
160+
orig_w, orig_h = F._get_image_size(image)
161+
162+
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
163+
canvas_width = int(orig_w * r)
164+
canvas_height = int(orig_h * r)
165+
166+
r = torch.rand(2)
167+
left = int((canvas_width - orig_w) * r[0])
168+
top = int((canvas_height - orig_h) * r[1])
169+
right = canvas_width - (left + orig_w)
170+
bottom = canvas_height - (top + orig_h)
171+
172+
if torch.jit.is_scripting():
173+
fill = 0
174+
else:
175+
fill = self._get_fill_value(F._is_pil_image(image))
176+
177+
image = F.pad(image, [left, top, right, bottom], fill=fill)
178+
if isinstance(image, torch.Tensor):
179+
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
180+
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \
181+
image[..., :, (left + orig_w):] = v
182+
183+
if target is not None:
184+
target["boxes"][:, 0::2] += left
185+
target["boxes"][:, 1::2] += top
186+
187+
return image, target
188+
189+
190+
class RandomPhotometricDistort(nn.Module):
191+
def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5),
192+
hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5):
193+
super().__init__()
194+
self._brightness = T.ColorJitter(brightness=brightness)
195+
self._contrast = T.ColorJitter(contrast=contrast)
196+
self._hue = T.ColorJitter(hue=hue)
197+
self._saturation = T.ColorJitter(saturation=saturation)
198+
self.p = p
199+
200+
def forward(self, image: Tensor,
201+
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
202+
if isinstance(image, torch.Tensor):
203+
if image.ndimension() not in {2, 3}:
204+
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
205+
elif image.ndimension() == 2:
206+
image = image.unsqueeze(0)
207+
208+
r = torch.rand(7)
209+
210+
if r[0] < self.p:
211+
image = self._brightness(image)
212+
213+
contrast_before = r[1] < 0.5
214+
if contrast_before:
215+
if r[2] < self.p:
216+
image = self._contrast(image)
217+
218+
if r[3] < self.p:
219+
image = self._saturation(image)
220+
221+
if r[4] < self.p:
222+
image = self._hue(image)
223+
224+
if not contrast_before:
225+
if r[5] < self.p:
226+
image = self._contrast(image)
227+
228+
if r[6] < self.p:
229+
channels = F._get_image_num_channels(image)
230+
permutation = torch.randperm(channels)
231+
232+
is_pil = F._is_pil_image(image)
233+
if is_pil:
234+
image = F.to_tensor(image)
235+
image = image[..., permutation, :, :]
236+
if is_pil:
237+
image = F.to_pil_image(image)
238+
239+
return image, target
Binary file not shown.

test/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def get_available_video_models():
4444
"maskrcnn_resnet50_fpn": lambda x: x[1],
4545
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4646
"retinanet_resnet50_fpn": lambda x: x[1],
47+
"ssd300_vgg16": lambda x: x[1],
4748
}
4849

4950

0 commit comments

Comments
 (0)