|
4 | 4 | import pytest
|
5 | 5 |
|
6 | 6 | import numpy as np
|
| 7 | +import os |
7 | 8 |
|
| 9 | +from PIL import Image |
8 | 10 | import torch
|
9 | 11 | from functools import lru_cache
|
10 | 12 | from torch import Tensor
|
@@ -1000,6 +1002,38 @@ def gen_iou_check(box, expected, tolerance=1e-4):
|
1000 | 1002 | gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
|
1001 | 1003 |
|
1002 | 1004 |
|
| 1005 | +class TestMasksToBoxes: |
| 1006 | + def test_masks_box(self): |
| 1007 | + def masks_box_check(masks, expected, tolerance=1e-4): |
| 1008 | + out = ops.masks_to_boxes(masks) |
| 1009 | + assert out.dtype == torch.float |
| 1010 | + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) |
| 1011 | + |
| 1012 | + # Check for int type boxes. |
| 1013 | + def _get_image(): |
| 1014 | + assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") |
| 1015 | + mask_path = os.path.join(assets_directory, "masks.tiff") |
| 1016 | + image = Image.open(mask_path) |
| 1017 | + return image |
| 1018 | + |
| 1019 | + def _create_masks(image, masks): |
| 1020 | + for index in range(image.n_frames): |
| 1021 | + image.seek(index) |
| 1022 | + frame = np.array(image) |
| 1023 | + masks[index] = torch.tensor(frame) |
| 1024 | + |
| 1025 | + return masks |
| 1026 | + |
| 1027 | + expected = torch.tensor([[127, 2, 165, 40], [2, 50, 44, 92], [56, 63, 98, 100], [139, 68, 175, 104], |
| 1028 | + [160, 112, 198, 145], [49, 138, 99, 182], [108, 148, 152, 213]], dtype=torch.float) |
| 1029 | + |
| 1030 | + image = _get_image() |
| 1031 | + for dtype in [torch.float16, torch.float32, torch.float64]: |
| 1032 | + masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype) |
| 1033 | + masks = _create_masks(image, masks) |
| 1034 | + masks_box_check(masks, expected) |
| 1035 | + |
| 1036 | + |
1003 | 1037 | class TestStochasticDepth:
|
1004 | 1038 | @pytest.mark.parametrize('p', [0.2, 0.5, 0.8])
|
1005 | 1039 | @pytest.mark.parametrize('mode', ["batch", "row"])
|
|
0 commit comments