|
15 | 15 |
|
16 | 16 | class SUN397(Dataset):
|
17 | 17 | def _make_info(self) -> DatasetInfo:
|
| 18 | + # split can be either "all" where it get all the images (suitable for those who want custom data split) |
| 19 | + # or it can be "train-{fold}" or "test-{fold}" where fold is a number between 1 to 10, this is suitable for |
| 20 | + # those who want to reproduce the paper's result |
| 21 | + split_choices = ["all"] |
| 22 | + split_choices.extend([f"{split_type}-{fold}" for split_type in ["train", "test"] for fold in range(1, 11)]) |
| 23 | + |
| 24 | + # The default split = "all" for backward compatibility with previous datasets API |
18 | 25 | return DatasetInfo(
|
19 | 26 | "sun397",
|
20 | 27 | homepage="https://vision.princeton.edu/projects/2010/SUN/",
|
21 | 28 | valid_options=dict(
|
22 |
| - split=("train", "test"), |
23 |
| - fold=tuple(str(fold) for fold in range(1, 11)), |
| 29 | + split=split_choices, |
24 | 30 | ),
|
25 | 31 | )
|
26 | 32 |
|
@@ -50,29 +56,43 @@ def _prepare_sample(self, data: Tuple[str, Tuple[str, Any]]) -> Dict[str, Any]:
|
50 | 56 | image=EncodedImage.from_file(buffer),
|
51 | 57 | )
|
52 | 58 |
|
| 59 | + def _prepare_sample_all(self, data: Tuple[str, Any]) -> Dict[str, Any]: |
| 60 | + path, buffer = data |
| 61 | + key = self._image_key(data) |
| 62 | + category = "/".join(key.split("/")[2:-1]) |
| 63 | + return dict( |
| 64 | + label=Label.from_category(category, categories=self.categories), |
| 65 | + path=path, |
| 66 | + image=EncodedImage.from_file(buffer), |
| 67 | + ) |
| 68 | + |
53 | 69 | def _make_datapipe(
|
54 | 70 | self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
|
55 | 71 | ) -> IterDataPipe[Dict[str, Any]]:
|
56 | 72 | images_dp, splits_dp = resource_dps
|
57 | 73 |
|
58 |
| - splits_dp = Filter( |
59 |
| - splits_dp, |
60 |
| - path_comparator("name", f"{config.split.capitalize()}ing_{int(config.fold):02d}.txt"), |
61 |
| - ) |
62 |
| - splits_dp = LineReader(splits_dp, decode=True, return_path=False) |
63 |
| - splits_dp = hint_sharding(splits_dp) |
64 |
| - splits_dp = hint_shuffling(splits_dp) |
65 |
| - |
66 | 74 | images_dp = Filter(images_dp, path_comparator("suffix", ".jpg"))
|
| 75 | + if config.split == "all": |
| 76 | + dp = images_dp |
| 77 | + return Mapper(dp, self._prepare_sample_all) |
| 78 | + else: |
| 79 | + split_type, fold = config.split.split("-") |
| 80 | + splits_dp = Filter( |
| 81 | + splits_dp, |
| 82 | + path_comparator("name", f"{split_type.capitalize()}ing_{int(fold):02d}.txt"), |
| 83 | + ) |
| 84 | + splits_dp = LineReader(splits_dp, decode=True, return_path=False) |
| 85 | + splits_dp = hint_sharding(splits_dp) |
| 86 | + splits_dp = hint_shuffling(splits_dp) |
67 | 87 |
|
68 |
| - dp = IterKeyZipper( |
69 |
| - splits_dp, |
70 |
| - images_dp, |
71 |
| - key_fn=getitem(), |
72 |
| - ref_key_fn=self._image_key, |
73 |
| - buffer_size=INFINITE_BUFFER_SIZE, |
74 |
| - ) |
75 |
| - return Mapper(dp, self._prepare_sample) |
| 88 | + dp = IterKeyZipper( |
| 89 | + splits_dp, |
| 90 | + images_dp, |
| 91 | + key_fn=getitem(), |
| 92 | + ref_key_fn=self._image_key, |
| 93 | + buffer_size=INFINITE_BUFFER_SIZE, |
| 94 | + ) |
| 95 | + return Mapper(dp, self._prepare_sample) |
76 | 96 |
|
77 | 97 | def _generate_categories(self, root: pathlib.Path) -> List[str]:
|
78 | 98 | resources = self.resources(self.default_config)
|
|
0 commit comments