Skip to content

Pass custom scales on DefaultBoxGenerator and change default estimation #3766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/test_models_detection_anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def test_defaultbox_generator(self):
dboxes = model(images, features)

dboxes_output = torch.tensor([
[6.9750, 6.9750, 8.0250, 8.0250],
[6.7315, 6.7315, 8.2685, 8.2685],
[6.7575, 7.1288, 8.2425, 7.8712],
[7.1288, 6.7575, 7.8712, 8.2425]
[6.3750, 6.3750, 8.6250, 8.6250],
[4.7443, 4.7443, 10.2557, 10.2557],
[5.9090, 6.7045, 9.0910, 8.2955],
[6.7045, 5.9090, 8.2955, 9.0910]
])

self.assertEqual(len(dboxes), 2)
Expand Down
38 changes: 24 additions & 14 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,19 @@ class DefaultBoxGenerator(nn.Module):
Args:
aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
of the scales of each feature map.
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation
of the scales of each feature map.
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
the ``min_ratio`` and ``max_ratio`` parameters.
steps (List[int]], optional): It's a hyper-parameter that affects the tiling of defalt boxes. If not provided
it will be estimated from the data.
clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
"""

def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ratio: float = 0.9,
steps: Optional[List[int]] = None, clip: bool = True):
scales: Optional[List[float]] = None, steps: Optional[List[int]] = None, clip: bool = True):
super().__init__()
if steps is not None:
assert len(aspect_ratios) == len(steps)
Expand All @@ -158,15 +160,15 @@ def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_
num_outputs = len(aspect_ratios)

# Estimation of default boxes scales
# Inspired from https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_pascal.py#L311-L317
min_centile = int(100 * min_ratio)
max_centile = int(100 * max_ratio)
conv4_centile = min_centile // 2 # assume half of min_ratio as in paper
step = (max_centile - min_centile) // (num_outputs - 2)
centiles = [conv4_centile, min_centile]
for c in range(min_centile, max_centile + 1, step):
centiles.append(c + step)
self.scales = [c / 100 for c in centiles]
if scales is None:
if num_outputs > 1:
range_ratio = max_ratio - min_ratio
self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
self.scales.append(1.0)
else:
self.scales = [min_ratio, max_ratio]
else:
self.scales = scales

self._wh_pairs = []
for k in range(num_outputs):
Expand Down Expand Up @@ -207,9 +209,17 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten
for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair
for j in range(f_k[0]):
cy = (j + 0.5) / (float(f_k[0]) if self.steps is None else image_size[1] / self.steps[k])
if self.steps is not None:
y_f_k = image_size[1] / self.steps[k]
else:
y_f_k = float(f_k[0])
cy = (j + 0.5) / y_f_k
for i in range(f_k[1]):
cx = (i + 0.5) / (float(f_k[1]) if self.steps is None else image_size[0] / self.steps[k])
if self.steps is not None:
x_f_k = image_size[0] / self.steps[k]
else:
x_f_k = float(f_k[1])
cx = (i + 0.5) / x_f_k
default_boxes.extend([[cx, cy, w, h] for w, h in self._wh_pairs[k]])

dboxes = []
Expand Down
4 changes: 3 additions & 1 deletion torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,9 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
pretrained_backbone = False

backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers, True)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], steps=[8, 16, 32, 64, 100, 300])
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]],
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
steps=[8, 16, 32, 64, 100, 300])
model = SSD(backbone, anchor_generator, (300, 300), num_classes,
image_mean=[0.48235, 0.45882, 0.40784], image_std=[1., 1., 1.], **kwargs)
if pretrained:
Expand Down