Skip to content

Migrate USPS prototype dataset #5771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,11 +1515,11 @@ def stanford_cars(root, config):
return num_samples


# @register_mock
def usps(info, root, config):
num_samples = {"train": 15, "test": 7}[config.split]
@register_mock(configs=combinations_grid(split=("train", "test")))
def usps(root, config):
num_samples = {"train": 15, "test": 7}[config["split"]]

with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh:
with bz2.open(root / f"usps{'.t' if not config['split'] == 'train' else ''}.bz2", "wb") as fh:
lines = []
for _ in range(num_samples):
label = make_tensor(1, low=1, high=11, dtype=torch.int)
Expand Down
57 changes: 36 additions & 21 deletions torchvision/prototype/datasets/_builtin/usps.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
from typing import Any, Dict, List
import pathlib
from typing import Any, Dict, List, Union

import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource
from torchvision.prototype.datasets.utils import Dataset2, OnlineResource, HttpResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label

from .._api import register_dataset, register_info

class USPS(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"usps",
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
valid_options=dict(
split=("train", "test"),
),
categories=10,
)
NAME = "usps"


@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])


@register_dataset(NAME)
class USPS(Dataset2):
"""USPS Dataset
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
"""

def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})

self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)

_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"

Expand All @@ -29,26 +46,24 @@ def _make_info(self) -> DatasetInfo:
),
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [USPS._RESOURCES[config.split]]
def _resources(self) -> List[OnlineResource]:
return [USPS._RESOURCES[self._split]]

def _prepare_sample(self, line: str) -> Dict[str, Any]:
label, *values = line.strip().split(" ")
values = [float(value.split(":")[1]) for value in values]
pixels = torch.tensor(values).add_(1).div_(2)
return dict(
image=Image(pixels.reshape(16, 16)),
label=Label(int(label) - 1, categories=self.categories),
label=Label(int(label) - 1, categories=self._categories),
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = Decompressor(resource_dps[0])
dp = LineReader(dp, decode=True, return_path=False)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)

def __len__(self) -> int:
return 7_291 if self._split == "train" else 2_007