1
1
from __future__ import division
2
+ import math
3
+ import unittest
4
+
2
5
import numpy as np
6
+
3
7
import torch
8
+ from torch import Tensor
4
9
from torch .autograd import gradcheck
5
-
10
+ from torch .jit .annotations import Tuple
11
+ from torch .nn .modules .utils import _pair
6
12
from torchvision import ops
7
13
8
- from itertools import product
9
- import unittest
10
-
11
14
12
- class RoIOpTester (object ):
15
+ class OpTester (object ):
13
16
@classmethod
14
17
def setUpClass (cls ):
15
18
cls .dtype = torch .float64
@@ -42,6 +45,14 @@ def test_backward_cuda_contiguous(self):
42
45
def test_backward_cuda_non_contiguous (self ):
43
46
self ._test_backward (device = torch .device ('cuda' ), contiguous = False )
44
47
48
+ def _test_forward (self , device , contiguous ):
49
+ pass
50
+
51
+ def _test_backward (self , device , contiguous ):
52
+ pass
53
+
54
+
55
+ class RoIOpTester (OpTester ):
45
56
def _test_forward (self , device , contiguous ):
46
57
pool_size = 5
47
58
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
@@ -79,7 +90,6 @@ def func(z):
79
90
80
91
self .assertTrue (gradcheck (func , (x ,)))
81
92
self .assertTrue (gradcheck (script_func , (x ,)))
82
- return
83
93
84
94
def fn (* args , ** kwargs ):
85
95
pass
@@ -98,7 +108,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
98
108
def get_script_fn (self , rois , pool_size ):
99
109
@torch .jit .script
100
110
def script_fn (input , rois , pool_size ):
101
- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
111
+ # type: (Tensor, Tensor, int) -> Tensor
102
112
return ops .roi_pool (input , rois , pool_size , 1.0 )[0 ]
103
113
return lambda x : script_fn (x , rois , pool_size )
104
114
@@ -137,7 +147,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
137
147
def get_script_fn (self , rois , pool_size ):
138
148
@torch .jit .script
139
149
def script_fn (input , rois , pool_size ):
140
- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
150
+ # type: (Tensor, Tensor, int) -> Tensor
141
151
return ops .ps_roi_pool (input , rois , pool_size , 1.0 )[0 ]
142
152
return lambda x : script_fn (x , rois , pool_size )
143
153
@@ -174,29 +184,35 @@ def get_slice(k, block):
174
184
return y
175
185
176
186
177
- def bilinear_interpolate (data , height , width , y , x ):
178
- if y < - 1.0 or y > height or x < - 1.0 or x > width :
179
- return 0.
187
+ def bilinear_interpolate (data , y , x , snap_border = False ):
188
+ height , width = data .shape
180
189
181
- y = min (max (0 , y ), height - 1 )
182
- x = min (max (0 , x ), width - 1 )
190
+ if snap_border :
191
+ if - 1 < y <= 0 :
192
+ y = 0
193
+ elif height - 1 <= y < height :
194
+ y = height - 1
183
195
184
- y_low = int (y )
185
- y_high = min (y_low + 1 , height - 1 )
196
+ if - 1 < x <= 0 :
197
+ x = 0
198
+ elif width - 1 <= x < width :
199
+ x = width - 1
186
200
187
- x_low = int (x )
188
- x_high = min (x_low + 1 , width - 1 )
201
+ y_low = int (math .floor (y ))
202
+ x_low = int (math .floor (x ))
203
+ y_high = y_low + 1
204
+ x_high = x_low + 1
189
205
190
206
wy_h = y - y_low
191
- wy_l = 1 - wy_h
192
-
193
207
wx_h = x - x_low
208
+ wy_l = 1 - wy_h
194
209
wx_l = 1 - wx_h
195
210
196
211
val = 0
197
- for wx , x in zip ((wx_l , wx_h ), (x_low , x_high )):
198
- for wy , y in zip ((wy_l , wy_h ), (y_low , y_high )):
199
- val += wx * wy * data [y * width + x ]
212
+ for wx , xp in zip ((wx_l , wx_h ), (x_low , x_high )):
213
+ for wy , yp in zip ((wy_l , wy_h ), (y_low , y_high )):
214
+ if 0 <= yp < height and 0 <= xp < width :
215
+ val += wx * wy * data [yp , xp ]
200
216
return val
201
217
202
218
@@ -208,7 +224,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
208
224
def get_script_fn (self , rois , pool_size ):
209
225
@torch .jit .script
210
226
def script_fn (input , rois , pool_size ):
211
- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
227
+ # type: (Tensor, Tensor, int) -> Tensor
212
228
return ops .roi_align (input , rois , pool_size , 1.0 )[0 ]
213
229
return lambda x : script_fn (x , rois , pool_size )
214
230
@@ -242,12 +258,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
242
258
y = start_h + (iy + 0.5 ) * bin_h / grid_h
243
259
for ix in range (0 , grid_w ):
244
260
x = start_w + (ix + 0.5 ) * bin_w / grid_w
245
- val += bilinear_interpolate (
246
- in_data [batch_idx , channel , :, :].flatten (),
247
- in_data .size (- 2 ),
248
- in_data .size (- 1 ),
249
- y , x
250
- )
261
+ val += bilinear_interpolate (in_data [batch_idx , channel , :, :], y , x , snap_border = True )
251
262
val /= grid_h * grid_w
252
263
253
264
out_data [r , channel , i , j ] = val
@@ -262,7 +273,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
262
273
def get_script_fn (self , rois , pool_size ):
263
274
@torch .jit .script
264
275
def script_fn (input , rois , pool_size ):
265
- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
276
+ # type: (Tensor, Tensor, int) -> Tensor
266
277
return ops .ps_roi_align (input , rois , pool_size , 1.0 )[0 ]
267
278
return lambda x : script_fn (x , rois , pool_size )
268
279
@@ -298,12 +309,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
298
309
y = start_h + (iy + 0.5 ) * bin_h / grid_h
299
310
for ix in range (0 , grid_w ):
300
311
x = start_w + (ix + 0.5 ) * bin_w / grid_w
301
- val += bilinear_interpolate (
302
- in_data [batch_idx , c_in , :, :].flatten (),
303
- in_data .size (- 2 ),
304
- in_data .size (- 1 ),
305
- y , x
306
- )
312
+ val += bilinear_interpolate (in_data [batch_idx , c_in , :, :], y , x , snap_border = True )
307
313
val /= grid_h * grid_w
308
314
309
315
out_data [r , c_out , i , j ] = val
@@ -376,5 +382,120 @@ def test_new_empty_tensor(self):
376
382
assert out .dtype == input .dtype
377
383
378
384
385
+ class DeformConvTester (OpTester , unittest .TestCase ):
386
+ def expected_fn (self , x , weight , offset , bias , stride = 1 , padding = 0 , dilation = 1 ):
387
+ stride_h , stride_w = _pair (stride )
388
+ pad_h , pad_w = _pair (padding )
389
+ dil_h , dil_w = _pair (dilation )
390
+ weight_h , weight_w = weight .shape [- 2 :]
391
+
392
+ n_batches , n_in_channels , in_h , in_w = x .shape
393
+ n_out_channels = weight .shape [0 ]
394
+
395
+ out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1 ) + 1 )) // stride_h + 1
396
+ out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1 ) + 1 )) // stride_w + 1
397
+
398
+ n_offset_grps = offset .shape [1 ] // (2 * weight_h * weight_w )
399
+ in_c_per_offset_grp = n_in_channels // n_offset_grps
400
+
401
+ n_weight_grps = n_in_channels // weight .shape [1 ]
402
+ in_c_per_weight_grp = weight .shape [1 ]
403
+ out_c_per_weight_grp = n_out_channels // n_weight_grps
404
+
405
+ out = torch .zeros (n_batches , n_out_channels , out_h , out_w , device = x .device , dtype = x .dtype )
406
+ for b in range (n_batches ):
407
+ for c_out in range (n_out_channels ):
408
+ for i in range (out_h ):
409
+ for j in range (out_w ):
410
+ for di in range (weight_h ):
411
+ for dj in range (weight_w ):
412
+ for c in range (in_c_per_weight_grp ):
413
+ weight_grp = c_out // out_c_per_weight_grp
414
+ c_in = weight_grp * in_c_per_weight_grp + c
415
+
416
+ offset_grp = c_in // in_c_per_offset_grp
417
+ offset_idx = 2 * (offset_grp * (weight_h * weight_w ) + di * weight_w + dj )
418
+
419
+ pi = stride_h * i - pad_h + dil_h * di + offset [b , offset_idx , i , j ]
420
+ pj = stride_w * j - pad_w + dil_w * dj + offset [b , offset_idx + 1 , i , j ]
421
+
422
+ out [b , c_out , i , j ] += (weight [c_out , c , di , dj ] *
423
+ bilinear_interpolate (x [b , c_in , :, :], pi , pj ))
424
+ out += bias .view (1 , n_out_channels , 1 , 1 )
425
+ return out
426
+
427
+ def get_fn_args (self , device , contiguous ):
428
+ batch_sz = 1
429
+ n_in_channels = 6
430
+ n_out_channels = 2
431
+ n_weight_grps = 2
432
+ n_offset_grps = 3
433
+
434
+ stride = (2 , 1 )
435
+ pad = (1 , 0 )
436
+ dilation = (2 , 1 )
437
+
438
+ stride_h , stride_w = stride
439
+ pad_h , pad_w = pad
440
+ dil_h , dil_w = dilation
441
+ weight_h , weight_w = (3 , 2 )
442
+ in_h , in_w = (5 , 4 )
443
+
444
+ out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1 ) + 1 )) // stride_h + 1
445
+ out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1 ) + 1 )) // stride_w + 1
446
+
447
+ x = torch .rand (batch_sz , n_in_channels , in_h , in_w , device = device , dtype = self .dtype , requires_grad = True )
448
+
449
+ offset = torch .randn (batch_sz , n_offset_grps * 2 * weight_h * weight_w , out_h , out_w ,
450
+ device = device , dtype = self .dtype , requires_grad = True )
451
+
452
+ weight = torch .randn (n_out_channels , n_in_channels // n_weight_grps , weight_h , weight_w ,
453
+ device = device , dtype = self .dtype , requires_grad = True )
454
+
455
+ bias = torch .randn (n_out_channels , device = device , dtype = self .dtype , requires_grad = True )
456
+
457
+ if not contiguous :
458
+ x = x .permute (0 , 1 , 3 , 2 ).contiguous ().permute (0 , 1 , 3 , 2 )
459
+ offset = offset .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
460
+ weight = weight .permute (3 , 2 , 0 , 1 ).contiguous ().permute (2 , 3 , 1 , 0 )
461
+
462
+ return x , weight , offset , bias , stride , pad , dilation
463
+
464
+ def _test_forward (self , device , contiguous ):
465
+ x , _ , offset , _ , stride , padding , dilation = self .get_fn_args (device , contiguous )
466
+ in_channels = 6
467
+ out_channels = 2
468
+ kernel_size = (3 , 2 )
469
+ groups = 2
470
+ offset_groups = 3
471
+
472
+ layer = ops .DeformConv2d (in_channels , out_channels , kernel_size , stride = stride , padding = padding ,
473
+ dilation = dilation , groups = groups , offset_groups = offset_groups ).to (device = x .device ,
474
+ dtype = x .dtype )
475
+ res = layer (x , offset )
476
+
477
+ weight = layer .weight .data
478
+ bias = layer .bias .data
479
+ expected = self .expected_fn (x , weight , offset , bias , stride = stride , padding = padding , dilation = dilation )
480
+
481
+ self .assertTrue (torch .allclose (res , expected ), '\n res:\n {}\n expected:\n {}' .format (res , expected ))
482
+
483
+ def _test_backward (self , device , contiguous ):
484
+ x , weight , offset , bias , stride , padding , dilation = self .get_fn_args (device , contiguous )
485
+
486
+ def func (x_ , offset_ , weight_ , bias_ ):
487
+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride , padding = padding , dilation = dilation )
488
+
489
+ gradcheck (func , (x , offset , weight , bias ), nondet_tol = 1e-5 )
490
+
491
+ @torch .jit .script
492
+ def script_func (x_ , offset_ , weight_ , bias_ , stride_ , pad_ , dilation_ ):
493
+ # type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
494
+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride_ , padding = pad_ , dilation = dilation_ )
495
+
496
+ gradcheck (lambda z , off , wei , bi : script_func (z , off , wei , bi , stride , padding , dilation ),
497
+ (x , offset , weight , bias ), nondet_tol = 1e-5 )
498
+
499
+
379
500
if __name__ == '__main__' :
380
501
unittest .main ()
0 commit comments