|
1 | 1 | from __future__ import division
|
2 | 2 | import math
|
3 |
| -from typing import Tuple |
4 | 3 | import unittest
|
5 | 4 |
|
6 | 5 | import numpy as np
|
7 | 6 |
|
8 | 7 | import torch
|
| 8 | +from torch import Tensor |
9 | 9 | from torch.autograd import gradcheck
|
| 10 | +from torch.jit.annotations import Tuple |
10 | 11 | from torch.nn.modules.utils import _pair
|
11 |
| -from torch import Tensor |
12 | 12 | from torchvision import ops
|
13 | 13 |
|
14 | 14 |
|
@@ -374,44 +374,45 @@ def test_nms_cuda(self):
|
374 | 374 |
|
375 | 375 |
|
376 | 376 | class DeformConvTester(OpTester, unittest.TestCase):
|
377 |
| - def expected_fn(self, x, offsets, weights, *args, stride=1, pad=0, dilation=1): |
| 377 | + def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1): |
378 | 378 | stride_h, stride_w = _pair(stride)
|
379 |
| - pad_h, pad_w = _pair(pad) |
| 379 | + pad_h, pad_w = _pair(padding) |
380 | 380 | dil_h, dil_w = _pair(dilation)
|
381 |
| - weights_h, weights_w = weights.shape[-2:] |
| 381 | + weight_h, weight_w = weight.shape[-2:] |
382 | 382 |
|
383 | 383 | n_batches, n_in_channels, in_h, in_w = x.shape
|
384 |
| - n_out_channels = weights.shape[0] |
| 384 | + n_out_channels = weight.shape[0] |
385 | 385 |
|
386 |
| - out_h = (in_h + 2 * pad_h - (dil_h * (weights_h - 1) + 1)) // stride_h + 1 |
387 |
| - out_w = (in_w + 2 * pad_w - (dil_w * (weights_w - 1) + 1)) // stride_w + 1 |
| 386 | + out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1 |
| 387 | + out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1 |
388 | 388 |
|
389 |
| - n_offset_grps = offsets.shape[1] // (2 * weights_h * weights_w) |
| 389 | + n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w) |
390 | 390 | in_c_per_offset_grp = n_in_channels // n_offset_grps
|
391 | 391 |
|
392 |
| - n_weight_grps = n_in_channels // weights.shape[1] |
393 |
| - in_c_per_weight_grp = weights.shape[1] |
| 392 | + n_weight_grps = n_in_channels // weight.shape[1] |
| 393 | + in_c_per_weight_grp = weight.shape[1] |
394 | 394 | out_c_per_weight_grp = n_out_channels // n_weight_grps
|
395 | 395 |
|
396 | 396 | out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
|
397 | 397 | for b in range(n_batches):
|
398 | 398 | for c_out in range(n_out_channels):
|
399 | 399 | for i in range(out_h):
|
400 | 400 | for j in range(out_w):
|
401 |
| - for di in range(weights_h): |
402 |
| - for dj in range(weights_w): |
| 401 | + for di in range(weight_h): |
| 402 | + for dj in range(weight_w): |
403 | 403 | for c in range(in_c_per_weight_grp):
|
404 | 404 | weight_grp = c_out // out_c_per_weight_grp
|
405 | 405 | c_in = weight_grp * in_c_per_weight_grp + c
|
406 | 406 |
|
407 | 407 | offset_grp = c_in // in_c_per_offset_grp
|
408 |
| - offset_idx = 2 * (offset_grp * (weights_h * weights_w) + di * weights_w + dj) |
| 408 | + offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj) |
409 | 409 |
|
410 |
| - pi = stride_h * i - pad_h + dil_h * di + offsets[b, offset_idx, i, j] |
411 |
| - pj = stride_w * j - pad_w + dil_w * dj + offsets[b, offset_idx + 1, i, j] |
| 410 | + pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j] |
| 411 | + pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j] |
412 | 412 |
|
413 |
| - out[b, c_out, i, j] += (weights[c_out, c, di, dj] * |
| 413 | + out[b, c_out, i, j] += (weight[c_out, c, di, dj] * |
414 | 414 | bilinear_interpolate(x[b, c_in, :, :], pi, pj))
|
| 415 | + out += bias.view(1, n_out_channels, 1, 1) |
415 | 416 | return out
|
416 | 417 |
|
417 | 418 | def get_fn_args(self, device, contiguous):
|
@@ -442,36 +443,50 @@ def get_fn_args(self, device, contiguous):
|
442 | 443 | weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
|
443 | 444 | device=device, dtype=self.dtype, requires_grad=True)
|
444 | 445 |
|
| 446 | + bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True) |
| 447 | + |
445 | 448 | if not contiguous:
|
446 | 449 | x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
|
447 | 450 | offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
|
448 | 451 | weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
|
449 | 452 |
|
450 |
| - return x, offset, weight, stride, pad, dilation |
| 453 | + return x, weight, offset, bias, stride, pad, dilation |
451 | 454 |
|
452 | 455 | def _test_forward(self, device, contiguous):
|
453 |
| - x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous) |
| 456 | + x, _, _, _, stride, padding, dilation = self.get_fn_args(device, contiguous) |
| 457 | + in_channels = 6 |
| 458 | + out_channels = 2 |
| 459 | + kernel_size = (3, 2) |
| 460 | + groups = 2 |
| 461 | + offset_groups = 3 |
| 462 | + |
| 463 | + layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, |
| 464 | + dilation=dilation, groups=groups, offset_groups=offset_groups) |
| 465 | + layer.offset_conv.weight.data = torch.randn_like(layer.offset_conv.weight.data) |
| 466 | + res = layer(x) |
454 | 467 |
|
455 |
| - res = ops.DeformConv(stride=stride, pad=pad, dilation=dilation)(x, offset, weight) |
456 |
| - expected = self.expected_fn(x, offset, weight, stride=stride, pad=pad, dilation=dilation) |
| 468 | + weight = layer.weight.data.to(device=x.device, dtype=x.dtype) |
| 469 | + offset = layer.offset_conv.to(device=x.device, dtype=x.dtype)(x) |
| 470 | + bias = layer.bias.data.to(device=x.device, dtype=x.dtype) |
| 471 | + expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation) |
457 | 472 |
|
458 |
| - self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(x, res, expected)) |
| 473 | + self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected)) |
459 | 474 |
|
460 | 475 | def _test_backward(self, device, contiguous):
|
461 |
| - x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous) |
| 476 | + x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous) |
462 | 477 |
|
463 |
| - def func(x_, offset_, weight_): |
464 |
| - return ops.deform_conv(x_, offset_, weight_, stride=stride, pad=pad, dilation=dilation) |
| 478 | + def func(x_, weight_, offset_, bias_): |
| 479 | + return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride, padding=padding, dilation=dilation) |
465 | 480 |
|
466 |
| - gradcheck(func, (x, offset, weight), nondet_tol=1e-5) |
| 481 | + gradcheck(func, (x, weight, offset, bias), nondet_tol=1e-5) |
467 | 482 |
|
468 | 483 | @torch.jit.script
|
469 |
| - def script_func(x_, offset_, weight_, stride_, pad_, dilation_): |
470 |
| - # type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor |
471 |
| - return ops.deform_conv(x_, offset_, weight_, stride=stride_, pad=pad_, dilation=dilation_) |
| 484 | + def script_func(x_, weight_, offset_, bias_, stride_, pad_, dilation_): |
| 485 | + # type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor |
| 486 | + return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride_, padding=pad_, dilation=dilation_) |
472 | 487 |
|
473 |
| - gradcheck(lambda z, off, wei: script_func(z, off, wei, stride, pad, dilation), |
474 |
| - (x, offset, weight), nondet_tol=1e-5) |
| 488 | + gradcheck(lambda z, wei, off, bi: script_func(z, wei, off, bi, stride, padding, dilation), |
| 489 | + (x, weight, offset, bias), nondet_tol=1e-5) |
475 | 490 |
|
476 | 491 |
|
477 | 492 | if __name__ == '__main__':
|
|
0 commit comments