Skip to content

Commit ac1512b

Browse files
vfdev-5NicolasHug
andauthored
Added wrap_dataset_for_transforms_v2 into datasets and handled beta w… (#7279)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 56b0497 commit ac1512b

File tree

7 files changed

+88
-49
lines changed

7 files changed

+88
-49
lines changed

.github/workflows/test-linux-cpu.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
# Create Conda Env
4242
conda create -yp ci_env python="${PYTHON_VERSION}" numpy libpng jpeg scipy
4343
conda activate /work/ci_env
44-
44+
4545
# Install PyTorch, Torchvision, and testing libraries
4646
set -ex
4747
conda install \
@@ -55,3 +55,9 @@ jobs:
5555
# Run Tests
5656
python3 -m torch.utils.collect_env
5757
python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20
58+
59+
# Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2"
60+
# We keep them separate to avoid any side effects due to warnings / imports.
61+
# TODO: Remove this and add proper tests (possibly using a sub-process solution as described
62+
# in https://github.com/pytorch/vision/pull/7269).
63+
python3 -m pytest -v test/check_v2_dataset_warnings.py

test/check_v2_dataset_warnings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
4+
def test_warns_if_imported_from_datasets(mocker):
5+
mocker.patch("torchvision._WARN_ABOUT_BETA_TRANSFORMS", return_value=True)
6+
7+
import torchvision
8+
9+
with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING):
10+
from torchvision.datasets import wrap_dataset_for_transforms_v2
11+
12+
assert callable(wrap_dataset_for_transforms_v2)
13+
14+
15+
@pytest.mark.filterwarnings("error")
16+
def test_no_warns_if_imported_from_datasets():
17+
from torchvision.datasets import wrap_dataset_for_transforms_v2
18+
19+
assert callable(wrap_dataset_for_transforms_v2)

test/datasets_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,8 @@ def test_transforms(self, config):
584584

585585
@test_all_configs
586586
def test_transforms_v2_wrapper(self, config):
587-
from torchvision.datapoints import wrap_dataset_for_transforms_v2
588587
from torchvision.datapoints._datapoint import Datapoint
588+
from torchvision.datasets import wrap_dataset_for_transforms_v2
589589

590590
try:
591591
with self.create_dataset(config) as (dataset, _):

test/test_datasets.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pathlib
99
import pickle
1010
import random
11+
import re
1112
import shutil
1213
import string
1314
import unittest
@@ -3309,5 +3310,47 @@ def test_bad_input(self):
33093310
pass
33103311

33113312

3313+
class TestDatasetWrapper:
3314+
def test_unknown_type(self):
3315+
unknown_object = object()
3316+
with pytest.raises(
3317+
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
3318+
):
3319+
datasets.wrap_dataset_for_transforms_v2(unknown_object)
3320+
3321+
def test_unknown_dataset(self):
3322+
class MyVisionDataset(datasets.VisionDataset):
3323+
pass
3324+
3325+
dataset = MyVisionDataset("root")
3326+
3327+
with pytest.raises(TypeError, match="No wrapper exist"):
3328+
datasets.wrap_dataset_for_transforms_v2(dataset)
3329+
3330+
def test_missing_wrapper(self):
3331+
dataset = datasets.FakeData()
3332+
3333+
with pytest.raises(TypeError, match="please open an issue"):
3334+
datasets.wrap_dataset_for_transforms_v2(dataset)
3335+
3336+
def test_subclass(self, mocker):
3337+
from torchvision import datapoints
3338+
3339+
sentinel = object()
3340+
mocker.patch.dict(
3341+
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
3342+
clear=False,
3343+
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
3344+
)
3345+
3346+
class MyFakeData(datasets.FakeData):
3347+
pass
3348+
3349+
dataset = MyFakeData()
3350+
wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
3351+
3352+
assert wrapped_dataset[0] is sentinel
3353+
3354+
33123355
if __name__ == "__main__":
33133356
unittest.main()

test/test_prototype_datapoints.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import re
2-
31
import pytest
42
import torch
53

64
from PIL import Image
75

8-
from torchvision import datapoints, datasets
6+
from torchvision import datapoints
97
from torchvision.prototype import datapoints as proto_datapoints
108

119

@@ -163,43 +161,3 @@ def test_bbox_instance(data, format):
163161
if isinstance(format, str):
164162
format = datapoints.BoundingBoxFormat.from_str(format.upper())
165163
assert bboxes.format == format
166-
167-
168-
class TestDatasetWrapper:
169-
def test_unknown_type(self):
170-
unknown_object = object()
171-
with pytest.raises(
172-
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
173-
):
174-
datapoints.wrap_dataset_for_transforms_v2(unknown_object)
175-
176-
def test_unknown_dataset(self):
177-
class MyVisionDataset(datasets.VisionDataset):
178-
pass
179-
180-
dataset = MyVisionDataset("root")
181-
182-
with pytest.raises(TypeError, match="No wrapper exist"):
183-
datapoints.wrap_dataset_for_transforms_v2(dataset)
184-
185-
def test_missing_wrapper(self):
186-
dataset = datasets.FakeData()
187-
188-
with pytest.raises(TypeError, match="please open an issue"):
189-
datapoints.wrap_dataset_for_transforms_v2(dataset)
190-
191-
def test_subclass(self, mocker):
192-
sentinel = object()
193-
mocker.patch.dict(
194-
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
195-
clear=False,
196-
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
197-
)
198-
199-
class MyFakeData(datasets.FakeData):
200-
pass
201-
202-
dataset = MyFakeData()
203-
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)
204-
205-
assert wrapped_dataset[0] is sentinel

torchvision/datapoints/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1+
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
2+
13
from ._bounding_box import BoundingBox, BoundingBoxFormat
24
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
35
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
46
from ._mask import Mask
57
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
68

7-
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
8-
9-
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
10-
119
if _WARN_ABOUT_BETA_TRANSFORMS:
1210
import warnings
1311

torchvision/datasets/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,18 @@
128128
"InStereo2k",
129129
"ETH3DStereo",
130130
)
131+
132+
133+
# We override current module's attributes to handle the import:
134+
# from torchvision.datasets import wrap_dataset_for_transforms_v2
135+
# with beta state v2 warning from torchvision.datapoints
136+
# We also want to avoid raising the warning when importing other attributes
137+
# from torchvision.datasets
138+
# Ref: https://peps.python.org/pep-0562/
139+
def __getattr__(name):
140+
if name in ("wrap_dataset_for_transforms_v2",):
141+
from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2
142+
143+
return wrap_dataset_for_transforms_v2
144+
145+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

0 commit comments

Comments
 (0)