Skip to content

Commit 998daf6

Browse files
committed
enhance partition utilities
Signed-off-by: Wenqi Li <[email protected]>
1 parent 8e5f154 commit 998daf6

File tree

2 files changed

+101
-89
lines changed

2 files changed

+101
-89
lines changed

monai/data/utils.py

Lines changed: 100 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from collections import defaultdict
1616
from itertools import product, starmap
1717
from pathlib import PurePath
18-
from typing import Dict, Generator, List, Optional, Sequence, Tuple, Union
18+
from typing import Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
1919

2020
import numpy as np
2121
import torch
@@ -570,21 +570,27 @@ def partition_dataset(
570570
Args:
571571
data: input dataset to split, expect a list of data.
572572
ratios: a list of ratio number to split the dataset, like [8, 1, 1].
573-
num_partitions: expected number of the partitions to evenly split, only works when no `ratios`.
573+
num_partitions: expected number of the partitions to evenly split, only works when `ratios` not specified.
574574
shuffle: whether to shuffle the original dataset before splitting.
575575
seed: random seed to shuffle the dataset, only works when `shuffle` is True.
576576
drop_last: only works when `even_divisible` is False and no ratios specified.
577577
if True, will drop the tail of the data to make it evenly divisible across partitions.
578578
if False, will add extra indices to make the data evenly divisible across partitions.
579579
even_divisible: if True, guarantee every partition has same length.
580580
581-
Examples:
582-
data: [1, 2, 3, 4, 5]
583-
(1) ratios: [0.6, 0.2, 0.2], shuffle=False, output: [[1, 2, 3], [4], [5]]
584-
num_partitions=2, shuffle=False
585-
(2) even_divisible=True, drop_last=True, output: [[1, 3], [2, 4]]
586-
(3) even_divisible=True, drop_last=False, output: [[1, 3, 5], [2, 4, 1]]
587-
(4) even_divisible=False, drop_last=False, output: [[1, 3, 5], [2, 4]]
581+
Examples::
582+
583+
>>> data = [1, 2, 3, 4, 5]
584+
>>> partition_dataset(data, ratios=[0.6, 0.2, 0.2], shuffle=False)
585+
[[1, 2, 3], [4], [5]]
586+
>>> partition_dataset(data, num_partitions=2, shuffle=False)
587+
[[1, 3, 5], [2, 4]]
588+
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=True)
589+
[[1, 3], [2, 4]]
590+
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=False)
591+
[[1, 3, 5], [2, 4, 1]]
592+
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=False, drop_last=False)
593+
[[1, 3, 5], [2, 4]]
588594
589595
"""
590596
data_len = len(data)
@@ -593,52 +599,53 @@ def partition_dataset(
593599
indices = list(range(data_len))
594600
if shuffle:
595601
# deterministically shuffle based on fixed seed for every process
596-
np.random.seed(seed)
597-
np.random.shuffle(indices)
602+
rs = np.random.RandomState(seed)
603+
rs.shuffle(indices)
598604

599-
if ratios is not None:
600-
start_idx = next_idx = 0
605+
if ratios:
606+
next_idx = 0
601607
rsum = sum(ratios)
602608
for r in ratios:
603609
start_idx = next_idx
604610
next_idx = min(start_idx + int(r / rsum * data_len + 0.5), data_len)
605611
datasets.append([data[i] for i in indices[start_idx:next_idx]])
612+
return datasets
613+
614+
if not num_partitions:
615+
raise ValueError("must specify number of partitions or ratios.")
616+
# evenly split the data without ratios
617+
if not even_divisible and drop_last:
618+
raise RuntimeError("drop_last only works when even_divisible is True.")
619+
if data_len < num_partitions:
620+
raise RuntimeError(f"there is no enough data to be split into {num_partitions} partitions.")
621+
622+
if drop_last and data_len % num_partitions != 0:
623+
# split to nearest available length that is evenly divisible
624+
num_samples = math.ceil((data_len - num_partitions) / num_partitions)
606625
else:
607-
if num_partitions is None:
608-
raise ValueError("must specify number of partitions.")
609-
# evenly split the data without ratios
610-
if not even_divisible and drop_last:
611-
raise RuntimeError("drop_last only works when even_divisible is True.")
612-
if data_len < num_partitions:
613-
raise RuntimeError(f"there is no enough data to be splitted for {num_partitions} partitions.")
614-
615-
if drop_last and data_len % num_partitions != 0:
616-
# split to nearest available length that is evenly divisible
617-
num_samples = math.ceil((data_len - num_partitions) / num_partitions)
618-
else:
619-
num_samples = math.ceil(data_len / num_partitions)
620-
# use original data length if not even divisible
621-
total_size = num_samples * num_partitions if even_divisible else data_len
626+
num_samples = math.ceil(data_len / num_partitions)
627+
# use original data length if not even divisible
628+
total_size = num_samples * num_partitions if even_divisible else data_len
622629

623-
if not drop_last and total_size - data_len > 0:
624-
# add extra samples to make it evenly divisible
625-
indices += indices[: (total_size - data_len)]
626-
else:
627-
# remove tail of data to make it evenly divisible
628-
indices = indices[:total_size]
630+
if not drop_last and total_size - data_len > 0:
631+
# add extra samples to make it evenly divisible
632+
indices += indices[: (total_size - data_len)]
633+
else:
634+
# remove tail of data to make it evenly divisible
635+
indices = indices[:total_size]
629636

630-
for i in range(num_partitions):
631-
_indices = indices[i:total_size:num_partitions]
632-
datasets.append([data[j] for j in _indices])
637+
for i in range(num_partitions):
638+
_indices = indices[i:total_size:num_partitions]
639+
datasets.append([data[j] for j in _indices])
633640

634641
return datasets
635642

636643

637644
def partition_dataset_classes(
638645
data: Sequence,
639646
classes: Sequence[int],
640-
ratios: Optional[Sequence[float]],
641-
num_partitions: Optional[int],
647+
ratios: Optional[Sequence[float]] = None,
648+
num_partitions: Optional[int] = None,
642649
shuffle: bool = False,
643650
seed: int = 0,
644651
drop_last: bool = False,
@@ -661,67 +668,72 @@ def partition_dataset_classes(
661668
if False, will add extra indices to make the data evenly divisible across partitions.
662669
even_divisible: if True, guarantee every partition has same length.
663670
664-
Examples:
665-
data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
666-
classes: [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
667-
shuffle: False, ratios: [2, 1]
668-
output: [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]
671+
Examples::
672+
673+
>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
674+
>>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
675+
>>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1])
676+
[[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]
669677
670678
"""
671-
data_len = len(data)
672679
datasets = list()
680+
if not classes:
681+
return []
682+
if len(classes) != len(data):
683+
raise ValueError(f"length of classes {len(classes)} must match the dataset length {len(data)}.")
684+
class_indices = defaultdict(list)
685+
for i, c in enumerate(classes):
686+
class_indices[c].append(i)
687+
688+
class_partition_indices: List[Sequence] = list()
689+
for _, per_class_indices in sorted(class_indices.items()):
690+
per_class_partition_indices = partition_dataset(
691+
data=per_class_indices,
692+
ratios=ratios,
693+
num_partitions=num_partitions,
694+
shuffle=shuffle,
695+
seed=seed,
696+
drop_last=drop_last,
697+
even_divisible=even_divisible,
698+
)
699+
if len(class_partition_indices) == 0:
700+
class_partition_indices = per_class_partition_indices
701+
else:
702+
for part, data_indices in zip(class_partition_indices, per_class_partition_indices):
703+
part += data_indices
673704

674-
if classes is not None:
675-
if len(classes) != data_len:
676-
raise ValueError("length of classes must match the dataset length.")
677-
class_indices = defaultdict(list)
678-
for i, c in enumerate(classes):
679-
class_indices[c].append(i)
680-
681-
class_partition_indices: List[Sequence] = list()
682-
for _, per_class_indices in sorted(class_indices.items()):
683-
per_class_partition_indices = partition_dataset(
684-
data=per_class_indices,
685-
ratios=ratios,
686-
num_partitions=num_partitions,
687-
shuffle=shuffle,
688-
seed=seed,
689-
drop_last=drop_last,
690-
even_divisible=even_divisible,
691-
)
692-
if len(class_partition_indices) == 0:
693-
class_partition_indices = per_class_partition_indices
694-
else:
695-
for part, data_indices in zip(class_partition_indices, per_class_partition_indices):
696-
part += data_indices
697-
698-
for indices in class_partition_indices:
699-
if shuffle:
700-
np.random.seed(seed)
701-
np.random.shuffle(indices)
702-
datasets.append([data[j] for j in indices])
705+
rs = np.random.RandomState(seed)
706+
for indices in class_partition_indices:
707+
if shuffle:
708+
rs.shuffle(indices)
709+
datasets.append([data[j] for j in indices])
703710

704711
return datasets
705712

706713

707-
def select_cross_validation_folds(partitions: Sequence[List], folds: Union[Sequence[int], int]) -> List:
714+
def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[Sequence[int], int]) -> List:
708715
"""
709-
Select cross validation data based on data partitions and specified fold indice.
710-
if a list of folds provided, concatenate the partitions of these folds.
711-
For example, `partitions`: [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] and `folds`: 2
712-
The the output will be [5, 6], if `folds`: [1, 2], output will be [3, 4, 5, 6].
716+
Select cross validation data based on data partitions and specified fold index.
717+
if a list of fold indices is provided, concatenate the partitions of these folds.
713718
714-
"""
715-
folds = ensure_tuple(folds)
716-
for fold in folds:
717-
if fold >= len(partitions):
718-
raise ValueError(f"fold index: {fold} is bigger than number of partitions.")
719+
Args:
720+
partitions: a sequence of datasets, each item is a iterable
721+
folds: the indices of the partitions to be combined.
722+
723+
Returns:
724+
A list of combined datasets.
719725
720-
data_list = list()
721-
for i, data in enumerate(partitions):
722-
if i in folds:
723-
data_list += data
726+
Example::
724727
728+
>>> partitions = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
729+
>>> select_cross_validation_folds(partitions, 2)
730+
[5, 6]
731+
>>> select_cross_validation_folds(partitions, [1, 2])
732+
[3, 4, 5, 6]
733+
>>> select_cross_validation_folds(partitions, [-1, 2])
734+
[9, 10, 5, 6]
735+
"""
736+
data_list = [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]]
725737
return data_list
726738

727739

tests/test_partition_dataset_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
"drop_last": False,
6969
"even_divisible": False,
7070
},
71-
[[13, 7, 14, 2, 3], [11, 8, 6, 5, 9, 10, 12, 1, 4]],
71+
[[13, 7, 14, 2, 3], [6, 8, 1, 5, 12, 11, 4, 9, 10]],
7272
]
7373

7474

0 commit comments

Comments
 (0)