Skip to content

Commit 2c94831

Browse files
authored
Merge branch 'main' into port-eurosat
2 parents a481830 + 64e7460 commit 2c94831

File tree

11 files changed

+222
-133
lines changed

11 files changed

+222
-133
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -199,21 +199,31 @@ def resize_bounding_box():
199199
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
200200

201201

202-
class TestKernelsCommon:
203-
@pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name)
204-
def test_scriptable(self, functional_info):
205-
jit.script(functional_info.functional)
206-
207-
@pytest.mark.parametrize(
208-
("functional_info", "sample_input"),
209-
[
210-
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
211-
for functional_info in FUNCTIONAL_INFOS
212-
for idx, sample_input in enumerate(functional_info.sample_inputs())
213-
],
214-
)
215-
def test_eager_vs_scripted(self, functional_info, sample_input):
216-
eager = functional_info(sample_input)
217-
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
218-
219-
torch.testing.assert_close(eager, scripted)
202+
@pytest.mark.parametrize(
203+
"kernel",
204+
[
205+
pytest.param(kernel, id=name)
206+
for name, kernel in F.__dict__.items()
207+
if not name.startswith("_")
208+
and callable(kernel)
209+
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
210+
and "pil" not in name
211+
],
212+
)
213+
def test_scriptable(kernel):
214+
jit.script(kernel)
215+
216+
217+
@pytest.mark.parametrize(
218+
("functional_info", "sample_input"),
219+
[
220+
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
221+
for functional_info in FUNCTIONAL_INFOS
222+
for idx, sample_input in enumerate(functional_info.sample_inputs())
223+
],
224+
)
225+
def test_eager_vs_scripted(functional_info, sample_input):
226+
eager = functional_info(sample_input)
227+
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
228+
229+
torch.testing.assert_close(eager, scripted)

torchvision/prototype/transforms/_augment.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype import features
88
from torchvision.prototype.transforms import Transform, functional as F
99

10-
from ._utils import query_image, get_image_dimensions
10+
from ._utils import query_image, get_image_dimensions, has_all, has_any
1111

1212

1313
class RandomErasing(Transform):
@@ -33,7 +33,6 @@ def __init__(
3333
raise ValueError("Scale should be between 0 and 1")
3434
if p < 0 or p > 1:
3535
raise ValueError("Random erasing probability should be between 0 and 1")
36-
# TODO: deprecate p in favor of wrapping the transform in a RandomApply
3736
self.p = p
3837
self.scale = scale
3938
self.ratio = ratio
@@ -88,9 +87,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
8887
return dict(zip("ijhwv", (i, j, h, w, v)))
8988

9089
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
91-
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
92-
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
93-
elif isinstance(input, features.Image):
90+
if isinstance(input, features.Image):
9491
output = F.erase_image_tensor(input, **params)
9592
return features.Image.new_like(input, output)
9693
elif isinstance(input, torch.Tensor):
@@ -99,10 +96,13 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
9996
return input
10097

10198
def forward(self, *inputs: Any) -> Any:
102-
if torch.rand(1) >= self.p:
103-
return inputs if len(inputs) > 1 else inputs[0]
99+
sample = inputs if len(inputs) > 1 else inputs[0]
100+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
101+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
102+
elif torch.rand(1) >= self.p:
103+
return sample
104104

105-
return super().forward(*inputs)
105+
return super().forward(sample)
106106

107107

108108
class RandomMixup(Transform):
@@ -115,9 +115,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
115115
return dict(lam=float(self._dist.sample(())))
116116

117117
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
118-
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
119-
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
120-
elif isinstance(input, features.Image):
118+
if isinstance(input, features.Image):
121119
output = F.mixup_image_tensor(input, **params)
122120
return features.Image.new_like(input, output)
123121
elif isinstance(input, features.OneHotLabel):
@@ -126,6 +124,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
126124
else:
127125
return input
128126

127+
def forward(self, *inputs: Any) -> Any:
128+
sample = inputs if len(inputs) > 1 else inputs[0]
129+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
130+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
131+
elif not has_all(sample, features.Image, features.OneHotLabel):
132+
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
133+
return super().forward(sample)
134+
129135

130136
class RandomCutmix(Transform):
131137
def __init__(self, *, alpha: float) -> None:
@@ -157,13 +163,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
157163
return dict(box=box, lam_adjusted=lam_adjusted)
158164

159165
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
160-
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
161-
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
162-
elif isinstance(input, features.Image):
166+
if isinstance(input, features.Image):
163167
output = F.cutmix_image_tensor(input, box=params["box"])
164168
return features.Image.new_like(input, output)
165169
elif isinstance(input, features.OneHotLabel):
166170
output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
167171
return features.OneHotLabel.new_like(input, output)
168172
else:
169173
return input
174+
175+
def forward(self, *inputs: Any) -> Any:
176+
sample = inputs if len(inputs) > 1 else inputs[0]
177+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
178+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
179+
elif not has_all(sample, features.Image, features.OneHotLabel):
180+
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
181+
return super().forward(sample)

0 commit comments

Comments
 (0)