diff --git a/monai/data/utils.py b/monai/data/utils.py index 959911ba7b..0b274ac625 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -15,7 +15,7 @@ from collections import defaultdict from itertools import product, starmap from pathlib import PurePath -from typing import Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -570,7 +570,7 @@ def partition_dataset( Args: data: input dataset to split, expect a list of data. ratios: a list of ratio number to split the dataset, like [8, 1, 1]. - num_partitions: expected number of the partitions to evenly split, only works when no `ratios`. + num_partitions: expected number of the partitions to evenly split, only works when `ratios` not specified. shuffle: whether to shuffle the original dataset before splitting. seed: random seed to shuffle the dataset, only works when `shuffle` is True. drop_last: only works when `even_divisible` is False and no ratios specified. @@ -578,13 +578,19 @@ def partition_dataset( if False, will add extra indices to make the data evenly divisible across partitions. even_divisible: if True, guarantee every partition has same length. - Examples: - data: [1, 2, 3, 4, 5] - (1) ratios: [0.6, 0.2, 0.2], shuffle=False, output: [[1, 2, 3], [4], [5]] - num_partitions=2, shuffle=False - (2) even_divisible=True, drop_last=True, output: [[1, 3], [2, 4]] - (3) even_divisible=True, drop_last=False, output: [[1, 3, 5], [2, 4, 1]] - (4) even_divisible=False, drop_last=False, output: [[1, 3, 5], [2, 4]] + Examples:: + + >>> data = [1, 2, 3, 4, 5] + >>> partition_dataset(data, ratios=[0.6, 0.2, 0.2], shuffle=False) + [[1, 2, 3], [4], [5]] + >>> partition_dataset(data, num_partitions=2, shuffle=False) + [[1, 3, 5], [2, 4]] + >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=True) + [[1, 3], [2, 4]] + >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=False) + [[1, 3, 5], [2, 4, 1]] + >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=False, drop_last=False) + [[1, 3, 5], [2, 4]] """ data_len = len(data) @@ -593,43 +599,44 @@ def partition_dataset( indices = list(range(data_len)) if shuffle: # deterministically shuffle based on fixed seed for every process - np.random.seed(seed) - np.random.shuffle(indices) + rs = np.random.RandomState(seed) + rs.shuffle(indices) - if ratios is not None: - start_idx = next_idx = 0 + if ratios: + next_idx = 0 rsum = sum(ratios) for r in ratios: start_idx = next_idx next_idx = min(start_idx + int(r / rsum * data_len + 0.5), data_len) datasets.append([data[i] for i in indices[start_idx:next_idx]]) + return datasets + + if not num_partitions: + raise ValueError("must specify number of partitions or ratios.") + # evenly split the data without ratios + if not even_divisible and drop_last: + raise RuntimeError("drop_last only works when even_divisible is True.") + if data_len < num_partitions: + raise RuntimeError(f"there is no enough data to be split into {num_partitions} partitions.") + + if drop_last and data_len % num_partitions != 0: + # split to nearest available length that is evenly divisible + num_samples = math.ceil((data_len - num_partitions) / num_partitions) else: - if num_partitions is None: - raise ValueError("must specify number of partitions.") - # evenly split the data without ratios - if not even_divisible and drop_last: - raise RuntimeError("drop_last only works when even_divisible is True.") - if data_len < num_partitions: - raise RuntimeError(f"there is no enough data to be splitted for {num_partitions} partitions.") - - if drop_last and data_len % num_partitions != 0: - # split to nearest available length that is evenly divisible - num_samples = math.ceil((data_len - num_partitions) / num_partitions) - else: - num_samples = math.ceil(data_len / num_partitions) - # use original data length if not even divisible - total_size = num_samples * num_partitions if even_divisible else data_len + num_samples = math.ceil(data_len / num_partitions) + # use original data length if not even divisible + total_size = num_samples * num_partitions if even_divisible else data_len - if not drop_last and total_size - data_len > 0: - # add extra samples to make it evenly divisible - indices += indices[: (total_size - data_len)] - else: - # remove tail of data to make it evenly divisible - indices = indices[:total_size] + if not drop_last and total_size - data_len > 0: + # add extra samples to make it evenly divisible + indices += indices[: (total_size - data_len)] + else: + # remove tail of data to make it evenly divisible + indices = indices[:total_size] - for i in range(num_partitions): - _indices = indices[i:total_size:num_partitions] - datasets.append([data[j] for j in _indices]) + for i in range(num_partitions): + _indices = indices[i:total_size:num_partitions] + datasets.append([data[j] for j in _indices]) return datasets @@ -637,8 +644,8 @@ def partition_dataset( def partition_dataset_classes( data: Sequence, classes: Sequence[int], - ratios: Optional[Sequence[float]], - num_partitions: Optional[int], + ratios: Optional[Sequence[float]] = None, + num_partitions: Optional[int] = None, shuffle: bool = False, seed: int = 0, drop_last: bool = False, @@ -661,67 +668,70 @@ def partition_dataset_classes( if False, will add extra indices to make the data evenly divisible across partitions. even_divisible: if True, guarantee every partition has same length. - Examples: - data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] - classes: [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3] - shuffle: False, ratios: [2, 1] - output: [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]] + Examples:: + + >>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + >>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3] + >>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1]) + [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]] """ - data_len = len(data) + if not classes or len(classes) != len(data): + raise ValueError(f"length of classes {classes} must match the dataset length {len(data)}.") datasets = list() + class_indices = defaultdict(list) + for i, c in enumerate(classes): + class_indices[c].append(i) + + class_partition_indices: List[Sequence] = list() + for _, per_class_indices in sorted(class_indices.items()): + per_class_partition_indices = partition_dataset( + data=per_class_indices, + ratios=ratios, + num_partitions=num_partitions, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + even_divisible=even_divisible, + ) + if len(class_partition_indices) == 0: + class_partition_indices = per_class_partition_indices + else: + for part, data_indices in zip(class_partition_indices, per_class_partition_indices): + part += data_indices - if classes is not None: - if len(classes) != data_len: - raise ValueError("length of classes must match the dataset length.") - class_indices = defaultdict(list) - for i, c in enumerate(classes): - class_indices[c].append(i) - - class_partition_indices: List[Sequence] = list() - for _, per_class_indices in sorted(class_indices.items()): - per_class_partition_indices = partition_dataset( - data=per_class_indices, - ratios=ratios, - num_partitions=num_partitions, - shuffle=shuffle, - seed=seed, - drop_last=drop_last, - even_divisible=even_divisible, - ) - if len(class_partition_indices) == 0: - class_partition_indices = per_class_partition_indices - else: - for part, data_indices in zip(class_partition_indices, per_class_partition_indices): - part += data_indices - - for indices in class_partition_indices: - if shuffle: - np.random.seed(seed) - np.random.shuffle(indices) - datasets.append([data[j] for j in indices]) + rs = np.random.RandomState(seed) + for indices in class_partition_indices: + if shuffle: + rs.shuffle(indices) + datasets.append([data[j] for j in indices]) return datasets -def select_cross_validation_folds(partitions: Sequence[List], folds: Union[Sequence[int], int]) -> List: +def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[Sequence[int], int]) -> List: """ - Select cross validation data based on data partitions and specified fold indice. - if a list of folds provided, concatenate the partitions of these folds. - For example, `partitions`: [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] and `folds`: 2 - The the output will be [5, 6], if `folds`: [1, 2], output will be [3, 4, 5, 6]. + Select cross validation data based on data partitions and specified fold index. + if a list of fold indices is provided, concatenate the partitions of these folds. - """ - folds = ensure_tuple(folds) - for fold in folds: - if fold >= len(partitions): - raise ValueError(f"fold index: {fold} is bigger than number of partitions.") + Args: + partitions: a sequence of datasets, each item is a iterable + folds: the indices of the partitions to be combined. + + Returns: + A list of combined datasets. - data_list = list() - for i, data in enumerate(partitions): - if i in folds: - data_list += data + Example:: + >>> partitions = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + >>> select_cross_validation_folds(partitions, 2) + [5, 6] + >>> select_cross_validation_folds(partitions, [1, 2]) + [3, 4, 5, 6] + >>> select_cross_validation_folds(partitions, [-1, 2]) + [9, 10, 5, 6] + """ + data_list = [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]] return data_list diff --git a/tests/test_partition_dataset_classes.py b/tests/test_partition_dataset_classes.py index 95ddb134ed..b081c496d8 100644 --- a/tests/test_partition_dataset_classes.py +++ b/tests/test_partition_dataset_classes.py @@ -68,7 +68,7 @@ "drop_last": False, "even_divisible": False, }, - [[13, 7, 14, 2, 3], [11, 8, 6, 5, 9, 10, 12, 1, 4]], + [[13, 7, 14, 2, 3], [6, 8, 1, 5, 12, 11, 4, 9, 10]], ]