Skip to content

Port SBU dataset #5683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
21 changes: 21 additions & 0 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,3 +1454,24 @@ def usps(info, root, config):
fh.write("\n".join(lines).encode())

return num_samples


@register_mock
def sbu(info, root, config):
num_samples = 10

dataset_folder = pathlib.Path(root).joinpath("dataset")
dataset_folder.mkdir(parents=True)

photo_urls_file = "SBU_captioned_photo_dataset_urls.txt"
photo_captions_file = "SBU_captioned_photo_dataset_captions.txt"
with open(dataset_folder.joinpath(photo_urls_file), "w") as url_file, open(
dataset_folder.joinpath(photo_captions_file), "w"
) as caption_file:
urls = [f"https://via.placeholder.com/{random.randint(100, 1000)}.jpg" for _ in range(num_samples)]
Copy link
Collaborator

@pmeier pmeier Mar 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really cool idea and I'm definitely going to use this webiste for other things in the future 🚀 Unfortunately, we cannot have an actual download during mock data generation for two reasons:

  1. Downloading these images takes quite some time and we want the tests to be fast.
  2. Meta internal test system do not have access to the internet and thus would fail here.

I propose I send a patch for the test suite that allows us to also only generate the already preprocessed files. Thus, we only add a SBUCaptionedPhotoDataset that already includes test images. I'll ping you on the PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #5706.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier I'll wait for the PR to get merged right? I can make the necessary changes after it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry for the delay. I'll try to get it merged soon.

captions = [f"Caption {i} for {url}" for i, url in enumerate(urls)]
url_file.write("\n".join(urls))
caption_file.write("\n".join(captions))

make_tar(root, "SBUCaptionedPhotoDataset.tar.gz", dataset_folder, compression="gz")
return num_samples
17 changes: 16 additions & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.features import Image, Label
from torchvision.prototype.features import EncodedImage, Label, Image

assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
Expand Down Expand Up @@ -213,3 +213,18 @@ def test_sample_content(self, test_home, dataset_mock, config):
assert isinstance(sample["label"], Label)

assert sample["image"].shape == (1, 16, 16)


@parametrize_dataset_mocks(DATASET_MOCKS["sbu"])
class TestSBU:
def test_sample_content(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

for sample in dataset:
assert "image" in sample
assert "caption" in sample

assert isinstance(sample["image"], EncodedImage)
assert isinstance(sample["caption"], str)
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .oxford_iiit_pet import OxfordIITPet
from .pcam import PCAM
from .sbd import SBD
from .sbu import SBU
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .svhn import SVHN
Expand Down
100 changes: 100 additions & 0 deletions torchvision/prototype/datasets/_builtin/sbu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pathlib
import warnings
from typing import List, Any, Dict, Optional, Tuple, BinaryIO

from torch.utils.model_zoo import tqdm
from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, LineReader, Zipper, Mapper, IterKeyZipper
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
)
from torchvision.prototype.features import EncodedImage


class SBU(Dataset):

def _make_info(self) -> DatasetInfo:
return DatasetInfo(
name="sbu",
homepage="http://www.cs.virginia.edu/~vicente/sbucaptions/",
)

def _preprocess(self, path: pathlib.Path) -> pathlib.Path:
folder = OnlineResource._extract(path)
data_folder = folder / "dataset"
image_folder = data_folder / "images"
image_folder.mkdir()
broken_urls = []
with open(data_folder / "SBU_captioned_photo_dataset_urls.txt") as fh:
urls = fh.read().splitlines()

# TODO: Use workers to download images
for url in tqdm(urls):
try:
# TODO: suppress print statements within HttpResource.download()
HttpResource(url).download(image_folder)
except Exception:
broken_urls.append(url)

if broken_urls:
broken_urls_file = folder.parent / "broken_urls.txt"
warnings.warn(
f"Failed to download {len(broken_urls)} ({len(broken_urls) / len(urls):.2%}) images. "
f"They are logged in {broken_urls_file}."
)
with open(broken_urls_file, "w") as fh:
fh.write("\n".join(broken_urls) + "\n")

return folder

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz",
sha256="2bf37d5e1c9e1c6eae7d5103030d58a7f2117fc5e8c6aa9620f0df165acebf09",
preprocess=self._preprocess,
)
]

def _classify_files(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parent.name == "images":
return 0
elif path.name == "SBU_captioned_photo_dataset_urls.txt":
return 1
elif path.name == "SBU_captioned_photo_dataset_captions.txt":
return 2
else:
return None

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:

images_dp, urls_dp, captions_dp = Demultiplexer(
resource_dps[0], 3, self._classify_files, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)

images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)

urls_dp = LineReader(urls_dp, decode=True, return_path=False)
captions_dp = LineReader(captions_dp, decode=True, return_path=False)
anns_dp = Zipper(urls_dp, captions_dp)

dp = IterKeyZipper(images_dp, anns_dp, path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._prepare_sample)

def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[str, str]]) -> Dict[str, Any]:
(path, buffer), (_, caption) = data
return dict(
path=path,
image=EncodedImage.from_file(buffer),
caption=caption.strip(),
)