1
1
from __future__ import division
2
+ import math
3
+ from typing import Tuple
4
+ import unittest
5
+
2
6
import numpy as np
7
+
3
8
import torch
4
9
from torch .autograd import gradcheck
5
-
10
+ from torch .nn .modules .utils import _pair
11
+ from torch import Tensor
6
12
from torchvision import ops
7
13
8
- from itertools import product
9
- import unittest
10
-
11
14
12
- class RoIOpTester (object ):
15
+ class OpTester (object ):
13
16
@classmethod
14
17
def setUpClass (cls ):
15
18
cls .dtype = torch .float64
@@ -42,6 +45,14 @@ def test_backward_cuda_contiguous(self):
42
45
def test_backward_cuda_non_contiguous (self ):
43
46
self ._test_backward (device = torch .device ('cuda' ), contiguous = False )
44
47
48
+ def _test_forward (self , device , contiguous ):
49
+ pass
50
+
51
+ def _test_backward (self , device , contiguous ):
52
+ pass
53
+
54
+
55
+ class RoIOpTester (OpTester ):
45
56
def _test_forward (self , device , contiguous ):
46
57
pool_size = 5
47
58
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
@@ -79,7 +90,6 @@ def func(z):
79
90
80
91
self .assertTrue (gradcheck (func , (x ,)))
81
92
self .assertTrue (gradcheck (script_func , (x ,)))
82
- return
83
93
84
94
def fn (* args , ** kwargs ):
85
95
pass
@@ -98,7 +108,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
98
108
def get_script_fn (self , rois , pool_size ):
99
109
@torch .jit .script
100
110
def script_fn (input , rois , pool_size ):
101
- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
111
+ # type: (Tensor, Tensor, int) -> Tensor
102
112
return ops .roi_pool (input , rois , pool_size , 1.0 )[0 ]
103
113
return lambda x : script_fn (x , rois , pool_size )
104
114
@@ -137,7 +147,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
137
147
def get_script_fn (self , rois , pool_size ):
138
148
@torch .jit .script
139
149
def script_fn (input , rois , pool_size ):
140
- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
150
+ # type: (Tensor, Tensor, int) -> Tensor
141
151
return ops .ps_roi_pool (input , rois , pool_size , 1.0 )[0 ]
142
152
return lambda x : script_fn (x , rois , pool_size )
143
153
@@ -174,29 +184,35 @@ def get_slice(k, block):
174
184
return y
175
185
176
186
177
- def bilinear_interpolate (data , height , width , y , x ):
178
- if y < - 1.0 or y > height or x < - 1.0 or x > width :
179
- return 0.
187
+ def bilinear_interpolate (data , y , x , snap_border = False ):
188
+ height , width = data .shape
180
189
181
- y = min (max (0 , y ), height - 1 )
182
- x = min (max (0 , x ), width - 1 )
190
+ if snap_border :
191
+ if - 1 < y <= 0 :
192
+ y = 0
193
+ elif height - 1 <= y < height :
194
+ y = height - 1
183
195
184
- y_low = int (y )
185
- y_high = min (y_low + 1 , height - 1 )
196
+ if - 1 < x <= 0 :
197
+ x = 0
198
+ elif width - 1 <= x < width :
199
+ x = width - 1
186
200
187
- x_low = int (x )
188
- x_high = min (x_low + 1 , width - 1 )
201
+ y_low = int (math .floor (y ))
202
+ x_low = int (math .floor (x ))
203
+ y_high = y_low + 1
204
+ x_high = x_low + 1
189
205
190
206
wy_h = y - y_low
191
- wy_l = 1 - wy_h
192
-
193
207
wx_h = x - x_low
208
+ wy_l = 1 - wy_h
194
209
wx_l = 1 - wx_h
195
210
196
211
val = 0
197
- for wx , x in zip ((wx_l , wx_h ), (x_low , x_high )):
198
- for wy , y in zip ((wy_l , wy_h ), (y_low , y_high )):
199
- val += wx * wy * data [y * width + x ]
212
+ for wx , xp in zip ((wx_l , wx_h ), (x_low , x_high )):
213
+ for wy , yp in zip ((wy_l , wy_h ), (y_low , y_high )):
214
+ if 0 <= yp < height and 0 <= xp < width :
215
+ val += wx * wy * data [yp , xp ]
200
216
return val
201
217
202
218
@@ -208,7 +224,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
208
224
def get_script_fn (self , rois , pool_size ):
209
225
@torch .jit .script
210
226
def script_fn (input , rois , pool_size ):
211
- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
227
+ # type: (Tensor, Tensor, int) -> Tensor
212
228
return ops .roi_align (input , rois , pool_size , 1.0 )[0 ]
213
229
return lambda x : script_fn (x , rois , pool_size )
214
230
@@ -242,12 +258,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
242
258
y = start_h + (iy + 0.5 ) * bin_h / grid_h
243
259
for ix in range (0 , grid_w ):
244
260
x = start_w + (ix + 0.5 ) * bin_w / grid_w
245
- val += bilinear_interpolate (
246
- in_data [batch_idx , channel , :, :].flatten (),
247
- in_data .size (- 2 ),
248
- in_data .size (- 1 ),
249
- y , x
250
- )
261
+ val += bilinear_interpolate (in_data [batch_idx , channel , :, :], y , x , snap_border = True )
251
262
val /= grid_h * grid_w
252
263
253
264
out_data [r , channel , i , j ] = val
@@ -262,7 +273,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
262
273
def get_script_fn (self , rois , pool_size ):
263
274
@torch .jit .script
264
275
def script_fn (input , rois , pool_size ):
265
- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
276
+ # type: (Tensor, Tensor, int) -> Tensor
266
277
return ops .ps_roi_align (input , rois , pool_size , 1.0 )[0 ]
267
278
return lambda x : script_fn (x , rois , pool_size )
268
279
@@ -298,12 +309,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
298
309
y = start_h + (iy + 0.5 ) * bin_h / grid_h
299
310
for ix in range (0 , grid_w ):
300
311
x = start_w + (ix + 0.5 ) * bin_w / grid_w
301
- val += bilinear_interpolate (
302
- in_data [batch_idx , c_in , :, :].flatten (),
303
- in_data .size (- 2 ),
304
- in_data .size (- 1 ),
305
- y , x
306
- )
312
+ val += bilinear_interpolate (in_data [batch_idx , c_in , :, :], y , x , snap_border = True )
307
313
val /= grid_h * grid_w
308
314
309
315
out_data [r , c_out , i , j ] = val
@@ -367,5 +373,106 @@ def test_nms_cuda(self):
367
373
self .assertTrue (torch .allclose (r_cpu , r_cuda .cpu ()), err_msg .format (iou ))
368
374
369
375
376
+ class DeformConvTester (OpTester , unittest .TestCase ):
377
+ def expected_fn (self , x , offsets , weights , * args , stride = 1 , pad = 0 , dilation = 1 ):
378
+ stride_h , stride_w = _pair (stride )
379
+ pad_h , pad_w = _pair (pad )
380
+ dil_h , dil_w = _pair (dilation )
381
+ weights_h , weights_w = weights .shape [- 2 :]
382
+
383
+ n_batches , n_in_channels , in_h , in_w = x .shape
384
+ n_out_channels = weights .shape [0 ]
385
+
386
+ out_h = (in_h + 2 * pad_h - (dil_h * (weights_h - 1 ) + 1 )) // stride_h + 1
387
+ out_w = (in_w + 2 * pad_w - (dil_w * (weights_w - 1 ) + 1 )) // stride_w + 1
388
+
389
+ n_offset_grps = offsets .shape [1 ] // (2 * weights_h * weights_w )
390
+ in_c_per_offset_grp = n_in_channels // n_offset_grps
391
+
392
+ n_weight_grps = n_in_channels // weights .shape [1 ]
393
+ in_c_per_weight_grp = weights .shape [1 ]
394
+ out_c_per_weight_grp = n_out_channels // n_weight_grps
395
+
396
+ out = torch .zeros (n_batches , n_out_channels , out_h , out_w , device = x .device , dtype = x .dtype )
397
+ for b in range (n_batches ):
398
+ for c_out in range (n_out_channels ):
399
+ for i in range (out_h ):
400
+ for j in range (out_w ):
401
+ for di in range (weights_h ):
402
+ for dj in range (weights_w ):
403
+ for c in range (in_c_per_weight_grp ):
404
+ weight_grp = c_out // out_c_per_weight_grp
405
+ c_in = weight_grp * in_c_per_weight_grp + c
406
+
407
+ offset_grp = c_in // in_c_per_offset_grp
408
+ offset_idx = 2 * (offset_grp * (weights_h * weights_w ) + di * weights_w + dj )
409
+
410
+ pi = stride_h * i - pad_h + dil_h * di + offsets [b , offset_idx , i , j ]
411
+ pj = stride_w * j - pad_w + dil_w * dj + offsets [b , offset_idx + 1 , i , j ]
412
+
413
+ out [b , c_out , i , j ] += (weights [c_out , c , di , dj ] *
414
+ bilinear_interpolate (x [b , c_in , :, :], pi , pj ))
415
+ return out
416
+
417
+ def get_fn_args (self , device , contiguous ):
418
+ batch_sz = 1
419
+ n_in_channels = 6
420
+ n_out_channels = 2
421
+ n_weight_grps = 2
422
+ n_offset_grps = 3
423
+
424
+ stride = (2 , 1 )
425
+ pad = (1 , 0 )
426
+ dilation = (2 , 1 )
427
+
428
+ stride_h , stride_w = stride
429
+ pad_h , pad_w = pad
430
+ dil_h , dil_w = dilation
431
+ weight_h , weight_w = (3 , 2 )
432
+ in_h , in_w = (5 , 4 )
433
+
434
+ out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1 ) + 1 )) // stride_h + 1
435
+ out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1 ) + 1 )) // stride_w + 1
436
+
437
+ x = torch .rand (batch_sz , n_in_channels , in_h , in_w , device = device , dtype = self .dtype , requires_grad = True )
438
+
439
+ offset = torch .randn (batch_sz , n_offset_grps * 2 * weight_h * weight_w , out_h , out_w ,
440
+ device = device , dtype = self .dtype , requires_grad = True )
441
+
442
+ weight = torch .randn (n_out_channels , n_in_channels // n_weight_grps , weight_h , weight_w ,
443
+ device = device , dtype = self .dtype , requires_grad = True )
444
+
445
+ if not contiguous :
446
+ x = x .permute (0 , 1 , 3 , 2 ).contiguous ().permute (0 , 1 , 3 , 2 )
447
+ offset = offset .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
448
+ weight = weight .permute (3 , 2 , 0 , 1 ).contiguous ().permute (2 , 3 , 1 , 0 )
449
+
450
+ return x , offset , weight , stride , pad , dilation
451
+
452
+ def _test_forward (self , device , contiguous ):
453
+ x , offset , weight , stride , pad , dilation = self .get_fn_args (device , contiguous )
454
+
455
+ res = ops .DeformConv (stride = stride , pad = pad , dilation = dilation )(x , offset , weight )
456
+ expected = self .expected_fn (x , offset , weight , stride = stride , pad = pad , dilation = dilation )
457
+
458
+ self .assertTrue (torch .allclose (res , expected ), '\n res:\n {}\n expected:\n {}' .format (x , res , expected ))
459
+
460
+ def _test_backward (self , device , contiguous ):
461
+ x , offset , weight , stride , pad , dilation = self .get_fn_args (device , contiguous )
462
+
463
+ def func (x_ , offset_ , weight_ ):
464
+ return ops .deform_conv (x_ , offset_ , weight_ , stride = stride , pad = pad , dilation = dilation )
465
+
466
+ gradcheck (func , (x , offset , weight ), nondet_tol = 1e-5 )
467
+
468
+ @torch .jit .script
469
+ def script_func (x_ , offset_ , weight_ , stride_ , pad_ , dilation_ ):
470
+ # type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
471
+ return ops .deform_conv (x_ , offset_ , weight_ , stride = stride_ , pad = pad_ , dilation = dilation_ )
472
+
473
+ gradcheck (lambda z , off , wei : script_func (z , off , wei , stride , pad , dilation ),
474
+ (x , offset , weight ), nondet_tol = 1e-5 )
475
+
476
+
370
477
if __name__ == '__main__' :
371
478
unittest .main ()
0 commit comments