Skip to content

USPS dataset #5647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0606674
added usps dataset
lezwon Mar 19, 2022
ed267f6
fixed type issues
lezwon Mar 19, 2022
16123ef
fix mobilnet norm layer test (#5643)
pmeier Mar 20, 2022
a37b3e9
More robust check in tests for 16 bits images (#5652)
NicolasHug Mar 21, 2022
7fd2ea0
Prefer nvidia channel for conda builds (#5648)
malfet Mar 21, 2022
2efd0f2
fix torchdata CI installation (#5657)
pmeier Mar 22, 2022
834c8d9
update urls for kinetics dataset (#5578)
sahilg06 Mar 22, 2022
2ef164a
Port Multi-weight support from prototype to main (#5618)
datumbox Mar 22, 2022
f7d0a50
Apply suggestions from code review
lezwon Mar 23, 2022
85bc5fd
use decompressor for extracting bz2
lezwon Mar 23, 2022
0b46877
Merge branch 'main' into 5352-usps-dataset
lezwon Mar 23, 2022
68f15ba
Apply suggestions from code review
lezwon Mar 23, 2022
aefad05
Apply suggestions from code review
lezwon Mar 23, 2022
e1ec56c
Merge branch 'main' into 5352-usps-dataset
lezwon Mar 23, 2022
81cfba9
fixed lint fails
lezwon Mar 24, 2022
e847fde
added tests for USPS
lezwon Mar 24, 2022
96da894
Merge branch 'main' into 5352-usps-dataset
lezwon Mar 24, 2022
fe7c573
check image shape
lezwon Mar 24, 2022
b1deb63
fix tests
lezwon Mar 24, 2022
9a08f51
Merge branch 'main' into 5352-usps-dataset
lezwon Mar 24, 2022
0624de9
check shape on image directly
lezwon Mar 24, 2022
d5d9386
Apply suggestions from code review
lezwon Mar 25, 2022
7324861
removed test and comments
lezwon Mar 25, 2022
d23863c
Update test/test_prototype_builtin_datasets.py
pmeier Mar 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import bz2
import collections.abc
import csv
import functools
Expand Down Expand Up @@ -1431,3 +1432,21 @@ def stanford_cars(info, root, config):
make_tar(root, "car_devkit.tgz", devkit, compression="gz")

return num_samples


@register_mock
def usps(info, root, config):
num_samples = {"train": 15, "test": 7}[config.split]

with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh:
lines = []
for _ in range(num_samples):
label = make_tensor(1, low=1, high=11, dtype=torch.int)
values = make_tensor(256, low=-1, high=1, dtype=torch.float)
lines.append(
" ".join([f"{int(label)}", *(f"{idx}:{float(value):.6f}" for idx, value in enumerate(values, 1))])
)

fh.write("\n".join(lines).encode())

return num_samples
19 changes: 18 additions & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets

from torchvision.prototype.features import Image, Label

assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
Expand Down Expand Up @@ -180,3 +180,20 @@ def test_label_matches_path(self, test_home, dataset_mock, config):
for sample in dataset:
label_from_path = int(Path(sample["path"]).parent.name)
assert sample["label"] == label_from_path


@parametrize_dataset_mocks(DATASET_MOCKS["usps"])
class TestUSPS:
def test_sample_content(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

for sample in dataset:
assert "image" in sample
assert "label" in sample

assert isinstance(sample["image"], Image)
assert isinstance(sample["label"], Label)

assert sample["image"].shape == (1, 16, 16)
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .svhn import SVHN
from .usps import USPS
from .voc import VOC
54 changes: 54 additions & 0 deletions torchvision/prototype/datasets/_builtin/usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Dict, List

import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label


class USPS(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"usps",
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
valid_options=dict(
split=("train", "test"),
),
categories=10,
)

_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"

_RESOURCES = {
"train": HttpResource(
f"{_URL}/usps.bz2", sha256="3771e9dd6ba685185f89867b6e249233dd74652389f263963b3b741e994b034f"
),
"test": HttpResource(
f"{_URL}/usps.t.bz2", sha256="a9c0164e797d60142a50604917f0baa604f326e9a689698763793fa5d12ffc4e"
),
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [USPS._RESOURCES[config.split]]

def _prepare_sample(self, line: str) -> Dict[str, Any]:
label, *values = line.strip().split(" ")
values = [float(value.split(":")[1]) for value in values]
pixels = torch.tensor(values).add_(1).div_(2)
return dict(
image=Image(pixels.reshape(16, 16)),
label=Label(int(label) - 1, categories=self.categories),
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
dp = Decompressor(resource_dps[0])
dp = LineReader(dp, decode=True, return_path=False)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_sample)