Skip to content

Commit 23f685a

Browse files
committed
Fix bug in tests and in food101 dataset
1 parent 1407dbd commit 23f685a

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

test/test_datasets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,11 +2194,13 @@ def inject_fake_data(self, tmpdir: str, config):
21942194
file_name_fn=lambda idx: f"{idx}.jpg",
21952195
num_examples=num_images_per_class,
21962196
)
2197-
metadata[cls] = [str(fname) for fname in random.choices(im_fnames, k=n_samples_per_class)]
2197+
metadata[cls] = [
2198+
fname.parent.name + "/" + fname.name for fname in random.choices(im_fnames, k=n_samples_per_class)
2199+
]
21982200

21992201
with open(meta_folder / f"{config['split']}.json", "w") as file:
22002202
file.write(json.dumps(metadata))
2201-
2203+
print(metadata)
22022204
return len(sampled_classes * n_samples_per_class)
22032205

22042206

torchvision/datasets/food101.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def __init__(
5858
self.classes = sorted(metadata.keys())
5959
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
6060

61-
for class_label, im_ids in metadata.items():
62-
self._labels += [self.class_to_idx[class_label]] * len(im_ids)
63-
self._image_files += [self._images_folder.joinpath(*f"{im_id}.jpg".split("/")) for im_id in im_ids]
61+
for class_label, im_rel_paths in metadata.items():
62+
self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
63+
self._image_files += [
64+
self._images_folder.joinpath(*f"{im_rel_path}".split("/")) for im_rel_path in im_rel_paths
65+
]
6466

6567
def __len__(self) -> int:
6668
return len(self._image_files)

0 commit comments

Comments
 (0)