From 6142ba91a6cf03c4b70310d82722403b1f3a4284 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 15:12:29 +0100 Subject: [PATCH 1/9] Remove Dataset2 class --- test/builtin_dataset_mocks.py | 2 +- test/test_prototype_builtin_datasets.py | 2 +- torchvision/prototype/datasets/_api.py | 6 +- .../prototype/datasets/_builtin/caltech.py | 6 +- .../prototype/datasets/_builtin/celeba.py | 4 +- .../prototype/datasets/_builtin/cifar.py | 4 +- .../prototype/datasets/_builtin/clevr.py | 4 +- .../prototype/datasets/_builtin/coco.py | 4 +- .../prototype/datasets/_builtin/country211.py | 4 +- .../prototype/datasets/_builtin/cub200.py | 4 +- .../prototype/datasets/_builtin/dtd.py | 4 +- .../prototype/datasets/_builtin/eurosat.py | 4 +- .../prototype/datasets/_builtin/fer2013.py | 4 +- .../prototype/datasets/_builtin/food101.py | 4 +- .../prototype/datasets/_builtin/gtsrb.py | 4 +- .../prototype/datasets/_builtin/imagenet.py | 4 +- .../prototype/datasets/_builtin/mnist.py | 4 +- .../datasets/_builtin/oxford_iiit_pet.py | 4 +- .../prototype/datasets/_builtin/pcam.py | 4 +- .../prototype/datasets/_builtin/sbd.py | 4 +- .../prototype/datasets/_builtin/semeion.py | 4 +- .../datasets/_builtin/stanford_cars.py | 4 +- .../prototype/datasets/_builtin/svhn.py | 4 +- .../prototype/datasets/_builtin/usps.py | 4 +- .../prototype/datasets/_builtin/voc.py | 4 +- .../prototype/datasets/utils/__init__.py | 2 +- .../prototype/datasets/utils/_dataset.py | 67 +------------------ 27 files changed, 52 insertions(+), 117 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 3a1aac71e4f..768d286e890 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -68,7 +68,7 @@ def prepare(self, home, config): mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) - with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"): + with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"): required_file_names = { resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() } diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 1f7ebf34826..badd78c1264 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -59,7 +59,7 @@ def test_smoke(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) - if not isinstance(dataset, datasets.utils.Dataset2): + if not isinstance(dataset, datasets.utils.Dataset): raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8f8bb53deb4..407dc23f64b 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -2,12 +2,12 @@ from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset2 +from torchvision.prototype.datasets.utils import Dataset from torchvision.prototype.utils._internal import add_suggestion T = TypeVar("T") -D = TypeVar("D", bound=Type[Dataset2]) +D = TypeVar("D", bound=Type[Dataset]) BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} @@ -56,7 +56,7 @@ def info(name: str) -> Dict[str, Any]: return find(BUILTIN_INFOS, name) -def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2: +def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset: dataset_cls = find(BUILTIN_DATASETS, name) if root is None: diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 3701063504f..06a45c443ee 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -9,7 +9,7 @@ Filter, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -31,7 +31,7 @@ def _caltech101_info() -> Dict[str, Any]: @register_dataset("caltech101") -class Caltech101(Dataset2): +class Caltech101(Dataset): """ - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101 - **dependencies**: @@ -161,7 +161,7 @@ def _caltech256_info() -> Dict[str, Any]: @register_dataset("caltech256") -class Caltech256(Dataset2): +class Caltech256(Dataset): """ - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256 """ diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 17a42082f3f..46ccf8de6f7 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -10,7 +10,7 @@ IterKeyZipper, ) from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, GDriveResource, OnlineResource, ) @@ -68,7 +68,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class CelebA(Dataset2): +class CelebA(Dataset): """ - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html """ diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 9274aa543d4..2fe5a2e9035 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -10,7 +10,7 @@ Filter, Mapper, ) -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_shuffling, path_comparator, hint_sharding, BUILTIN_DIR from torchvision.prototype.features import Label, Image @@ -29,7 +29,7 @@ def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]: yield from iter(zip(image_arrays, category_idcs)) -class _CifarBase(Dataset2): +class _CifarBase(Dataset): _FILE_NAME: str _SHA256: str _LABELS_KEY: str diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index 9d322de084c..3a139787c6f 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher -from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, hint_sharding, @@ -24,7 +24,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class CLEVR(Dataset2): +class CLEVR(Dataset): """ - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/ """ diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 75896a8db08..ce6d2bd29bd 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -19,7 +19,7 @@ DatasetInfo, HttpResource, OnlineResource, - Dataset2, + Dataset, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -45,7 +45,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class Coco(Dataset2): +class Coco(Dataset): """ - **homepage**: https://cocodataset.org/ - **dependencies**: diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 461cd71568f..a72e43100e5 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling, BUILTIN_DIR from torchvision.prototype.features import EncodedImage, Label @@ -19,7 +19,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class Country211(Dataset2): +class Country211(Dataset): """ - **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md """ diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 073a790092c..d68d6a1cd0f 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -14,7 +14,7 @@ CSVDictParser, ) from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, DatasetInfo, HttpResource, OnlineResource, @@ -47,7 +47,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class CUB200(Dataset2): +class CUB200(Dataset): """ - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html """ diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index d7f07dc8b30..bacd2faefc0 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -4,7 +4,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, DatasetInfo, HttpResource, OnlineResource, @@ -39,7 +39,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class DTD(Dataset2): +class DTD(Dataset): """DTD Dataset. homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", """ diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py index 00d6a04f320..ab31aaf6f42 100644 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ b/torchvision/prototype/datasets/_builtin/eurosat.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import EncodedImage, Label @@ -29,7 +29,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class EuroSAT(Dataset2): +class EuroSAT(Dataset): """EuroSAT Dataset. homepage="https://github.com/phelber/eurosat", """ diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index ca30b78e609..2fb708e6141 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -4,7 +4,7 @@ import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, OnlineResource, KaggleDownloadResource, ) @@ -25,7 +25,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class FER2013(Dataset2): +class FER2013(Dataset): """FER 2013 Dataset homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" """ diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index c86b9aaea84..f9e923fc4e5 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -9,7 +9,7 @@ Demultiplexer, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, BUILTIN_DIR, @@ -34,7 +34,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class Food101(Dataset2): +class Food101(Dataset): """Food 101 dataset homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", """ diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index fa29f3be780..01f754208e2 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -3,7 +3,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, OnlineResource, HttpResource, ) @@ -28,7 +28,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class GTSRB(Dataset2): +class GTSRB(Dataset): """GTSRB Dataset homepage="https://benchmark.ini.rub.de" diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 56accca02b4..478b083a972 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -18,7 +18,7 @@ DatasetInfo, OnlineResource, ManualDownloadResource, - Dataset2, + Dataset, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -53,7 +53,7 @@ class ImageNetDemux(enum.IntEnum): @register_dataset(NAME) -class ImageNet(Dataset2): +class ImageNet(Dataset): """ - **homepage**: https://www.image-net.org/ """ diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 907faed49bd..e5537a1ef66 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -7,7 +7,7 @@ import torch from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor -from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile @@ -58,7 +58,7 @@ def __iter__(self) -> Iterator[torch.Tensor]: yield read(dtype=dtype, count=count).reshape(shape) -class _MNISTBase(Dataset2): +class _MNISTBase(Dataset): _URL_BASE: Union[str, Sequence[str]] @abc.abstractmethod diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 0ea336a1421..26134722743 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -4,7 +4,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, DatasetInfo, HttpResource, OnlineResource, @@ -39,7 +39,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class OxfordIIITPet(Dataset2): +class OxfordIIITPet(Dataset): """Oxford IIIT Pet Dataset homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", """ diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 1ae94da5665..14b100d5807 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -7,7 +7,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, OnlineResource, GDriveResource, ) @@ -51,7 +51,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class PCAM(Dataset2): +class PCAM(Dataset): # TODO write proper docstring """PCAM Dataset diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index d062d78fe0a..5f6fa84711c 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -11,7 +11,7 @@ IterKeyZipper, LineReader, ) -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -37,7 +37,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class SBD(Dataset2): +class SBD(Dataset): """ - **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html - **dependencies**: diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index e3a802d3cee..c3f5927b65b 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -9,7 +9,7 @@ CSVParser, ) from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, HttpResource, OnlineResource, ) @@ -27,7 +27,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class SEMEION(Dataset2): +class SEMEION(Dataset): """Semeion dataset homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", """ diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 85098eb34e5..6f0bde5d478 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, @@ -37,7 +37,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class StanfordCars(Dataset2): +class StanfordCars(Dataset): """Stanford Cars dataset. homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", dependencies=scipy diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 80c769f6377..175aa6c0a51 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -8,7 +8,7 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, HttpResource, OnlineResource, ) @@ -30,7 +30,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class SVHN(Dataset2): +class SVHN(Dataset): """SVHN Dataset. homepage="http://ufldl.stanford.edu/housenumbers/", dependencies = scipy diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index e1c9940ed86..e732f3b788a 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -3,7 +3,7 @@ import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor -from torchvision.prototype.datasets.utils import Dataset2, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils import Dataset, OnlineResource, HttpResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label @@ -18,7 +18,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class USPS(Dataset2): +class USPS(Dataset): """USPS Dataset homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", """ diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 1f5980bdc72..98f67b63a34 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -13,7 +13,7 @@ LineReader, ) from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import DatasetInfo, OnlineResource, HttpResource, Dataset2 +from torchvision.prototype.datasets.utils import DatasetInfo, OnlineResource, HttpResource, Dataset from torchvision.prototype.datasets.utils._internal import ( path_accessor, getitem, @@ -38,7 +38,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class VOC(Dataset2): +class VOC(Dataset): """ - **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/ """ diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index a16a839b594..885d1bd6184 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip -from ._dataset import DatasetConfig, DatasetInfo, Dataset, Dataset2 +from ._dataset import DatasetConfig, DatasetInfo, Dataset, Dataset from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index a6ec05c3ff4..89817921831 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -120,72 +120,7 @@ def __repr__(self) -> str: return make_repr(type(self).__name__, items) -class Dataset(abc.ABC): - def __init__(self) -> None: - self._info = self._make_info() - - @abc.abstractmethod - def _make_info(self) -> DatasetInfo: - pass - - @property - def info(self) -> DatasetInfo: - return self._info - - @property - def name(self) -> str: - return self.info.name - - @property - def default_config(self) -> DatasetConfig: - return self.info.default_config - - @property - def categories(self) -> Tuple[str, ...]: - return self.info.categories - - @abc.abstractmethod - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - pass - - @abc.abstractmethod - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: - pass - - def supports_sharded(self) -> bool: - return False - - def load( - self, - root: Union[str, pathlib.Path], - *, - config: Optional[DatasetConfig] = None, - skip_integrity_check: bool = False, - ) -> IterDataPipe[Dict[str, Any]]: - if not config: - config = self.info.default_config - - if use_sharded_dataset() and self.supports_sharded(): - root = os.path.join(root, *config.values()) - dataset_size = self.info.extra["sizes"][config] - return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return] - - self.info.check_dependencies() - resource_dps = [ - resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) - ] - return self._make_datapipe(resource_dps, config=config) - - def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: - raise NotImplementedError - - -class Dataset2(IterDataPipe[Dict[str, Any]], abc.ABC): +class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC): @staticmethod def _verify_str_arg( value: str, From 63ddc3a9a1218d96f7f437aeb46d5419cbe831f5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 15:18:25 +0100 Subject: [PATCH 2/9] Move read_categories_file out of DatasetInfo --- torchvision/prototype/datasets/_builtin/caltech.py | 6 +++--- torchvision/prototype/datasets/_builtin/cifar.py | 6 +++--- torchvision/prototype/datasets/_builtin/coco.py | 4 ++-- torchvision/prototype/datasets/_builtin/country211.py | 4 ++-- torchvision/prototype/datasets/_builtin/cub200.py | 4 ++-- torchvision/prototype/datasets/_builtin/dtd.py | 4 ++-- torchvision/prototype/datasets/_builtin/food101.py | 4 ++-- torchvision/prototype/datasets/_builtin/imagenet.py | 4 ++-- torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py | 4 ++-- torchvision/prototype/datasets/_builtin/sbd.py | 4 ++-- torchvision/prototype/datasets/_builtin/stanford_cars.py | 4 ++-- torchvision/prototype/datasets/_builtin/voc.py | 4 ++-- torchvision/prototype/datasets/utils/__init__.py | 2 +- torchvision/prototype/datasets/utils/_dataset.py | 7 ++++--- 14 files changed, 31 insertions(+), 30 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 06a45c443ee..429007b72d3 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -9,7 +9,7 @@ Filter, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -22,7 +22,7 @@ from .._api import register_dataset, register_info -CALTECH101_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech101.categories")) +CALTECH101_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "caltech101.categories")) @register_info("caltech101") @@ -152,7 +152,7 @@ def _generate_categories(self) -> List[str]: return sorted({pathlib.Path(path).parent.name for path, _ in dp}) -CALTECH256_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech256.categories")) +CALTECH256_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "caltech256.categories")) @register_info("caltech256") diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 2fe5a2e9035..b505455ad1c 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -10,7 +10,7 @@ Filter, Mapper, ) -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file from torchvision.prototype.datasets.utils._internal import hint_shuffling, path_comparator, hint_sharding, BUILTIN_DIR from torchvision.prototype.features import Label, Image @@ -92,7 +92,7 @@ def _generate_categories(self) -> List[str]: return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) -CIFAR10_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar10.categories")) +CIFAR10_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "cifar10.categories")) @register_info("cifar10") @@ -118,7 +118,7 @@ def _is_data_file(self, data: Tuple[str, Any]) -> bool: return path.name.startswith("data" if self._split == "train" else "test") -CIFAR100_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar100.categories")) +CIFAR100_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "cifar100.categories")) @register_info("cifar100") diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index ce6d2bd29bd..137ebbb5307 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -16,10 +16,10 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - DatasetInfo, HttpResource, OnlineResource, Dataset, + read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -40,7 +40,7 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + categories, super_categories = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) return dict(categories=categories, super_categories=super_categories) diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index a72e43100e5..b4d1a9804c1 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling, BUILTIN_DIR from torchvision.prototype.features import EncodedImage, Label @@ -10,7 +10,7 @@ NAME = "country211" -CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) @register_info(NAME) diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index d68d6a1cd0f..32dc84cb11e 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -15,9 +15,9 @@ ) from torchvision.prototype.datasets.utils import ( Dataset, - DatasetInfo, HttpResource, OnlineResource, + read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -38,7 +38,7 @@ NAME = "cub200" -CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) @register_info(NAME) diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index bacd2faefc0..20b49114fd7 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -5,9 +5,9 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser from torchvision.prototype.datasets.utils import ( Dataset, - DatasetInfo, HttpResource, OnlineResource, + read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -33,7 +33,7 @@ class DTDDemux(enum.IntEnum): @register_info(NAME) def _info() -> Dict[str, Any]: - categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = read_categories_file(BUILTIN_DIR / f"{NAME}.categories") categories = [c[0] for c in categories] return dict(categories=categories) diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index f9e923fc4e5..3c6e984b30d 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -9,7 +9,7 @@ Demultiplexer, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, BUILTIN_DIR, @@ -28,7 +28,7 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = read_categories_file(BUILTIN_DIR / f"{NAME}.categories") categories = [c[0] for c in categories] return dict(categories=categories) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 478b083a972..05e421a8f8e 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -15,10 +15,10 @@ Enumerator, ) from torchvision.prototype.datasets.utils import ( - DatasetInfo, OnlineResource, ManualDownloadResource, Dataset, + read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -38,7 +38,7 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + categories, wnids = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) return dict(categories=categories, wnids=wnids) diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 26134722743..a617458bd94 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -5,9 +5,9 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchvision.prototype.datasets.utils import ( Dataset, - DatasetInfo, HttpResource, OnlineResource, + read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -33,7 +33,7 @@ class OxfordIIITPetDemux(enum.IntEnum): @register_info(NAME) def _info() -> Dict[str, Any]: - categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = read_categories_file(BUILTIN_DIR / f"{NAME}.categories") categories = [c[0] for c in categories] return dict(categories=categories) diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 5f6fa84711c..6a89dca47ea 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -11,7 +11,7 @@ IterKeyZipper, LineReader, ) -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -28,7 +28,7 @@ NAME = "sbd" -CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) @register_info(NAME) diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 6f0bde5d478..d116d220348 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, @@ -31,7 +31,7 @@ def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]: @register_info(NAME) def _info() -> Dict[str, Any]: - categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = read_categories_file(BUILTIN_DIR / f"{NAME}.categories") categories = [c[0] for c in categories] return dict(categories=categories) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 98f67b63a34..d09cbc3595d 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -13,7 +13,7 @@ LineReader, ) from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import DatasetInfo, OnlineResource, HttpResource, Dataset +from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset, read_categories_file from torchvision.prototype.datasets.utils._internal import ( path_accessor, getitem, @@ -29,7 +29,7 @@ NAME = "voc" -CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) @register_info(NAME) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 885d1bd6184..bf81de7170f 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip -from ._dataset import DatasetConfig, DatasetInfo, Dataset, Dataset +from ._dataset import DatasetConfig, Dataset, read_categories_file from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 89817921831..45b28de4e96 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -22,6 +22,10 @@ class DatasetConfig(FrozenBunch): pass +def read_categories_file(path: pathlib.Path) -> List[List[str]]: + with open(path, newline="") as file: + return [row for row in csv.reader(file)] + class DatasetInfo: def __init__( self, @@ -66,9 +70,6 @@ def default_config(self) -> DatasetConfig: return self._configs[0] @staticmethod - def read_categories_file(path: pathlib.Path) -> List[List[str]]: - with open(path, newline="") as file: - return [row for row in csv.reader(file)] def make_config(self, **options: Any) -> DatasetConfig: if not self._valid_options and options: From 344bbf2b4ef3c52bd51809abff32ca788098b6a9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 15:30:00 +0100 Subject: [PATCH 3/9] Remove FrozenBunch and FrozenMapping --- test/test_prototype_datasets_api.py | 156 +----------------- .../prototype/datasets/utils/__init__.py | 2 +- .../prototype/datasets/utils/_dataset.py | 108 +----------- torchvision/prototype/utils/_internal.py | 82 --------- 4 files changed, 4 insertions(+), 344 deletions(-) diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py index 70a2707d050..d001bb30dac 100644 --- a/test/test_prototype_datasets_api.py +++ b/test/test_prototype_datasets_api.py @@ -2,163 +2,11 @@ import pytest from torchvision.prototype import datasets -from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch def make_minimal_dataset_info(name="name", categories=None, **kwargs): - return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs) - - -class TestFrozenMapping: - @pytest.mark.parametrize( - ("args", "kwargs"), - [ - pytest.param((dict(foo="bar", baz=1),), dict(), id="from_dict"), - pytest.param((), dict(foo="bar", baz=1), id="from_kwargs"), - pytest.param((dict(foo="bar"),), dict(baz=1), id="mixed"), - ], - ) - def test_instantiation(self, args, kwargs): - FrozenMapping(*args, **kwargs) - - def test_unhashable_items(self): - with pytest.raises(TypeError, match="unhashable type"): - FrozenMapping(foo=[]) - - def test_getitem(self): - options = dict(foo="bar", baz=1) - config = FrozenMapping(options) - - for key, value in options.items(): - assert config[key] == value - - def test_getitem_unknown(self): - with pytest.raises(KeyError): - FrozenMapping()["unknown"] - - def test_iter(self): - options = dict(foo="bar", baz=1) - assert set(iter(FrozenMapping(options))) == set(options.keys()) - - def test_len(self): - options = dict(foo="bar", baz=1) - assert len(FrozenMapping(options)) == len(options) - - def test_immutable_setitem(self): - frozen_mapping = FrozenMapping() - - with pytest.raises(RuntimeError, match="immutable"): - frozen_mapping["foo"] = "bar" - - def test_immutable_delitem( - self, - ): - frozen_mapping = FrozenMapping(foo="bar") - - with pytest.raises(RuntimeError, match="immutable"): - del frozen_mapping["foo"] - - def test_eq(self): - options = dict(foo="bar", baz=1) - assert FrozenMapping(options) == FrozenMapping(options) - - def test_ne(self): - options1 = dict(foo="bar", baz=1) - options2 = options1.copy() - options2["baz"] += 1 - - assert FrozenMapping(options1) != FrozenMapping(options2) - - def test_repr(self): - options = dict(foo="bar", baz=1) - output = repr(FrozenMapping(options)) - - assert isinstance(output, str) - for key, value in options.items(): - assert str(key) in output and str(value) in output - - -class TestFrozenBunch: - def test_getattr(self): - options = dict(foo="bar", baz=1) - config = FrozenBunch(options) - - for key, value in options.items(): - assert getattr(config, key) == value - - def test_getattr_unknown(self): - with pytest.raises(AttributeError, match="no attribute 'unknown'"): - datasets.utils.DatasetConfig().unknown - - def test_immutable_setattr(self): - frozen_bunch = FrozenBunch() - - with pytest.raises(RuntimeError, match="immutable"): - frozen_bunch.foo = "bar" - - def test_immutable_delattr( - self, - ): - frozen_bunch = FrozenBunch(foo="bar") - - with pytest.raises(RuntimeError, match="immutable"): - del frozen_bunch.foo - - def test_repr(self): - options = dict(foo="bar", baz=1) - output = repr(FrozenBunch(options)) - - assert isinstance(output, str) - assert output.startswith("FrozenBunch") - for key, value in options.items(): - assert f"{key}={value}" in output - - -class TestDatasetInfo: - @pytest.fixture - def info(self): - return make_minimal_dataset_info(valid_options=dict(split=("train", "test"), foo=("bar", "baz"))) - - def test_default_config(self, info): - valid_options = info._valid_options - default_config = datasets.utils.DatasetConfig({key: values[0] for key, values in valid_options.items()}) - - assert info.default_config == default_config - - @pytest.mark.parametrize( - ("valid_options", "options", "expected_error_msg"), - [ - (dict(), dict(any_option=None), "does not take any options"), - (dict(split="train"), dict(unknown_option=None), "Unknown option 'unknown_option'"), - (dict(split="train"), dict(split="invalid_argument"), "Invalid argument 'invalid_argument'"), - ], - ) - def test_make_config_invalid_inputs(self, info, valid_options, options, expected_error_msg): - info = make_minimal_dataset_info(valid_options=valid_options) - - with pytest.raises(ValueError, match=expected_error_msg): - info.make_config(**options) - - def test_check_dependencies(self): - dependency = "fake_dependency" - info = make_minimal_dataset_info(dependencies=(dependency,)) - with pytest.raises(ModuleNotFoundError, match=dependency): - info.check_dependencies() - - def test_repr(self, info): - output = repr(info) - - assert isinstance(output, str) - assert "DatasetInfo" in output - for key, value in info._valid_options.items(): - assert f"{key}={str(value)[1:-1]}" in output - - @pytest.mark.parametrize("optional_info", ("citation", "homepage", "license")) - def test_repr_optional_info(self, optional_info): - sentinel = "sentinel" - info = make_minimal_dataset_info(**{optional_info: sentinel}) - - assert f"{optional_info}={sentinel}" in repr(info) + # TODO: remove this? + return dict(categories=categories or [], **kwargs) class TestDataset: diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index bf81de7170f..4c9aa683616 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip -from ._dataset import DatasetConfig, Dataset, read_categories_file +from ._dataset import Dataset, read_categories_file from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 45b28de4e96..a9a28ad3c66 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,125 +1,19 @@ import abc import csv import importlib -import itertools -import os import pathlib -from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection, Iterator +from typing import Any, Dict, List, Optional, Sequence, Union, Collection, Iterator from torch.utils.data import IterDataPipe -from torchvision._utils import sequence_to_str from torchvision.datasets.utils import verify_str_arg -from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion -from .._home import use_sharded_dataset -from ._internal import BUILTIN_DIR, _make_sharded_datapipe from ._resource import OnlineResource -class DatasetConfig(FrozenBunch): - # This needs to be Frozen because we often pass configs as partial(func, config=config) - # and partial() requires the parameters to be hashable. - pass - - def read_categories_file(path: pathlib.Path) -> List[List[str]]: with open(path, newline="") as file: return [row for row in csv.reader(file)] -class DatasetInfo: - def __init__( - self, - name: str, - *, - dependencies: Collection[str] = (), - categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, - citation: Optional[str] = None, - homepage: Optional[str] = None, - license: Optional[str] = None, - valid_options: Optional[Dict[str, Sequence[Any]]] = None, - extra: Optional[Dict[str, Any]] = None, - ) -> None: - self.name = name.lower() - - self.dependecies = dependencies - - if categories is None: - path = BUILTIN_DIR / f"{self.name}.categories" - categories = path if path.exists() else [] - if isinstance(categories, int): - categories = [str(label) for label in range(categories)] - elif isinstance(categories, (str, pathlib.Path)): - path = pathlib.Path(categories).expanduser().resolve() - categories, *_ = zip(*self.read_categories_file(path)) - self.categories = tuple(categories) - - self.citation = citation - self.homepage = homepage - self.license = license - - self._valid_options = valid_options or dict() - self._configs = tuple( - DatasetConfig(**dict(zip(self._valid_options.keys(), combination))) - for combination in itertools.product(*self._valid_options.values()) - ) - - self.extra = FrozenBunch(extra or dict()) - - @property - def default_config(self) -> DatasetConfig: - return self._configs[0] - - @staticmethod - - def make_config(self, **options: Any) -> DatasetConfig: - if not self._valid_options and options: - raise ValueError( - f"Dataset {self.name} does not take any options, " - f"but got {sequence_to_str(list(options), separate_last=' and')}." - ) - - for name, arg in options.items(): - if name not in self._valid_options: - raise ValueError( - add_suggestion( - f"Unknown option '{name}' of dataset {self.name}.", - word=name, - possibilities=sorted(self._valid_options.keys()), - ) - ) - - valid_args = self._valid_options[name] - - if arg not in valid_args: - raise ValueError( - add_suggestion( - f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.", - word=arg, - possibilities=valid_args, - ) - ) - - return DatasetConfig(self.default_config, **options) - - def check_dependencies(self) -> None: - for dependency in self.dependecies: - try: - importlib.import_module(dependency) - except ModuleNotFoundError as error: - raise ModuleNotFoundError( - f"Dataset '{self.name}' depends on the third-party package '{dependency}'. " - f"Please install it, for example with `pip install {dependency}`." - ) from error - - def __repr__(self) -> str: - items = [("name", self.name)] - for key in ("citation", "homepage", "license"): - value = getattr(self, key) - if value is not None: - items.append((key, value)) - items.extend(sorted((key, sequence_to_str(value)) for key, value in self._valid_options.items())) - return make_repr(type(self).__name__, items) - class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC): @staticmethod diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index fe5284394cb..2caceb6b186 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -2,20 +2,14 @@ import difflib import io import mmap -import os import os.path import platform -import textwrap from typing import ( Any, BinaryIO, Callable, - cast, Collection, - Iterable, Iterator, - Mapping, - NoReturn, Sequence, Tuple, TypeVar, @@ -30,9 +24,6 @@ __all__ = [ "add_suggestion", - "FrozenMapping", - "make_repr", - "FrozenBunch", "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", @@ -60,82 +51,9 @@ def add_suggestion( return f"{msg.strip()} {hint}" -K = TypeVar("K") D = TypeVar("D") -class FrozenMapping(Mapping[K, D]): - def __init__(self, *args: Any, **kwargs: Any) -> None: - data = dict(*args, **kwargs) - self.__dict__["__data__"] = data - self.__dict__["__final_hash__"] = hash(tuple(data.items())) - - def __getitem__(self, item: K) -> D: - return cast(Mapping[K, D], self.__dict__["__data__"])[item] - - def __iter__(self) -> Iterator[K]: - return iter(self.__dict__["__data__"].keys()) - - def __len__(self) -> int: - return len(self.__dict__["__data__"]) - - def __immutable__(self) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __setitem__(self, key: K, value: Any) -> NoReturn: - self.__immutable__() - - def __delitem__(self, key: K) -> NoReturn: - self.__immutable__() - - def __hash__(self) -> int: - return cast(int, self.__dict__["__final_hash__"]) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, FrozenMapping): - return NotImplemented - - return hash(self) == hash(other) - - def __repr__(self) -> str: - return repr(self.__dict__["__data__"]) - - -def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str: - def to_str(sep: str) -> str: - return sep.join([f"{key}={value}" for key, value in items]) - - prefix = f"{name}(" - postfix = ")" - body = to_str(", ") - - line_length = int(os.environ.get("COLUMNS", 80)) - body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length - multiline_body = len(str(body).splitlines()) > 1 - if not (body_too_long or multiline_body): - return prefix + body + postfix - - body = textwrap.indent(to_str(",\n"), " " * 2) - return f"{prefix}\n{body}\n{postfix}" - - -class FrozenBunch(FrozenMapping): - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError as error: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error - - def __setattr__(self, key: Any, value: Any) -> NoReturn: - self.__immutable__() - - def __delattr__(self, item: Any) -> NoReturn: - self.__immutable__() - - def __repr__(self) -> str: - return make_repr(type(self).__name__, self.items()) - - def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable return bytearray(file.read(-1 if count == -1 else count * item_size)) From 1908fd62eb38151890751d03c8e57d4325f92dbb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 15:53:47 +0100 Subject: [PATCH 4/9] Remove test_prototype_datasets_api.py and move missing dep test somewhere else --- test/test_prototype_datasets_api.py | 79 --------------------------- test/test_prototype_datasets_utils.py | 19 ++++++- 2 files changed, 18 insertions(+), 80 deletions(-) delete mode 100644 test/test_prototype_datasets_api.py diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py deleted file mode 100644 index d001bb30dac..00000000000 --- a/test/test_prototype_datasets_api.py +++ /dev/null @@ -1,79 +0,0 @@ -import unittest.mock - -import pytest -from torchvision.prototype import datasets - - -def make_minimal_dataset_info(name="name", categories=None, **kwargs): - # TODO: remove this? - return dict(categories=categories or [], **kwargs) - - -class TestDataset: - class DatasetMock(datasets.utils.Dataset): - def __init__(self, info=None, *, resources=None): - self._info = info or make_minimal_dataset_info(valid_options=dict(split=("train", "test"))) - self.resources = unittest.mock.Mock(return_value=[]) if resources is None else lambda config: resources - self._make_datapipe = unittest.mock.Mock() - super().__init__() - - def _make_info(self): - return self._info - - def resources(self, config): - # This method is just defined to appease the ABC, but will be overwritten at instantiation - pass - - def _make_datapipe(self, resource_dps, *, config): - # This method is just defined to appease the ABC, but will be overwritten at instantiation - pass - - def test_name(self): - name = "sentinel" - dataset = self.DatasetMock(make_minimal_dataset_info(name=name)) - - assert dataset.name == name - - def test_default_config(self): - sentinel = "sentinel" - dataset = self.DatasetMock(info=make_minimal_dataset_info(valid_options=dict(split=(sentinel, "train")))) - - assert dataset.default_config == datasets.utils.DatasetConfig(split=sentinel) - - @pytest.mark.parametrize( - ("config", "kwarg"), - [ - pytest.param(*(datasets.utils.DatasetConfig(split="test"),) * 2, id="specific"), - pytest.param(DatasetMock().default_config, None, id="default"), - ], - ) - def test_load_config(self, config, kwarg): - dataset = self.DatasetMock() - - dataset.load("", config=kwarg) - - dataset.resources.assert_called_with(config) - - _, call_kwargs = dataset._make_datapipe.call_args - assert call_kwargs["config"] == config - - def test_missing_dependencies(self): - dependency = "fake_dependency" - dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,))) - with pytest.raises(ModuleNotFoundError, match=dependency): - dataset.load("root") - - def test_resources(self, mocker): - resource_mock = mocker.Mock(spec=["load"]) - sentinel = object() - resource_mock.load.return_value = sentinel - dataset = self.DatasetMock(resources=[resource_mock]) - - root = "root" - dataset.load(root) - - (call_args, _) = resource_mock.load.call_args - assert call_args[0] == root - - (call_args, _) = dataset._make_datapipe.call_args - assert call_args[0][0] is sentinel diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index bd857abf02f..d06a5cac421 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -5,7 +5,7 @@ import torch from datasets_utils import make_fake_flo_file from torchvision.datasets._optical_flow import _read_flo as read_flo_ref -from torchvision.prototype.datasets.utils import HttpResource, GDriveResource +from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset from torchvision.prototype.datasets.utils._internal import read_flo, fromfile @@ -101,3 +101,20 @@ def preprocess_sentinel(path): assert redirected_resource.file_name == file_name assert redirected_resource.sha256 == sha256_sentinel assert redirected_resource._preprocess is preprocess_sentinel + + +def test_missing_dependency_error(): + class DummyDataset(Dataset): + def __init__(self): + super().__init__(root="root", dependencies=("fake_dependency",)) + + def _resources(self): + pass + + def _datapipe(self, resource_dps): + pass + def __len__(self): + pass + + with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"): + DummyDataset() From 836018582b412a0f3de1d23cd1c516b2c7d6d4cd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 15:54:29 +0100 Subject: [PATCH 5/9] ufmt --- test/test_prototype_datasets_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index d06a5cac421..b1c95844574 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -113,8 +113,9 @@ def _resources(self): def _datapipe(self, resource_dps): pass + def __len__(self): pass - + with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"): DummyDataset() From 8f473d70fb98eb34ed06bdc3b026f92a87230378 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 16:23:47 +0100 Subject: [PATCH 6/9] Let read_categories_file accept names instead of paths --- torchvision/prototype/datasets/_builtin/caltech.py | 8 ++++---- torchvision/prototype/datasets/_builtin/cifar.py | 13 +++++++++---- torchvision/prototype/datasets/_builtin/coco.py | 5 ++--- .../prototype/datasets/_builtin/country211.py | 11 ++++++++--- torchvision/prototype/datasets/_builtin/cub200.py | 5 ++--- torchvision/prototype/datasets/_builtin/dtd.py | 7 ++----- torchvision/prototype/datasets/_builtin/food101.py | 7 +++---- torchvision/prototype/datasets/_builtin/imagenet.py | 5 ++--- .../prototype/datasets/_builtin/oxford_iiit_pet.py | 7 ++----- torchvision/prototype/datasets/_builtin/sbd.py | 6 +++--- .../prototype/datasets/_builtin/stanford_cars.py | 8 +++----- torchvision/prototype/datasets/_builtin/voc.py | 6 +++--- torchvision/prototype/datasets/utils/__init__.py | 2 +- torchvision/prototype/datasets/utils/_dataset.py | 5 ----- torchvision/prototype/datasets/utils/_internal.py | 10 ++++++++++ 15 files changed, 54 insertions(+), 51 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 429007b72d3..20d05da852e 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -9,20 +9,20 @@ Filter, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling, - BUILTIN_DIR, + read_categories_file, ) from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage from .._api import register_dataset, register_info -CALTECH101_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "caltech101.categories")) +CALTECH101_CATEGORIES, *_ = zip(*read_categories_file("caltech101")) @register_info("caltech101") @@ -152,7 +152,7 @@ def _generate_categories(self) -> List[str]: return sorted({pathlib.Path(path).parent.name for path, _ in dp}) -CALTECH256_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "caltech256.categories")) +CALTECH256_CATEGORIES, *_ = zip(*read_categories_file("caltech256")) @register_info("caltech256") diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index b505455ad1c..33250342005 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -10,8 +10,13 @@ Filter, Mapper, ) -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file -from torchvision.prototype.datasets.utils._internal import hint_shuffling, path_comparator, hint_sharding, BUILTIN_DIR +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + hint_shuffling, + path_comparator, + hint_sharding, + read_categories_file, +) from torchvision.prototype.features import Label, Image from .._api import register_dataset, register_info @@ -92,7 +97,7 @@ def _generate_categories(self) -> List[str]: return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) -CIFAR10_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "cifar10.categories")) +CIFAR10_CATEGORIES, *_ = zip(*read_categories_file("cifar10")) @register_info("cifar10") @@ -118,7 +123,7 @@ def _is_data_file(self, data: Tuple[str, Any]) -> bool: return path.name.startswith("data" if self._split == "train" else "test") -CIFAR100_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "cifar100.categories")) +CIFAR100_CATEGORIES, *_ = zip(*read_categories_file("cifar100")) @register_info("cifar100") diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 137ebbb5307..ff3b5f37c96 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -19,13 +19,12 @@ HttpResource, OnlineResource, Dataset, - read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, INFINITE_BUFFER_SIZE, - BUILTIN_DIR, getitem, + read_categories_file, path_accessor, hint_sharding, hint_shuffling, @@ -40,7 +39,7 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories, super_categories = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + categories, super_categories = zip(*read_categories_file(NAME)) return dict(categories=categories, super_categories=super_categories) diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index b4d1a9804c1..8bcb5130db2 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -2,15 +2,20 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file -from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling, BUILTIN_DIR +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + path_comparator, + hint_sharding, + hint_shuffling, + read_categories_file, +) from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info NAME = "country211" -CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +CATEGORIES, *_ = zip(*read_categories_file(NAME)) @register_info(NAME) diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 32dc84cb11e..1149bbe59ed 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -17,7 +17,6 @@ Dataset, HttpResource, OnlineResource, - read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -26,8 +25,8 @@ hint_shuffling, getitem, path_comparator, + read_categories_file, path_accessor, - BUILTIN_DIR, ) from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage @@ -38,7 +37,7 @@ NAME = "cub200" -CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +CATEGORIES, *_ = zip(*read_categories_file(NAME)) @register_info(NAME) diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index 20b49114fd7..b082ada19ce 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -7,14 +7,13 @@ Dataset, HttpResource, OnlineResource, - read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, hint_sharding, path_comparator, - BUILTIN_DIR, getitem, + read_categories_file, hint_shuffling, ) from torchvision.prototype.features import Label, EncodedImage @@ -33,9 +32,7 @@ class DTDDemux(enum.IntEnum): @register_info(NAME) def _info() -> Dict[str, Any]: - categories = read_categories_file(BUILTIN_DIR / f"{NAME}.categories") - categories = [c[0] for c in categories] - return dict(categories=categories) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index 3c6e984b30d..7343466d18b 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -9,7 +9,7 @@ Demultiplexer, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, BUILTIN_DIR, @@ -17,6 +17,7 @@ path_comparator, getitem, INFINITE_BUFFER_SIZE, + read_categories_file, ) from torchvision.prototype.features import Label, EncodedImage @@ -28,9 +29,7 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories = read_categories_file(BUILTIN_DIR / f"{NAME}.categories") - categories = [c[0] for c in categories] - return dict(categories=categories) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 05e421a8f8e..1307757cef6 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -18,15 +18,14 @@ OnlineResource, ManualDownloadResource, Dataset, - read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, - BUILTIN_DIR, getitem, read_mat, hint_sharding, hint_shuffling, + read_categories_file, path_accessor, ) from torchvision.prototype.features import Label, EncodedImage @@ -38,7 +37,7 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories, wnids = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + categories, wnids = zip(*read_categories_file(NAME)) return dict(categories=categories, wnids=wnids) diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index a617458bd94..f7da02a4765 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -7,15 +7,14 @@ Dataset, HttpResource, OnlineResource, - read_categories_file, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling, - BUILTIN_DIR, getitem, path_accessor, + read_categories_file, path_comparator, ) from torchvision.prototype.features import Label, EncodedImage @@ -33,9 +32,7 @@ class OxfordIIITPetDemux(enum.IntEnum): @register_info(NAME) def _info() -> Dict[str, Any]: - categories = read_categories_file(BUILTIN_DIR / f"{NAME}.categories") - categories = [c[0] for c in categories] - return dict(categories=categories) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 6a89dca47ea..c6e5f45afbc 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -11,7 +11,7 @@ IterKeyZipper, LineReader, ) -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -20,7 +20,7 @@ path_comparator, hint_sharding, hint_shuffling, - BUILTIN_DIR, + read_categories_file, ) from torchvision.prototype.features import _Feature, EncodedImage @@ -28,7 +28,7 @@ NAME = "sbd" -CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +CATEGORIES, *_ = zip(*read_categories_file(NAME)) @register_info(NAME) diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index d116d220348..465d753c2e5 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -2,13 +2,13 @@ from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, path_comparator, read_mat, - BUILTIN_DIR, + read_categories_file, ) from torchvision.prototype.features import BoundingBox, EncodedImage, Label @@ -31,9 +31,7 @@ def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]: @register_info(NAME) def _info() -> Dict[str, Any]: - categories = read_categories_file(BUILTIN_DIR / f"{NAME}.categories") - categories = [c[0] for c in categories] - return dict(categories=categories) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d09cbc3595d..2882e3434b8 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -13,7 +13,7 @@ LineReader, ) from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset, read_categories_file +from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset from torchvision.prototype.datasets.utils._internal import ( path_accessor, getitem, @@ -21,7 +21,7 @@ path_comparator, hint_sharding, hint_shuffling, - BUILTIN_DIR, + read_categories_file, ) from torchvision.prototype.features import BoundingBox, Label, EncodedImage @@ -29,7 +29,7 @@ NAME = "voc" -CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +CATEGORIES, *_ = zip(*read_categories_file(NAME)) @register_info(NAME) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 4c9aa683616..e7ef72f28a9 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip -from ._dataset import Dataset, read_categories_file +from ._dataset import Dataset from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index a9a28ad3c66..62cfa707f6c 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -10,11 +10,6 @@ from ._resource import OnlineResource -def read_categories_file(path: pathlib.Path) -> List[List[str]]: - with open(path, newline="") as file: - return [row for row in csv.reader(file)] - - class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC): @staticmethod def _verify_str_arg( diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index fa48218fe02..21fa15e527b 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,3 +1,4 @@ +import csv import functools import pathlib import pickle @@ -9,6 +10,7 @@ Any, Tuple, TypeVar, + List, Iterator, Dict, IO, @@ -198,3 +200,11 @@ def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter: def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) + + +def read_categories_file(name: str) -> List[List[str]]: + path = BUILTIN_DIR / f"{name}.categories" + with open(path, newline="") as file: + rows = list(csv.reader(file)) + rows = [row[0] if len(row) == 1 else row for row in rows] + return rows From 70d443b845809c94fbb1fa18aed191c84d47bde6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 16:28:26 +0100 Subject: [PATCH 7/9] Mypy --- torchvision/prototype/datasets/_builtin/pcam.py | 2 +- torchvision/prototype/datasets/utils/_internal.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 14b100d5807..1ce9f45e673 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -130,5 +130,5 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def __len__(self): + def __len__(self) -> int: return 262_144 if self._split == "train" else 32_768 diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 21fa15e527b..007e91eb657 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -202,7 +202,7 @@ def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) -def read_categories_file(name: str) -> List[List[str]]: +def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]: path = BUILTIN_DIR / f"{name}.categories" with open(path, newline="") as file: rows = list(csv.reader(file)) From 565f834a83e36c6ad0250ea5079274c2f0cce0e5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 16:42:41 +0100 Subject: [PATCH 8/9] flake8 --- torchvision/prototype/datasets/_builtin/fer2013.py | 2 +- torchvision/prototype/datasets/_builtin/food101.py | 1 - torchvision/prototype/datasets/_builtin/pcam.py | 1 - torchvision/prototype/datasets/_builtin/semeion.py | 1 - torchvision/prototype/datasets/utils/_dataset.py | 1 - torchvision/prototype/utils/_internal.py | 1 - 6 files changed, 1 insertion(+), 6 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index 2fb708e6141..c1a914c6f63 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Dict, List, cast, Union +from typing import Any, Dict, List, Union import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index 7343466d18b..5100e5d8c74 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -12,7 +12,6 @@ from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, - BUILTIN_DIR, hint_sharding, path_comparator, getitem, diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 1ce9f45e673..7cd31469139 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -2,7 +2,6 @@ import pathlib from collections import namedtuple from typing import Any, Dict, List, Optional, Tuple, Iterator, Union -from unicodedata import category from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index c3f5927b65b..5051bde4047 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -2,7 +2,6 @@ from typing import Any, Dict, List, Tuple, Union import torch -from pytest import skip from torchdata.datapipes.iter import ( IterDataPipe, Mapper, diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 62cfa707f6c..528d0a0f25f 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,5 +1,4 @@ import abc -import csv import importlib import pathlib from typing import Any, Dict, List, Optional, Sequence, Union, Collection, Iterator diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 2caceb6b186..233128880e3 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -2,7 +2,6 @@ import difflib import io import mmap -import os.path import platform from typing import ( Any, From 6b95df09f7f0f72f244d5794f7c7a1de0056ebcb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 7 Apr 2022 08:29:49 +0200 Subject: [PATCH 9/9] fix category file reading --- torchvision/prototype/datasets/_builtin/caltech.py | 10 ++-------- torchvision/prototype/datasets/_builtin/cifar.py | 12 +++--------- .../prototype/datasets/_builtin/country211.py | 4 +--- torchvision/prototype/datasets/_builtin/cub200.py | 4 +--- torchvision/prototype/datasets/_builtin/sbd.py | 6 ++---- torchvision/prototype/datasets/_builtin/voc.py | 4 +--- 6 files changed, 10 insertions(+), 30 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 20d05da852e..7010ab9503d 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -22,12 +22,9 @@ from .._api import register_dataset, register_info -CALTECH101_CATEGORIES, *_ = zip(*read_categories_file("caltech101")) - - @register_info("caltech101") def _caltech101_info() -> Dict[str, Any]: - return dict(categories=CALTECH101_CATEGORIES) + return dict(categories=read_categories_file("caltech101")) @register_dataset("caltech101") @@ -152,12 +149,9 @@ def _generate_categories(self) -> List[str]: return sorted({pathlib.Path(path).parent.name for path, _ in dp}) -CALTECH256_CATEGORIES, *_ = zip(*read_categories_file("caltech256")) - - @register_info("caltech256") def _caltech256_info() -> Dict[str, Any]: - return dict(categories=CALTECH256_CATEGORIES) + return dict(categories=read_categories_file("caltech256")) @register_dataset("caltech256") diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 33250342005..514938d6e5f 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -40,7 +40,7 @@ class _CifarBase(Dataset): _LABELS_KEY: str _META_FILE_NAME: str _CATEGORIES_KEY: str - # _categories: List[str] + _categories: List[str] def __init__( self, @@ -97,12 +97,9 @@ def _generate_categories(self) -> List[str]: return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) -CIFAR10_CATEGORIES, *_ = zip(*read_categories_file("cifar10")) - - @register_info("cifar10") def _cifar10_info() -> Dict[str, Any]: - return dict(categories=CIFAR10_CATEGORIES) + return dict(categories=read_categories_file("cifar10")) @register_dataset("cifar10") @@ -123,12 +120,9 @@ def _is_data_file(self, data: Tuple[str, Any]) -> bool: return path.name.startswith("data" if self._split == "train" else "test") -CIFAR100_CATEGORIES, *_ = zip(*read_categories_file("cifar100")) - - @register_info("cifar100") def _cifar100_info() -> Dict[str, Any]: - return dict(categories=CIFAR10_CATEGORIES) + return dict(categories=read_categories_file("cifar100")) @register_dataset("cifar100") diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 8bcb5130db2..012ecae19e2 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -15,12 +15,10 @@ NAME = "country211" -CATEGORIES, *_ = zip(*read_categories_file(NAME)) - @register_info(NAME) def _info() -> Dict[str, Any]: - return dict(categories=CATEGORIES) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 1149bbe59ed..1e4db7cef73 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -37,12 +37,10 @@ NAME = "cub200" -CATEGORIES, *_ = zip(*read_categories_file(NAME)) - @register_info(NAME) def _info() -> Dict[str, Any]: - return dict(categories=CATEGORIES) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index c6e5f45afbc..0c806fe098c 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -28,12 +28,10 @@ NAME = "sbd" -CATEGORIES, *_ = zip(*read_categories_file(NAME)) - @register_info(NAME) def _info() -> Dict[str, Any]: - return dict(categories=CATEGORIES) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) @@ -53,7 +51,7 @@ def __init__( ) -> None: self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval")) - self._categories = CATEGORIES + self._categories = _info()["categories"] super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 2882e3434b8..05a3c2e8622 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -29,12 +29,10 @@ NAME = "voc" -CATEGORIES, *_ = zip(*read_categories_file(NAME)) - @register_info(NAME) def _info() -> Dict[str, Any]: - return dict(categories=CATEGORIES) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME)