Skip to content

Commit b6953b2

Browse files
author
Federico Pozzi
committed
test: improvements
1 parent def456a commit b6953b2

File tree

1 file changed

+8
-35
lines changed

1 file changed

+8
-35
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import functools
22
import itertools
33
import math
4-
from unittest.mock import patch, Mock
54

65
import numpy as np
76
import pytest
@@ -1348,40 +1347,14 @@ def _compute_expected_bbox(bbox, output_size_):
13481347
torch.testing.assert_close(output_boxes, expected_bboxes)
13491348

13501349

1351-
def test_correctness_center_crop_segmentation_mask_on_fixed_input(device):
1352-
mask = torch.ones((1, 6, 6), dtype=torch.long, device=device)
1353-
mask[:, 1:5, 2:4] = 0
1354-
1355-
out_mask = F.center_crop_segmentation_mask(mask, [2])
1356-
expected_mask = torch.zeros((1, 4, 2), dtype=torch.long, device=device)
1357-
torch.testing.assert_close(out_mask, expected_mask)
1358-
1359-
1360-
@pytest.mark.parametrize("output_size", [[4, 3], [4]])
1361-
def test_correctness_center_crop_segmentation_mask(output_size):
1362-
def _compute_expected_segmentation_mask():
1363-
_output_size = output_size if isinstance(output_size, tuple) else (output_size, output_size)
1364-
1365-
_, h, w = mask.shape
1366-
left = w - _output_size[0]
1367-
top = h - _output_size[1]
1368-
1369-
return mask[:, top : _output_size[1], left : _output_size[0]]
1370-
1371-
mask = torch.randint(0, 2, shape=(1, 6, 6))
1372-
actual = F.center_crop_segmentation_mask(mask, output_size)
1373-
1374-
expected = _compute_expected_segmentation_mask()
1375-
assert expected == actual
1376-
1377-
1350+
@pytest.mark.parametrize("device", cpu_and_gpu())
13781351
@pytest.mark.parametrize("output_size", [[4, 3], [4], [7, 7]])
1379-
@patch("torchvision.prototype.transforms.functional._geometry.center_crop_image_tensor")
1380-
def test_correctness_center_crop_segmentation_mask_mock(center_crop_mock, output_size):
1381-
mask, expected = Mock(spec=torch.Tensor), Mock(spec=torch.Tensor)
1382-
center_crop_mock.return_value = expected
1352+
def test_correctness_center_crop_segmentation_mask(device, output_size):
1353+
def _compute_expected_segmentation_mask(mask, output_size):
1354+
return F.center_crop_image_tensor(mask, output_size)
13831355

1384-
out_mask = F.center_crop_segmentation_mask(mask, output_size)
1356+
mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
1357+
actual = F.center_crop_segmentation_mask(mask, output_size)
13851358

1386-
center_crop_mock.assert_called_once_with(img=mask, output_size=output_size)
1387-
assert expected is out_mask
1359+
expected = _compute_expected_segmentation_mask(mask, output_size)
1360+
torch.testing.assert_close(expected, actual)

0 commit comments

Comments
 (0)