Skip to content

Commit 36957ff

Browse files
committed
fix semeion mock data
1 parent aaff6f6 commit 36957ff

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

test/builtin_dataset_mocks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,15 @@ def sbd(info, root, config):
668668
@register_mock
669669
def semeion(info, root, config):
670670
num_samples = 3
671+
num_categories = len(info.categories)
671672

672673
images = torch.rand(num_samples, 256)
673-
labels = one_hot(torch.randint(len(info.categories), size=(num_samples,)))
674+
labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories)
674675
with open(root / "semeion.data", "w") as fh:
675676
for image, one_hot_label in zip(images, labels):
676677
image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image])
677678
labels_columns = " ".join([str(label.item()) for label in one_hot_label])
678-
fh.write(f"{image_columns} {labels_columns}\n")
679+
fh.write(f"{image_columns} {labels_columns} \n")
679680

680681
return num_samples
681682

0 commit comments

Comments
 (0)