Skip to content

Commit ddb1887

Browse files
committed
Change the valid_options to make sure it is backward compatible
1 parent 701539f commit ddb1887

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

test/builtin_dataset_mocks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,9 @@ def sun397(info, root, config):
14801480
random.shuffle(keys)
14811481

14821482
for split, keys_in_split in zip(splits, random_subsets(keys, len(splits))):
1483-
if split == config.split and str(fold) == config.fold:
1483+
if config.split == "all":
1484+
num_samples = len(keys)
1485+
elif config.split == f"{split}-{fold}":
14841486
num_samples = len(keys_in_split)
14851487

14861488
with open(partitions_root / f"{split.capitalize()}ing_{fold:02d}.txt", "w") as fh:

torchvision/prototype/datasets/_builtin/sun397.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,18 @@
1515

1616
class SUN397(Dataset):
1717
def _make_info(self) -> DatasetInfo:
18+
# split can be either "all" where it get all the images (suitable for those who want custom data split)
19+
# or it can be "train-{fold}" or "test-{fold}" where fold is a number between 1 to 10, this is suitable for
20+
# those who want to reproduce the paper's result
21+
split_choices = ["all"]
22+
split_choices.extend([f"{split_type}-{fold}" for split_type in ["train", "test"] for fold in range(1, 11)])
23+
24+
# The default split = "all" for backward compatibility with previous datasets API
1825
return DatasetInfo(
1926
"sun397",
2027
homepage="https://vision.princeton.edu/projects/2010/SUN/",
2128
valid_options=dict(
22-
split=("train", "test"),
23-
fold=tuple(str(fold) for fold in range(1, 11)),
29+
split=split_choices,
2430
),
2531
)
2632

@@ -50,29 +56,43 @@ def _prepare_sample(self, data: Tuple[str, Tuple[str, Any]]) -> Dict[str, Any]:
5056
image=EncodedImage.from_file(buffer),
5157
)
5258

59+
def _prepare_sample_all(self, data: Tuple[str, Any]) -> Dict[str, Any]:
60+
path, buffer = data
61+
key = self._image_key(data)
62+
category = "/".join(key.split("/")[2:-1])
63+
return dict(
64+
label=Label.from_category(category, categories=self.categories),
65+
path=path,
66+
image=EncodedImage.from_file(buffer),
67+
)
68+
5369
def _make_datapipe(
5470
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
5571
) -> IterDataPipe[Dict[str, Any]]:
5672
images_dp, splits_dp = resource_dps
5773

58-
splits_dp = Filter(
59-
splits_dp,
60-
path_comparator("name", f"{config.split.capitalize()}ing_{int(config.fold):02d}.txt"),
61-
)
62-
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
63-
splits_dp = hint_sharding(splits_dp)
64-
splits_dp = hint_shuffling(splits_dp)
65-
6674
images_dp = Filter(images_dp, path_comparator("suffix", ".jpg"))
75+
if config.split == "all":
76+
dp = images_dp
77+
return Mapper(dp, self._prepare_sample_all)
78+
else:
79+
split_type, fold = config.split.split("-")
80+
splits_dp = Filter(
81+
splits_dp,
82+
path_comparator("name", f"{split_type.capitalize()}ing_{int(fold):02d}.txt"),
83+
)
84+
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
85+
splits_dp = hint_sharding(splits_dp)
86+
splits_dp = hint_shuffling(splits_dp)
6787

68-
dp = IterKeyZipper(
69-
splits_dp,
70-
images_dp,
71-
key_fn=getitem(),
72-
ref_key_fn=self._image_key,
73-
buffer_size=INFINITE_BUFFER_SIZE,
74-
)
75-
return Mapper(dp, self._prepare_sample)
88+
dp = IterKeyZipper(
89+
splits_dp,
90+
images_dp,
91+
key_fn=getitem(),
92+
ref_key_fn=self._image_key,
93+
buffer_size=INFINITE_BUFFER_SIZE,
94+
)
95+
return Mapper(dp, self._prepare_sample)
7696

7797
def _generate_categories(self, root: pathlib.Path) -> List[str]:
7898
resources = self.resources(self.default_config)

0 commit comments

Comments
 (0)