Skip to content

Commit 1f17f5f

Browse files
pmeierdatumbox
andauthored
add tests for SEMEION dataset (#3465)
* add tests for SEMEION dataset * add missing imports Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent b5f29cc commit 1f17f5f

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

test/test_datasets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import shutil
2424
import json
2525
import random
26+
import torch.nn.functional as F
2627
import string
2728
import io
2829

@@ -1155,5 +1156,22 @@ def _create_captions_txt(self, root, num_images):
11551156
fh.write(f"{datasets_utils.create_random_string(10)}\n")
11561157

11571158

1159+
class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
1160+
DATASET_CLASS = datasets.SEMEION
1161+
1162+
def inject_fake_data(self, tmpdir, config):
1163+
num_images = 3
1164+
1165+
images = torch.rand(num_images, 256)
1166+
labels = F.one_hot(torch.randint(10, size=(num_images,)))
1167+
with open(pathlib.Path(tmpdir) / "semeion.data", "w") as fh:
1168+
for image, one_hot_labels in zip(images, labels):
1169+
image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image])
1170+
labels_columns = " ".join([str(label.item()) for label in one_hot_labels])
1171+
fh.write(f"{image_columns} {labels_columns}\n")
1172+
1173+
return num_images
1174+
1175+
11581176
if __name__ == "__main__":
11591177
unittest.main()

0 commit comments

Comments
 (0)