Skip to content

Commit e2aa359

Browse files
authored
Merge branch 'main' into transforms/fill
2 parents 1476b20 + 585b3b1 commit e2aa359

File tree

8 files changed

+107
-17
lines changed

8 files changed

+107
-17
lines changed

.circleci/unittest/linux/scripts/run_test.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@ set -e
55
eval "$(./conda/bin/conda shell.bash hook)"
66
conda activate ./env
77

8-
export PYTORCH_TEST_WITH_SLOW='1'
98
python -m torch.utils.collect_env
109
pytest --junitxml=test-results/junit.xml -v --durations 20

.circleci/unittest/windows/scripts/run_test.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,5 @@ conda activate ./env
88
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
99
source "$this_dir/set_cuda_envs.sh"
1010

11-
export PYTORCH_TEST_WITH_SLOW='1'
1211
python -m torch.utils.collect_env
1312
pytest --junitxml=test-results/junit.xml -v --durations 20

test/test_models.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
118118
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
119119

120120

121-
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False, eager_out=None):
121+
def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
122122
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
123123

124124
def get_export_import_copy(m):
@@ -129,20 +129,6 @@ def get_export_import_copy(m):
129129
imported = torch.jit.load(path)
130130
return imported
131131

132-
TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1"
133-
if not TEST_WITH_SLOW or skip:
134-
# TorchScript is not enabled, skip these tests
135-
msg = (
136-
f"The check_jit_scriptable test for {nn_module.__class__.__name__} was skipped. "
137-
"This test checks if the module's results in TorchScript "
138-
"match eager and that it can be exported. To run these "
139-
"tests make sure you set the environment variable "
140-
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
141-
"manually skipped."
142-
)
143-
warnings.warn(msg, RuntimeWarning)
144-
return None
145-
146132
sm = torch.jit.script(nn_module)
147133

148134
if eager_out is None:

test/test_prototype_transforms.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,56 @@ def test_features_bounding_box(self, p):
243243
assert_equal(expected, actual)
244244
assert actual.format == expected.format
245245
assert actual.image_size == expected.image_size
246+
247+
248+
@pytest.mark.parametrize("p", [0.0, 1.0])
249+
class TestRandomVerticalFlip:
250+
def input_expected_image_tensor(self, p, dtype=torch.float32):
251+
input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
252+
expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype)
253+
254+
return input, expected if p == 1 else input
255+
256+
def test_simple_tensor(self, p):
257+
input, expected = self.input_expected_image_tensor(p)
258+
transform = transforms.RandomVerticalFlip(p=p)
259+
260+
actual = transform(input)
261+
262+
assert_equal(expected, actual)
263+
264+
def test_pil_image(self, p):
265+
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
266+
transform = transforms.RandomVerticalFlip(p=p)
267+
268+
actual = transform(to_pil_image(input))
269+
270+
assert_equal(expected, pil_to_tensor(actual))
271+
272+
def test_features_image(self, p):
273+
input, expected = self.input_expected_image_tensor(p)
274+
transform = transforms.RandomVerticalFlip(p=p)
275+
276+
actual = transform(features.Image(input))
277+
278+
assert_equal(features.Image(expected), actual)
279+
280+
def test_features_segmentation_mask(self, p):
281+
input, expected = self.input_expected_image_tensor(p)
282+
transform = transforms.RandomVerticalFlip(p=p)
283+
284+
actual = transform(features.SegmentationMask(input))
285+
286+
assert_equal(features.SegmentationMask(expected), actual)
287+
288+
def test_features_bounding_box(self, p):
289+
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
290+
transform = transforms.RandomVerticalFlip(p=p)
291+
292+
actual = transform(input)
293+
294+
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
295+
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
296+
assert_equal(expected, actual)
297+
assert actual.format == expected.format
298+
assert actual.image_size == expected.image_size

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TenCrop,
1616
BatchMultiCrop,
1717
RandomHorizontalFlip,
18+
RandomVerticalFlip,
1819
Pad,
1920
RandomZoomOut,
2021
)

torchvision/prototype/transforms/_geometry.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,36 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
4545
return input
4646

4747

48+
class RandomVerticalFlip(Transform):
49+
def __init__(self, p: float = 0.5) -> None:
50+
super().__init__()
51+
self.p = p
52+
53+
def forward(self, *inputs: Any) -> Any:
54+
sample = inputs if len(inputs) > 1 else inputs[0]
55+
if torch.rand(1) > self.p:
56+
return sample
57+
58+
return super().forward(sample)
59+
60+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
61+
if isinstance(input, features.Image):
62+
output = F.vertical_flip_image_tensor(input)
63+
return features.Image.new_like(input, output)
64+
elif isinstance(input, features.SegmentationMask):
65+
output = F.vertical_flip_segmentation_mask(input)
66+
return features.SegmentationMask.new_like(input, output)
67+
elif isinstance(input, features.BoundingBox):
68+
output = F.vertical_flip_bounding_box(input, format=input.format, image_size=input.image_size)
69+
return features.BoundingBox.new_like(input, output)
70+
elif isinstance(input, PIL.Image.Image):
71+
return F.vertical_flip_image_pil(input)
72+
elif is_simple_tensor(input):
73+
return F.vertical_flip_image_tensor(input)
74+
else:
75+
return input
76+
77+
4878
class Resize(Transform):
4979
def __init__(
5080
self,

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
perspective_image_pil,
6464
vertical_flip_image_tensor,
6565
vertical_flip_image_pil,
66+
vertical_flip_bounding_box,
67+
vertical_flip_segmentation_mask,
6668
five_crop_image_tensor,
6769
five_crop_image_pil,
6870
ten_crop_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size:
8181
vertical_flip_image_pil = _FP.vflip
8282

8383

84+
def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
85+
return vertical_flip_image_tensor(segmentation_mask)
86+
87+
88+
def vertical_flip_bounding_box(
89+
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
90+
) -> torch.Tensor:
91+
shape = bounding_box.shape
92+
93+
bounding_box = convert_bounding_box_format(
94+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
95+
).view(-1, 4)
96+
97+
bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]]
98+
99+
return convert_bounding_box_format(
100+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
101+
).view(shape)
102+
103+
84104
def _affine_parse_args(
85105
angle: float,
86106
translate: List[float],

0 commit comments

Comments
 (0)