Skip to content

Commit 4b2ad55

Browse files
authored
Refactor poolers (#4951)
* refactoring methods from MultiScaleRoIAlign
1 parent 9841a90 commit 4b2ad55

File tree

1 file changed

+25
-26
lines changed

1 file changed

+25
-26
lines changed

torchvision/ops/poolers.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,29 @@ def __call__(self, boxlists: List[Tensor]) -> Tensor:
8383
return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
8484

8585

86+
def _convert_to_roi_format(boxes: List[Tensor]) -> Tensor:
87+
concat_boxes = torch.cat(boxes, dim=0)
88+
device, dtype = concat_boxes.device, concat_boxes.dtype
89+
ids = torch.cat(
90+
[torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device) for i, b in enumerate(boxes)],
91+
dim=0,
92+
)
93+
rois = torch.cat([ids, concat_boxes], dim=1)
94+
return rois
95+
96+
97+
def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
98+
# assumption: the scale is of the form 2 ** (-k), with k integer
99+
size = feature.shape[-2:]
100+
possible_scales: List[float] = []
101+
for s1, s2 in zip(size, original_size):
102+
approx_scale = float(s1) / float(s2)
103+
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
104+
possible_scales.append(scale)
105+
assert possible_scales[0] == possible_scales[1]
106+
return possible_scales[0]
107+
108+
86109
class MultiScaleRoIAlign(nn.Module):
87110
"""
88111
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
@@ -142,30 +165,6 @@ def __init__(
142165
self.canonical_scale = canonical_scale
143166
self.canonical_level = canonical_level
144167

145-
def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
146-
concat_boxes = torch.cat(boxes, dim=0)
147-
device, dtype = concat_boxes.device, concat_boxes.dtype
148-
ids = torch.cat(
149-
[
150-
torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
151-
for i, b in enumerate(boxes)
152-
],
153-
dim=0,
154-
)
155-
rois = torch.cat([ids, concat_boxes], dim=1)
156-
return rois
157-
158-
def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
159-
# assumption: the scale is of the form 2 ** (-k), with k integer
160-
size = feature.shape[-2:]
161-
possible_scales: List[float] = []
162-
for s1, s2 in zip(size, original_size):
163-
approx_scale = float(s1) / float(s2)
164-
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
165-
possible_scales.append(scale)
166-
assert possible_scales[0] == possible_scales[1]
167-
return possible_scales[0]
168-
169168
def setup_scales(
170169
self,
171170
features: List[Tensor],
@@ -179,7 +178,7 @@ def setup_scales(
179178
max_y = max(shape[1], max_y)
180179
original_input_shape = (max_x, max_y)
181180

182-
scales = [self.infer_scale(feat, original_input_shape) for feat in features]
181+
scales = [_infer_scale(feat, original_input_shape) for feat in features]
183182
# get the levels in the feature map by leveraging the fact that the network always
184183
# downsamples by a factor of 2 at each level.
185184
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
@@ -216,7 +215,7 @@ def forward(
216215
if k in self.featmap_names:
217216
x_filtered.append(v)
218217
num_levels = len(x_filtered)
219-
rois = self.convert_to_roi_format(boxes)
218+
rois = _convert_to_roi_format(boxes)
220219
if self.scales is None:
221220
self.setup_scales(x_filtered, image_shapes)
222221

0 commit comments

Comments
 (0)