Skip to content

refactor prototype datasets to inherit from IterDataPipe #5448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 24, 2022
Merged
100 changes: 54 additions & 46 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
import pathlib
import pickle
import random
import unittest.mock
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter

import numpy as np
import PIL.Image
import pytest
import torch
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision.prototype.datasets._api import find
from torchvision.prototype import datasets
from torchvision.prototype.utils._internal import sequence_to_str

make_tensor = functools.partial(_make_tensor, device="cpu")
Expand All @@ -30,13 +31,11 @@


class DatasetMock:
def __init__(self, name, mock_data_fn):
self.dataset = find(name)
self.info = self.dataset.info
self.name = self.info.name

def __init__(self, name, *, mock_data_fn, configs):
# FIXME: error handling for unknown names
self.name = name
self.mock_data_fn = mock_data_fn
self.configs = self.info._configs
self.configs = configs

def _parse_mock_info(self, mock_info):
if mock_info is None:
Expand Down Expand Up @@ -65,10 +64,13 @@ def prepare(self, home, config):
root = home / self.name
root.mkdir(exist_ok=True)

mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config))
mock_info = self._parse_mock_info(self.mock_data_fn(root, config))

with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
required_file_names = {
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
}
available_file_names = {path.name for path in root.glob("*")}
required_file_names = {resource.file_name for resource in self.dataset.resources(config)}
missing_file_names = required_file_names - available_file_names
if missing_file_names:
raise pytest.UsageError(
Expand Down Expand Up @@ -123,10 +125,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
DATASET_MOCKS = {}


def register_mock(fn):
name = fn.__name__.replace("_", "-")
DATASET_MOCKS[name] = DatasetMock(name, fn)
return fn
def register_mock(name=None, *, configs):
def wrapper(mock_data_fn):
nonlocal name
if name is None:
name = mock_data_fn.__name__
DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs)

return mock_data_fn

return wrapper


class MNISTMockData:
Expand Down Expand Up @@ -204,7 +212,7 @@ def generate(
return num_samples


@register_mock
# @register_mock
def mnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
Expand All @@ -217,10 +225,10 @@ def mnist(info, root, config):
)


DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})


@register_mock
# @register_mock
def emnist(info, root, config):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
Expand All @@ -247,7 +255,7 @@ def emnist(info, root, config):
return num_samples_map[config]


@register_mock
# @register_mock
def qmnist(info, root, config):
num_categories = len(info.categories)
if config.split == "train":
Expand Down Expand Up @@ -324,7 +332,7 @@ def generate(
make_tar(root, name, folder, compression="gz")


@register_mock
# @register_mock
def cifar10(info, root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"]
Expand All @@ -342,7 +350,7 @@ def cifar10(info, root, config):
return len(train_files if config.split == "train" else test_files)


@register_mock
# @register_mock
def cifar100(info, root, config):
train_files = ["train"]
test_files = ["test"]
Expand All @@ -360,7 +368,7 @@ def cifar100(info, root, config):
return len(train_files if config.split == "train" else test_files)


@register_mock
# @register_mock
def caltech101(info, root, config):
def create_ann_file(root, name):
import scipy.io
Expand Down Expand Up @@ -410,7 +418,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
return num_images_per_category * len(info.categories)


@register_mock
# @register_mock
def caltech256(info, root, config):
dir = root / "256_ObjectCategories"
num_images_per_category = 2
Expand All @@ -430,26 +438,26 @@ def caltech256(info, root, config):
return num_images_per_category * len(info.categories)


@register_mock
def imagenet(info, root, config):
@register_mock(configs=combinations_grid(split=("train", "val", "test")))
def imagenet(root, config):
from scipy.io import savemat

categories = info.categories
wnids = [info.extra.category_to_wnid[category] for category in categories]
if config.split == "train":
num_samples = len(wnids)
info = datasets.info("imagenet")

if config["split"] == "train":
num_samples = len(info["wnids"])
archive_name = "ILSVRC2012_img_train.tar"

files = []
for wnid in wnids:
for wnid in info["wnids"]:
create_image_folder(
root=root,
name=wnid,
file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG",
num_examples=1,
)
files.append(make_tar(root, f"{wnid}.tar"))
elif config.split == "val":
elif config["split"] == "val":
num_samples = 3
archive_name = "ILSVRC2012_img_val.tar"
files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
Expand All @@ -459,20 +467,20 @@ def imagenet(info, root, config):
data_root.mkdir(parents=True)

with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist():
file.write(f"{label}\n")

num_children = 0
synsets = [
(idx, wnid, category, "", num_children, [], 0, 0)
for idx, (category, wnid) in enumerate(zip(categories, wnids), 1)
for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1)
]
num_children = 1
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
savemat(data_root / "meta.mat", dict(synsets=synsets))

make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
else: # config.split == "test"
else: # config["split"] == "test"
num_samples = 5
archive_name = "ILSVRC2012_img_test_v10102019.tar"
files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
Expand Down Expand Up @@ -587,7 +595,7 @@ def generate(
return num_samples


@register_mock
# @register_mock
def coco(info, root, config):
return CocoMockData.generate(root, year=config.year, num_samples=5)

Expand Down Expand Up @@ -661,12 +669,12 @@ def generate(cls, root):
return num_samples_map


@register_mock
# @register_mock
def sbd(info, root, config):
return SBDMockData.generate(root)[config.split]


@register_mock
# @register_mock
def semeion(info, root, config):
num_samples = 3
num_categories = len(info.categories)
Expand Down Expand Up @@ -779,7 +787,7 @@ def generate(cls, root, *, year, trainval):
return num_samples_map


@register_mock
# @register_mock
def voc(info, root, config):
trainval = config.split != "test"
return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split]
Expand Down Expand Up @@ -873,12 +881,12 @@ def generate(cls, root):
return num_samples_map


@register_mock
# @register_mock
def celeba(info, root, config):
return CelebAMockData.generate(root)[config.split]


@register_mock
# @register_mock
def dtd(info, root, config):
data_folder = root / "dtd"

Expand Down Expand Up @@ -926,7 +934,7 @@ def dtd(info, root, config):
return num_samples_map[config]


@register_mock
# @register_mock
def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3

Expand All @@ -951,7 +959,7 @@ def fer2013(info, root, config):
return num_samples


@register_mock
# @register_mock
def gtsrb(info, root, config):
num_examples_per_class = 5 if config.split == "train" else 3
classes = ("00000", "00042", "00012")
Expand Down Expand Up @@ -1021,7 +1029,7 @@ def _make_ann_file(path, num_examples, class_idx):
return num_examples


@register_mock
# @register_mock
def clevr(info, root, config):
data_folder = root / "CLEVR_v1.0"

Expand Down Expand Up @@ -1127,7 +1135,7 @@ def generate(self, root):
return num_samples_map


@register_mock
# @register_mock
def oxford_iiit_pet(info, root, config):
return OxfordIIITPetMockData.generate(root)[config.split]

Expand Down Expand Up @@ -1293,13 +1301,13 @@ def generate(cls, root):
return num_samples_map


@register_mock
# @register_mock
def cub200(info, root, config):
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
return num_samples_map[config.split]


@register_mock
# @register_mock
def svhn(info, root, config):
import scipy.io as sio

Expand All @@ -1319,7 +1327,7 @@ def svhn(info, root, config):
return num_samples


@register_mock
# @register_mock
def pcam(info, root, config):
import h5py

Expand Down
Loading