|
1 | 1 | import functools
|
2 | 2 | import itertools
|
3 | 3 | import math
|
4 |
| -from unittest.mock import patch, Mock |
5 | 4 |
|
6 | 5 | import numpy as np
|
7 | 6 | import pytest
|
@@ -1348,40 +1347,14 @@ def _compute_expected_bbox(bbox, output_size_):
|
1348 | 1347 | torch.testing.assert_close(output_boxes, expected_bboxes)
|
1349 | 1348 |
|
1350 | 1349 |
|
1351 |
| -def test_correctness_center_crop_segmentation_mask_on_fixed_input(device): |
1352 |
| - mask = torch.ones((1, 6, 6), dtype=torch.long, device=device) |
1353 |
| - mask[:, 1:5, 2:4] = 0 |
1354 |
| - |
1355 |
| - out_mask = F.center_crop_segmentation_mask(mask, [2]) |
1356 |
| - expected_mask = torch.zeros((1, 4, 2), dtype=torch.long, device=device) |
1357 |
| - torch.testing.assert_close(out_mask, expected_mask) |
1358 |
| - |
1359 |
| - |
1360 |
| -@pytest.mark.parametrize("output_size", [[4, 3], [4]]) |
1361 |
| -def test_correctness_center_crop_segmentation_mask(output_size): |
1362 |
| - def _compute_expected_segmentation_mask(): |
1363 |
| - _output_size = output_size if isinstance(output_size, tuple) else (output_size, output_size) |
1364 |
| - |
1365 |
| - _, h, w = mask.shape |
1366 |
| - left = w - _output_size[0] |
1367 |
| - top = h - _output_size[1] |
1368 |
| - |
1369 |
| - return mask[:, top : _output_size[1], left : _output_size[0]] |
1370 |
| - |
1371 |
| - mask = torch.randint(0, 2, shape=(1, 6, 6)) |
1372 |
| - actual = F.center_crop_segmentation_mask(mask, output_size) |
1373 |
| - |
1374 |
| - expected = _compute_expected_segmentation_mask() |
1375 |
| - assert expected == actual |
1376 |
| - |
1377 |
| - |
| 1350 | +@pytest.mark.parametrize("device", cpu_and_gpu()) |
1378 | 1351 | @pytest.mark.parametrize("output_size", [[4, 3], [4], [7, 7]])
|
1379 |
| -@patch("torchvision.prototype.transforms.functional._geometry.center_crop_image_tensor") |
1380 |
| -def test_correctness_center_crop_segmentation_mask_mock(center_crop_mock, output_size): |
1381 |
| - mask, expected = Mock(spec=torch.Tensor), Mock(spec=torch.Tensor) |
1382 |
| - center_crop_mock.return_value = expected |
| 1352 | +def test_correctness_center_crop_segmentation_mask(device, output_size): |
| 1353 | + def _compute_expected_segmentation_mask(mask, output_size): |
| 1354 | + return F.center_crop_image_tensor(mask, output_size) |
1383 | 1355 |
|
1384 |
| - out_mask = F.center_crop_segmentation_mask(mask, output_size) |
| 1356 | + mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device) |
| 1357 | + actual = F.center_crop_segmentation_mask(mask, output_size) |
1385 | 1358 |
|
1386 |
| - center_crop_mock.assert_called_once_with(img=mask, output_size=output_size) |
1387 |
| - assert expected is out_mask |
| 1359 | + expected = _compute_expected_segmentation_mask(mask, output_size) |
| 1360 | + torch.testing.assert_close(expected, actual) |
0 commit comments