Skip to content

Commit d32bc4b

Browse files
authored
Revamp prototype features and transforms (#5407)
* revamp prototype features (#5283) * remove decoding from prototype datasets (#5287) * remove decoder from prototype datasets * remove unused imports * cleanup * fix readme * use OneHotLabel in SEMEION * improve voc implementation * revert unrelated changes * fix semeion mock data * fix pcam * readd functional transforms API to prototype (#5295) * readd functional transforms * cleanup * add missing imports * remove __torch_function__ dispatch * readd repr * readd empty line * add test for scriptability * remove function copy * change import from functional tensor transforms to just functional * fix import * fix test * fix prototype features and functional transforms after review (#5377) * fix prototype functional transforms after review * address features review * make mypy more strict on prototype features * make mypy more strict for prototype transforms * fix annotation * fix kernel tests * add automatic feature type dispatch to functional transforms (#5323) * add auto dispatch * fix missing arguments error message * remove pil kernel for erase * automate feature specific parameter detection * fix typos * cleanup dispatcher call * remove __torch_function__ from transform dispatch * remove auto-generation * revert unrelated changes * remove implements decorator * change register parameter order * change order of transforms for readability * add documentation for __torch_function__ * fix mypy * inline check for support * refactor kernel registering process * refactor dispatch to be a regular decorator * split kernels and dispatchers * remove sentinels * replace pass with ... * appease mypy * make single kernel dispatchers more concise * make dispatcher signatures more generic * make kernel checking more strict * revert doc changes * address Franciscos comments * remove inplace * rename kernel test module * fix inplace * remove special casing for pil and vanilla tensors * address comments * update docs * cleanup features / transforms feature branch (#5406) * mark candidates for removal * align signature of resize_bounding_box with corresponding image kernel * fix documentation of Feature * remove interpolation mode and antialias option from resize_segmentation_mask * remove or privatize functionality in features / datasets / transforms
1 parent f2f490b commit d32bc4b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1780
-1966
lines changed

mypy.ini

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,36 @@ pretty = True
66
allow_redefinition = True
77
warn_redundant_casts = True
88

9+
[mypy-torchvision.prototype.features.*]
10+
11+
; untyped definitions and calls
12+
disallow_untyped_defs = True
13+
14+
; None and Optional handling
15+
no_implicit_optional = True
16+
17+
; warnings
18+
warn_unused_ignores = True
19+
warn_return_any = True
20+
21+
; miscellaneous strictness flags
22+
allow_redefinition = True
23+
24+
[mypy-torchvision.prototype.transforms.*]
25+
26+
; untyped definitions and calls
27+
disallow_untyped_defs = True
28+
29+
; None and Optional handling
30+
no_implicit_optional = True
31+
32+
; warnings
33+
warn_unused_ignores = True
34+
warn_return_any = True
35+
36+
; miscellaneous strictness flags
37+
allow_redefinition = True
38+
939
[mypy-torchvision.prototype.datasets.*]
1040

1141
; untyped definitions and calls

test/builtin_dataset_mocks.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -432,50 +432,52 @@ def caltech256(info, root, config):
432432

433433
@register_mock
434434
def imagenet(info, root, config):
435-
wnids = tuple(info.extra.wnid_to_category.keys())
436-
if config.split == "train":
437-
images_root = root / "ILSVRC2012_img_train"
435+
from scipy.io import savemat
438436

437+
categories = info.categories
438+
wnids = [info.extra.category_to_wnid[category] for category in categories]
439+
if config.split == "train":
439440
num_samples = len(wnids)
441+
archive_name = "ILSVRC2012_img_train.tar"
440442

443+
files = []
441444
for wnid in wnids:
442-
files = create_image_folder(
443-
root=images_root,
445+
create_image_folder(
446+
root=root,
444447
name=wnid,
445448
file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG",
446449
num_examples=1,
447450
)
448-
make_tar(images_root, f"{wnid}.tar", files[0].parent)
451+
files.append(make_tar(root, f"{wnid}.tar"))
449452
elif config.split == "val":
450453
num_samples = 3
451-
files = create_image_folder(
452-
root=root,
453-
name="ILSVRC2012_img_val",
454-
file_name_fn=lambda image_idx: f"ILSVRC2012_val_{image_idx + 1:08d}.JPEG",
455-
num_examples=num_samples,
456-
)
457-
images_root = files[0].parent
458-
else: # config.split == "test"
459-
images_root = root / "ILSVRC2012_img_test_v10102019"
454+
archive_name = "ILSVRC2012_img_val.tar"
455+
files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
460456

461-
num_samples = 3
457+
devkit_root = root / "ILSVRC2012_devkit_t12"
458+
data_root = devkit_root / "data"
459+
data_root.mkdir(parents=True)
462460

463-
create_image_folder(
464-
root=images_root,
465-
name="test",
466-
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
467-
num_examples=num_samples,
468-
)
469-
make_tar(root, f"{images_root.name}.tar", images_root)
461+
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
462+
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
463+
file.write(f"{label}\n")
464+
465+
num_children = 0
466+
synsets = [
467+
(idx, wnid, category, "", num_children, [], 0, 0)
468+
for idx, (category, wnid) in enumerate(zip(categories, wnids), 1)
469+
]
470+
num_children = 1
471+
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
472+
savemat(data_root / "meta.mat", dict(synsets=synsets))
473+
474+
make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
475+
else: # config.split == "test"
476+
num_samples = 5
477+
archive_name = "ILSVRC2012_img_test_v10102019.tar"
478+
files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
470479

471-
devkit_root = root / "ILSVRC2012_devkit_t12"
472-
devkit_root.mkdir()
473-
data_root = devkit_root / "data"
474-
data_root.mkdir()
475-
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
476-
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
477-
file.write(f"{label}\n")
478-
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
480+
make_tar(root, archive_name, *files)
479481

480482
return num_samples
481483

@@ -667,14 +669,15 @@ def sbd(info, root, config):
667669
@register_mock
668670
def semeion(info, root, config):
669671
num_samples = 3
672+
num_categories = len(info.categories)
670673

671674
images = torch.rand(num_samples, 256)
672-
labels = one_hot(torch.randint(len(info.categories), size=(num_samples,)))
675+
labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories)
673676
with open(root / "semeion.data", "w") as fh:
674677
for image, one_hot_label in zip(images, labels):
675678
image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image])
676679
labels_columns = " ".join([str(label.item()) for label in one_hot_label])
677-
fh.write(f"{image_columns} {labels_columns}\n")
680+
fh.write(f"{image_columns} {labels_columns} \n")
678681

679682
return num_samples
680683

@@ -729,32 +732,33 @@ def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples):
729732
def _make_detection_ann_file(cls, root, name):
730733
def add_child(parent, name, text=None):
731734
child = ET.SubElement(parent, name)
732-
child.text = text
735+
child.text = str(text)
733736
return child
734737

735738
def add_name(obj, name="dog"):
736739
add_child(obj, "name", name)
737-
return name
738740

739-
def add_bndbox(obj, bndbox=None):
740-
if bndbox is None:
741-
bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"}
741+
def add_size(obj):
742+
obj = add_child(obj, "size")
743+
size = {"width": 0, "height": 0, "depth": 3}
744+
for name, text in size.items():
745+
add_child(obj, name, text)
742746

747+
def add_bndbox(obj):
743748
obj = add_child(obj, "bndbox")
749+
bndbox = {"xmin": 1, "xmax": 2, "ymin": 3, "ymax": 4}
744750
for name, text in bndbox.items():
745751
add_child(obj, name, text)
746752

747-
return bndbox
748-
749753
annotation = ET.Element("annotation")
754+
add_size(annotation)
750755
obj = add_child(annotation, "object")
751-
data = dict(name=add_name(obj), bndbox=add_bndbox(obj))
756+
add_name(obj)
757+
add_bndbox(obj)
752758

753759
with open(root / name, "wb") as fh:
754760
fh.write(ET.tostring(annotation))
755761

756-
return data
757-
758762
@classmethod
759763
def generate(cls, root, *, year, trainval):
760764
archive_folder = root

test/test_prototype_builtin_datasets.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1+
import functools
12
import io
23
from pathlib import Path
34

45
import pytest
56
import torch
67
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
8+
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
79
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
810
from torch.utils.data.graph import traverse
911
from torchdata.datapipes.iter import IterDataPipe, Shuffler
1012
from torchvision.prototype import transforms, datasets
1113
from torchvision.prototype.utils._internal import sequence_to_str
1214

1315

16+
assert_samples_equal = functools.partial(
17+
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
18+
)
19+
20+
1421
@pytest.fixture
1522
def test_home(mocker, tmp_path):
1623
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
@@ -92,6 +99,7 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
9299
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
93100
)
94101

102+
@pytest.mark.xfail
95103
@parametrize_dataset_mocks(DATASET_MOCKS)
96104
def test_transformable(self, test_home, dataset_mock, config):
97105
dataset_mock.prepare(test_home, config)
@@ -137,6 +145,17 @@ def scan(graph):
137145
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
138146
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
139147

148+
@parametrize_dataset_mocks(DATASET_MOCKS)
149+
def test_save_load(self, test_home, dataset_mock, config):
150+
dataset_mock.prepare(test_home, config)
151+
dataset = datasets.load(dataset_mock.name, **config)
152+
sample = next(iter(dataset))
153+
154+
with io.BytesIO() as buffer:
155+
torch.save(sample, buffer)
156+
buffer.seek(0)
157+
assert_samples_equal(torch.load(buffer), sample)
158+
140159

141160
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
142161
class TestQMNIST:
@@ -171,5 +190,5 @@ def test_label_matches_path(self, test_home, dataset_mock, config):
171190
dataset = datasets.load(dataset_mock.name, **config)
172191

173192
for sample in dataset:
174-
label_from_path = int(Path(sample["image_path"]).parent.name)
193+
label_from_path = int(Path(sample["path"]).parent.name)
175194
assert sample["label"] == label_from_path

test/test_prototype_datasets_api.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch
66

77

8-
def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):
9-
return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs)
8+
def make_minimal_dataset_info(name="name", categories=None, **kwargs):
9+
return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs)
1010

1111

1212
class TestFrozenMapping:
@@ -176,7 +176,7 @@ def resources(self, config):
176176
# This method is just defined to appease the ABC, but will be overwritten at instantiation
177177
pass
178178

179-
def _make_datapipe(self, resource_dps, *, config, decoder):
179+
def _make_datapipe(self, resource_dps, *, config):
180180
# This method is just defined to appease the ABC, but will be overwritten at instantiation
181181
pass
182182

@@ -229,12 +229,3 @@ def test_resources(self, mocker):
229229

230230
(call_args, _) = dataset._make_datapipe.call_args
231231
assert call_args[0][0] is sentinel
232-
233-
def test_decoder(self):
234-
dataset = self.DatasetMock()
235-
236-
sentinel = object()
237-
dataset.load("", decoder=sentinel)
238-
239-
(_, call_kwargs) = dataset._make_datapipe.call_args
240-
assert call_kwargs["decoder"] is sentinel

0 commit comments

Comments
 (0)