From c742860bf74141ebd7e4943228687852914848c7 Mon Sep 17 00:00:00 2001 From: saswatpp Date: Mon, 27 Dec 2021 23:56:42 +0530 Subject: [PATCH 1/5] dataset class added --- docs/source/datasets.rst | 1 + torchvision/datasets/__init__.py | 2 + torchvision/datasets/sun397.py | 122 +++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 torchvision/datasets/sun397.py diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 3a2872a6388..110f3e3a514 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -67,6 +67,7 @@ You can also create your own datasets using the provided :ref:`base classes `_. + The SUN397 is a dataset for scene recognition consisting of 397 categories with 108'754 images. + The dataset also provides 10 paritions for training and testing, with each partition + consisting of 50 images per class. + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. + parition (string, integer, optional): A valid partition can be an integer from 1 to 10 or ``"all"`` + for the entire dataset. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" + _FILENAME = "SUN397.tar.gz" + _MD5 = "8ca2778205c41d23104230ba66911c7a" + _PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip" + _PARTITIONS_FILENAME = "Partitions.zip" + + def __init__( + self, + root: str, + split: str = "train", + partition: Union[int,str] = 1, + download: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.split = split + self.partition = partition + self.data_dir = Path(self.root) / "SUN397" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._labels = [] + self._image_files = [] + with open(self.data_dir / f"ClassName.txt", "r") as f: + classes = f.read().splitlines() + + for idx,c in enumerate(classes): + classes[idx] = c[3:] + + self.classes = classes + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + + if isinstance(self.partition,int): + if self.partition<0 or self.partition>10: + raise RuntimeError("Enter a valid integer partition from 1 to 10 or \"all\" ") + + splitname = "Training" if self.split is "train" else "Testing" + zero = "0" if self.partition<10 else "" + + with open(self.data_dir / f"{splitname}_{zero}{self.partition}.txt", "r") as f: + pathlist = f.read().splitlines() + + for p in pathlist: + self._labels.append(self.class_to_idx[p[3:-25]]) + self._image_files.append(self.data_dir.joinpath(*p.split("/"))) + + else: + if self.partition is not "all": + raise RuntimeError("Enter a valid integer partition from 1 to 10 or \"all\" ") + else: + for path, _, files in os.walk(self.data_dir): + for file in files: + if(file[:3]=="sun"): + self._image_files.append(Path(path)/file) + self._labels.append(Path(path).relative_to(self.data_dir).as_posix()[2:]) + + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def _check_exists(self) -> bool: + file = Path(self.root) / self._FILENAME + if not check_integrity(file, self._MD5): + return False + elif self._PARTITIONS_FILENAME not in os.listdir(self.data_dir): + return False + else: + return True + + def extra_repr(self) -> str: + return "Split: {split}".format(**self.__dict__) + + def _download(self) -> None: + file = Path(self.root) / self._FILENAME + if self._FILENAME not in os.listdir(self.data_dir) or not check_integrity(file, self._MD5): + download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) + + if self._PARTITIONS_FILENAME not in os.listdir(self.data_dir): + download_and_extract_archive(self._PARTITIONS_URL, download_root=self.data_dir) \ No newline at end of file From 66ddf2f5f50ba045512fed8e2243c96e47f80f70 Mon Sep 17 00:00:00 2001 From: saswatpp Date: Sat, 1 Jan 2022 01:51:54 +0530 Subject: [PATCH 2/5] fix code format --- test/test_datasets.py | 40 +++++++++++++++ torchvision/datasets/sun397.py | 92 ++++++++++++++-------------------- 2 files changed, 79 insertions(+), 53 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index edda3aeaf64..889a7fdba0f 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2205,5 +2205,45 @@ def inject_fake_data(self, tmpdir: str, config): return len(sampled_classes * n_samples_per_class) +class SUN397TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.SUN397 + FEATURE_TYPES = (PIL.Image.Image, int) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "test"), + partition=(1, 10, None), + ) + + def inject_fake_data(self, tmpdir: str, config): + data_dir = pathlib.Path(tmpdir) / "SUN397" + + data_dir.mkdir(parents=True) + + num_images_per_class = 5 + sampled_classes = ("abbey", "airplane_cabin", "airport_terminal") + im_paths = [] + for cls in sampled_classes: + image_folder = data_dir / cls[0] + im_paths.append( + datasets_utils.create_image_folder( + image_folder, + image_folder / cls, + file_name_fn=lambda idx: f"sun_{idx}.jpg", + num_examples=num_images_per_class, + ) + ) + + with open(data_dir / "ClassName.txt", "a") as file: + file.write("/" + cls[0] + "/" + cls + "\n") + + if config["partition"] is not None: + with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file: + for f_pathlist in im_paths: + for f_path in f_pathlist: + file.write("/" + str(pathlib.Path(f_path).relative_to(data_dir).as_posix()) + "\n") + + return len(sampled_classes * num_images_per_class) + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index bf2717810e6..4be9e0e95c9 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -1,19 +1,19 @@ import os from pathlib import Path -from typing import Any, Tuple, Callable, Optional, Union +from typing import Any, Tuple, Callable, Optional import PIL.Image -from .utils import verify_str_arg, download_and_extract_archive, check_integrity +from .utils import verify_str_arg, download_and_extract_archive from .vision import VisionDataset class SUN397(VisionDataset): """`The SUN397 Data Set `_. The SUN397 is a dataset for scene recognition consisting of 397 categories with 108'754 images. - The dataset also provides 10 paritions for training and testing, with each partition - consisting of 50 images per class. - + The dataset also provides 10 paritions for training and testing, with each partition + consisting of 50 images per class. + Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. @@ -34,57 +34,51 @@ def __init__( self, root: str, split: str = "train", - partition: Union[int,str] = 1, + partition: Optional[int] = 1, download: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) - self.split = split + self.split = verify_str_arg(split, "split", ("train", "test")) self.partition = partition self.data_dir = Path(self.root) / "SUN397" + if self.partition is not None: + if self.partition < 0 or self.partition > 10: + raise RuntimeError("Enter a valid integer partition from 1 to 10 or None, for entire dataset") + if download: self._download() if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - + + with open(self.data_dir / "ClassName.txt", "r") as f: + self.classes = [c[3:].strip() for c in f] + + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + self._labels = [] self._image_files = [] - with open(self.data_dir / f"ClassName.txt", "r") as f: - classes = f.read().splitlines() - - for idx,c in enumerate(classes): - classes[idx] = c[3:] + if self.partition is not None: + with open(self.data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f: + self._image_files, self._labels = zip( + *( + ( + self.data_dir.joinpath(*posix_rel_path.split("/")), + self.class_to_idx["/".join(Path(posix_rel_path).parts[1:-1])], + ) + for posix_rel_path in (line.strip()[1:] for line in f) + ) + ) - self.classes = classes - self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) - - if isinstance(self.partition,int): - if self.partition<0 or self.partition>10: - raise RuntimeError("Enter a valid integer partition from 1 to 10 or \"all\" ") - - splitname = "Training" if self.split is "train" else "Testing" - zero = "0" if self.partition<10 else "" - - with open(self.data_dir / f"{splitname}_{zero}{self.partition}.txt", "r") as f: - pathlist = f.read().splitlines() - - for p in pathlist: - self._labels.append(self.class_to_idx[p[3:-25]]) - self._image_files.append(self.data_dir.joinpath(*p.split("/"))) - else: - if self.partition is not "all": - raise RuntimeError("Enter a valid integer partition from 1 to 10 or \"all\" ") - else: - for path, _, files in os.walk(self.data_dir): - for file in files: - if(file[:3]=="sun"): - self._image_files.append(Path(path)/file) - self._labels.append(Path(path).relative_to(self.data_dir).as_posix()[2:]) - + for path, _, files in os.walk(self.data_dir): + for file in files: + if file[:3] == "sun": + self._image_files.append(Path(path) / file) + self._labels.append(Path(path).relative_to(self.data_dir).as_posix()[2:]) def __len__(self) -> int: return len(self._image_files) @@ -102,21 +96,13 @@ def __getitem__(self, idx) -> Tuple[Any, Any]: return image, label def _check_exists(self) -> bool: - file = Path(self.root) / self._FILENAME - if not check_integrity(file, self._MD5): - return False - elif self._PARTITIONS_FILENAME not in os.listdir(self.data_dir): - return False - else: - return True - + return self.data_dir.exists() and self.data_dir.is_dir() + def extra_repr(self) -> str: return "Split: {split}".format(**self.__dict__) def _download(self) -> None: - file = Path(self.root) / self._FILENAME - if self._FILENAME not in os.listdir(self.data_dir) or not check_integrity(file, self._MD5): - download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) - - if self._PARTITIONS_FILENAME not in os.listdir(self.data_dir): - download_and_extract_archive(self._PARTITIONS_URL, download_root=self.data_dir) \ No newline at end of file + if self._check_exists: + return + download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) + download_and_extract_archive(self._PARTITIONS_URL, download_root=self.data_dir) From 636de7756ae0a46429a717ee0cf94e018b110d42 Mon Sep 17 00:00:00 2001 From: saswatpp Date: Mon, 3 Jan 2022 22:35:26 +0530 Subject: [PATCH 3/5] fixed requested changes --- test/test_datasets.py | 26 +++++++++++-------- torchvision/datasets/sun397.py | 47 ++++++++++++---------------------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 889a7fdba0f..6b8b3cc4306 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2207,7 +2207,6 @@ def inject_fake_data(self, tmpdir: str, config): class SUN397TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.SUN397 - FEATURE_TYPES = (PIL.Image.Image, int) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( split=("train", "test"), @@ -2216,15 +2215,15 @@ class SUN397TestCase(datasets_utils.ImageDatasetTestCase): def inject_fake_data(self, tmpdir: str, config): data_dir = pathlib.Path(tmpdir) / "SUN397" - - data_dir.mkdir(parents=True) + data_dir.mkdir() num_images_per_class = 5 sampled_classes = ("abbey", "airplane_cabin", "airport_terminal") im_paths = [] + for cls in sampled_classes: image_folder = data_dir / cls[0] - im_paths.append( + im_paths.extend( datasets_utils.create_image_folder( image_folder, image_folder / cls, @@ -2233,16 +2232,23 @@ def inject_fake_data(self, tmpdir: str, config): ) ) - with open(data_dir / "ClassName.txt", "a") as file: - file.write("/" + cls[0] + "/" + cls + "\n") + with open(data_dir / "ClassName.txt", "w") as file: + file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes)) if config["partition"] is not None: + num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1) + with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file: - for f_pathlist in im_paths: - for f_path in f_pathlist: - file.write("/" + str(pathlib.Path(f_path).relative_to(data_dir).as_posix()) + "\n") + file.writelines( + "\n".join( + f"/{f_path.relative_to(data_dir).as_posix()}" + for f_path in random.choices(im_paths, k=num_samples) + ) + ) + else: + num_samples = len(im_paths) - return len(sampled_classes * num_images_per_class) + return num_samples if __name__ == "__main__": diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index 4be9e0e95c9..959c1b013eb 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Any, Tuple, Callable, Optional @@ -10,25 +9,25 @@ class SUN397(VisionDataset): """`The SUN397 Data Set `_. - The SUN397 is a dataset for scene recognition consisting of 397 categories with 108'754 images. - The dataset also provides 10 paritions for training and testing, with each partition - consisting of 50 images per class. + + The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of + 397 categories with 108'754 images. The dataset also provides 10 partitions for training + and testing, with each partition consisting of 50 images per class. Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. - parition (string, integer, optional): A valid partition can be an integer from 1 to 10 or ``"all"`` + parition (integer, optional): A valid partition can be an integer from 1 to 10 or None, for the entire dataset. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ - _URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" - _FILENAME = "SUN397.tar.gz" - _MD5 = "8ca2778205c41d23104230ba66911c7a" + _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" + _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a" _PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip" - _PARTITIONS_FILENAME = "Partitions.zip" + _PARTITIONS_MD5 = "29a205c0a0129d21f36cbecfefe81881" def __init__( self, @@ -54,31 +53,19 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - with open(self.data_dir / "ClassName.txt", "r") as f: + with open(self.data_dir / "ClassName.txt") as f: self.classes = [c[3:].strip() for c in f] self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) - - self._labels = [] - self._image_files = [] if self.partition is not None: with open(self.data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f: - self._image_files, self._labels = zip( - *( - ( - self.data_dir.joinpath(*posix_rel_path.split("/")), - self.class_to_idx["/".join(Path(posix_rel_path).parts[1:-1])], - ) - for posix_rel_path in (line.strip()[1:] for line in f) - ) - ) - + self._image_files = [self.data_dir.joinpath(*line.strip()[1:].split("/")) for line in f] else: - for path, _, files in os.walk(self.data_dir): - for file in files: - if file[:3] == "sun": - self._image_files.append(Path(path) / file) - self._labels.append(Path(path).relative_to(self.data_dir).as_posix()[2:]) + self._image_files = list(self.data_dir.rglob("sun_*.jpg")) + + self._labels = [ + self.class_to_idx["/".join(path.relative_to(self.data_dir).parts[1:-1])] for path in self._image_files + ] def __len__(self) -> int: return len(self._image_files) @@ -104,5 +91,5 @@ def extra_repr(self) -> str: def _download(self) -> None: if self._check_exists: return - download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) - download_and_extract_archive(self._PARTITIONS_URL, download_root=self.data_dir) + download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._MD5) + download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self.data_dir), md5=self._PARTITIONS_MD5) From 61b4f058e29b3629856392681689113288ae2eeb Mon Sep 17 00:00:00 2001 From: saswatpp Date: Thu, 6 Jan 2022 22:23:52 +0530 Subject: [PATCH 4/5] fixed issues in sun397 --- torchvision/datasets/sun397.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index 959c1b013eb..b326bccfb44 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -19,6 +19,9 @@ class SUN397(VisionDataset): split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. parition (integer, optional): A valid partition can be an integer from 1 to 10 or None, for the entire dataset. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. @@ -41,11 +44,11 @@ def __init__( super().__init__(root, transform=transform, target_transform=target_transform) self.split = verify_str_arg(split, "split", ("train", "test")) self.partition = partition - self.data_dir = Path(self.root) / "SUN397" + self._data_dir = Path(self.root) / "SUN397" if self.partition is not None: if self.partition < 0 or self.partition > 10: - raise RuntimeError("Enter a valid integer partition from 1 to 10 or None, for entire dataset") + raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.") if download: self._download() @@ -53,18 +56,18 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - with open(self.data_dir / "ClassName.txt") as f: + with open(self._data_dir / "ClassName.txt") as f: self.classes = [c[3:].strip() for c in f] self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) if self.partition is not None: - with open(self.data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f: - self._image_files = [self.data_dir.joinpath(*line.strip()[1:].split("/")) for line in f] + with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f: + self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f] else: - self._image_files = list(self.data_dir.rglob("sun_*.jpg")) + self._image_files = list(self._data_dir.rglob("sun_*.jpg")) self._labels = [ - self.class_to_idx["/".join(path.relative_to(self.data_dir).parts[1:-1])] for path in self._image_files + self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files ] def __len__(self) -> int: @@ -83,13 +86,13 @@ def __getitem__(self, idx) -> Tuple[Any, Any]: return image, label def _check_exists(self) -> bool: - return self.data_dir.exists() and self.data_dir.is_dir() + return self._data_dir.exists() and self._data_dir.is_dir() def extra_repr(self) -> str: return "Split: {split}".format(**self.__dict__) def _download(self) -> None: - if self._check_exists: + if self._check_exists(): return - download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._MD5) - download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self.data_dir), md5=self._PARTITIONS_MD5) + download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5) + download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5) From 90cee2e7ed356b4c71bc4c945a296322e1c3b9a0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 6 Jan 2022 17:59:25 +0000 Subject: [PATCH 5/5] Update torchvision/datasets/sun397.py --- torchvision/datasets/sun397.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index b326bccfb44..da34351771f 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -17,7 +17,7 @@ class SUN397(VisionDataset): Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. - parition (integer, optional): A valid partition can be an integer from 1 to 10 or None, + partition (int, optional): A valid partition can be an integer from 1 to 10 or None, for the entire dataset. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not