@@ -83,6 +83,29 @@ def __call__(self, boxlists: List[Tensor]) -> Tensor:
83
83
return (target_lvls .to (torch .int64 ) - self .k_min ).to (torch .int64 )
84
84
85
85
86
+ def _convert_to_roi_format (boxes : List [Tensor ]) -> Tensor :
87
+ concat_boxes = torch .cat (boxes , dim = 0 )
88
+ device , dtype = concat_boxes .device , concat_boxes .dtype
89
+ ids = torch .cat (
90
+ [torch .full_like (b [:, :1 ], i , dtype = dtype , layout = torch .strided , device = device ) for i , b in enumerate (boxes )],
91
+ dim = 0 ,
92
+ )
93
+ rois = torch .cat ([ids , concat_boxes ], dim = 1 )
94
+ return rois
95
+
96
+
97
+ def _infer_scale (feature : Tensor , original_size : List [int ]) -> float :
98
+ # assumption: the scale is of the form 2 ** (-k), with k integer
99
+ size = feature .shape [- 2 :]
100
+ possible_scales : List [float ] = []
101
+ for s1 , s2 in zip (size , original_size ):
102
+ approx_scale = float (s1 ) / float (s2 )
103
+ scale = 2 ** float (torch .tensor (approx_scale ).log2 ().round ())
104
+ possible_scales .append (scale )
105
+ assert possible_scales [0 ] == possible_scales [1 ]
106
+ return possible_scales [0 ]
107
+
108
+
86
109
class MultiScaleRoIAlign (nn .Module ):
87
110
"""
88
111
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
@@ -142,30 +165,6 @@ def __init__(
142
165
self .canonical_scale = canonical_scale
143
166
self .canonical_level = canonical_level
144
167
145
- def convert_to_roi_format (self , boxes : List [Tensor ]) -> Tensor :
146
- concat_boxes = torch .cat (boxes , dim = 0 )
147
- device , dtype = concat_boxes .device , concat_boxes .dtype
148
- ids = torch .cat (
149
- [
150
- torch .full_like (b [:, :1 ], i , dtype = dtype , layout = torch .strided , device = device )
151
- for i , b in enumerate (boxes )
152
- ],
153
- dim = 0 ,
154
- )
155
- rois = torch .cat ([ids , concat_boxes ], dim = 1 )
156
- return rois
157
-
158
- def infer_scale (self , feature : Tensor , original_size : List [int ]) -> float :
159
- # assumption: the scale is of the form 2 ** (-k), with k integer
160
- size = feature .shape [- 2 :]
161
- possible_scales : List [float ] = []
162
- for s1 , s2 in zip (size , original_size ):
163
- approx_scale = float (s1 ) / float (s2 )
164
- scale = 2 ** float (torch .tensor (approx_scale ).log2 ().round ())
165
- possible_scales .append (scale )
166
- assert possible_scales [0 ] == possible_scales [1 ]
167
- return possible_scales [0 ]
168
-
169
168
def setup_scales (
170
169
self ,
171
170
features : List [Tensor ],
@@ -179,7 +178,7 @@ def setup_scales(
179
178
max_y = max (shape [1 ], max_y )
180
179
original_input_shape = (max_x , max_y )
181
180
182
- scales = [self . infer_scale (feat , original_input_shape ) for feat in features ]
181
+ scales = [_infer_scale (feat , original_input_shape ) for feat in features ]
183
182
# get the levels in the feature map by leveraging the fact that the network always
184
183
# downsamples by a factor of 2 at each level.
185
184
lvl_min = - torch .log2 (torch .tensor (scales [0 ], dtype = torch .float32 )).item ()
@@ -216,7 +215,7 @@ def forward(
216
215
if k in self .featmap_names :
217
216
x_filtered .append (v )
218
217
num_levels = len (x_filtered )
219
- rois = self . convert_to_roi_format (boxes )
218
+ rois = _convert_to_roi_format (boxes )
220
219
if self .scales is None :
221
220
self .setup_scales (x_filtered , image_shapes )
222
221
0 commit comments