Skip to content

enhance partition utilities #1191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 98 additions & 88 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections import defaultdict
from itertools import product, starmap
from pathlib import PurePath
from typing import Dict, Generator, List, Optional, Sequence, Tuple, Union
from typing import Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -570,21 +570,27 @@ def partition_dataset(
Args:
data: input dataset to split, expect a list of data.
ratios: a list of ratio number to split the dataset, like [8, 1, 1].
num_partitions: expected number of the partitions to evenly split, only works when no `ratios`.
num_partitions: expected number of the partitions to evenly split, only works when `ratios` not specified.
shuffle: whether to shuffle the original dataset before splitting.
seed: random seed to shuffle the dataset, only works when `shuffle` is True.
drop_last: only works when `even_divisible` is False and no ratios specified.
if True, will drop the tail of the data to make it evenly divisible across partitions.
if False, will add extra indices to make the data evenly divisible across partitions.
even_divisible: if True, guarantee every partition has same length.

Examples:
data: [1, 2, 3, 4, 5]
(1) ratios: [0.6, 0.2, 0.2], shuffle=False, output: [[1, 2, 3], [4], [5]]
num_partitions=2, shuffle=False
(2) even_divisible=True, drop_last=True, output: [[1, 3], [2, 4]]
(3) even_divisible=True, drop_last=False, output: [[1, 3, 5], [2, 4, 1]]
(4) even_divisible=False, drop_last=False, output: [[1, 3, 5], [2, 4]]
Examples::

>>> data = [1, 2, 3, 4, 5]
>>> partition_dataset(data, ratios=[0.6, 0.2, 0.2], shuffle=False)
[[1, 2, 3], [4], [5]]
>>> partition_dataset(data, num_partitions=2, shuffle=False)
[[1, 3, 5], [2, 4]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=True)
[[1, 3], [2, 4]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=False)
[[1, 3, 5], [2, 4, 1]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=False, drop_last=False)
[[1, 3, 5], [2, 4]]

"""
data_len = len(data)
Expand All @@ -593,52 +599,53 @@ def partition_dataset(
indices = list(range(data_len))
if shuffle:
# deterministically shuffle based on fixed seed for every process
np.random.seed(seed)
np.random.shuffle(indices)
rs = np.random.RandomState(seed)
rs.shuffle(indices)

if ratios is not None:
start_idx = next_idx = 0
if ratios:
next_idx = 0
rsum = sum(ratios)
for r in ratios:
start_idx = next_idx
next_idx = min(start_idx + int(r / rsum * data_len + 0.5), data_len)
datasets.append([data[i] for i in indices[start_idx:next_idx]])
return datasets

if not num_partitions:
raise ValueError("must specify number of partitions or ratios.")
# evenly split the data without ratios
if not even_divisible and drop_last:
raise RuntimeError("drop_last only works when even_divisible is True.")
if data_len < num_partitions:
raise RuntimeError(f"there is no enough data to be split into {num_partitions} partitions.")

if drop_last and data_len % num_partitions != 0:
# split to nearest available length that is evenly divisible
num_samples = math.ceil((data_len - num_partitions) / num_partitions)
else:
if num_partitions is None:
raise ValueError("must specify number of partitions.")
# evenly split the data without ratios
if not even_divisible and drop_last:
raise RuntimeError("drop_last only works when even_divisible is True.")
if data_len < num_partitions:
raise RuntimeError(f"there is no enough data to be splitted for {num_partitions} partitions.")

if drop_last and data_len % num_partitions != 0:
# split to nearest available length that is evenly divisible
num_samples = math.ceil((data_len - num_partitions) / num_partitions)
else:
num_samples = math.ceil(data_len / num_partitions)
# use original data length if not even divisible
total_size = num_samples * num_partitions if even_divisible else data_len
num_samples = math.ceil(data_len / num_partitions)
# use original data length if not even divisible
total_size = num_samples * num_partitions if even_divisible else data_len

if not drop_last and total_size - data_len > 0:
# add extra samples to make it evenly divisible
indices += indices[: (total_size - data_len)]
else:
# remove tail of data to make it evenly divisible
indices = indices[:total_size]
if not drop_last and total_size - data_len > 0:
# add extra samples to make it evenly divisible
indices += indices[: (total_size - data_len)]
else:
# remove tail of data to make it evenly divisible
indices = indices[:total_size]

for i in range(num_partitions):
_indices = indices[i:total_size:num_partitions]
datasets.append([data[j] for j in _indices])
for i in range(num_partitions):
_indices = indices[i:total_size:num_partitions]
datasets.append([data[j] for j in _indices])

return datasets


def partition_dataset_classes(
data: Sequence,
classes: Sequence[int],
ratios: Optional[Sequence[float]],
num_partitions: Optional[int],
ratios: Optional[Sequence[float]] = None,
num_partitions: Optional[int] = None,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False,
Expand All @@ -661,67 +668,70 @@ def partition_dataset_classes(
if False, will add extra indices to make the data evenly divisible across partitions.
even_divisible: if True, guarantee every partition has same length.

Examples:
data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
classes: [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
shuffle: False, ratios: [2, 1]
output: [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]
Examples::

>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
>>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
>>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1])
[[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]

"""
data_len = len(data)
if not classes or len(classes) != len(data):
raise ValueError(f"length of classes {classes} must match the dataset length {len(data)}.")
datasets = list()
class_indices = defaultdict(list)
for i, c in enumerate(classes):
class_indices[c].append(i)

class_partition_indices: List[Sequence] = list()
for _, per_class_indices in sorted(class_indices.items()):
per_class_partition_indices = partition_dataset(
data=per_class_indices,
ratios=ratios,
num_partitions=num_partitions,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
even_divisible=even_divisible,
)
if len(class_partition_indices) == 0:
class_partition_indices = per_class_partition_indices
else:
for part, data_indices in zip(class_partition_indices, per_class_partition_indices):
part += data_indices

if classes is not None:
if len(classes) != data_len:
raise ValueError("length of classes must match the dataset length.")
class_indices = defaultdict(list)
for i, c in enumerate(classes):
class_indices[c].append(i)

class_partition_indices: List[Sequence] = list()
for _, per_class_indices in sorted(class_indices.items()):
per_class_partition_indices = partition_dataset(
data=per_class_indices,
ratios=ratios,
num_partitions=num_partitions,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
even_divisible=even_divisible,
)
if len(class_partition_indices) == 0:
class_partition_indices = per_class_partition_indices
else:
for part, data_indices in zip(class_partition_indices, per_class_partition_indices):
part += data_indices

for indices in class_partition_indices:
if shuffle:
np.random.seed(seed)
np.random.shuffle(indices)
datasets.append([data[j] for j in indices])
rs = np.random.RandomState(seed)
for indices in class_partition_indices:
if shuffle:
rs.shuffle(indices)
datasets.append([data[j] for j in indices])

return datasets


def select_cross_validation_folds(partitions: Sequence[List], folds: Union[Sequence[int], int]) -> List:
def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[Sequence[int], int]) -> List:
"""
Select cross validation data based on data partitions and specified fold indice.
if a list of folds provided, concatenate the partitions of these folds.
For example, `partitions`: [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] and `folds`: 2
The the output will be [5, 6], if `folds`: [1, 2], output will be [3, 4, 5, 6].
Select cross validation data based on data partitions and specified fold index.
if a list of fold indices is provided, concatenate the partitions of these folds.

"""
folds = ensure_tuple(folds)
for fold in folds:
if fold >= len(partitions):
raise ValueError(f"fold index: {fold} is bigger than number of partitions.")
Args:
partitions: a sequence of datasets, each item is a iterable
folds: the indices of the partitions to be combined.

Returns:
A list of combined datasets.

data_list = list()
for i, data in enumerate(partitions):
if i in folds:
data_list += data
Example::

>>> partitions = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
>>> select_cross_validation_folds(partitions, 2)
[5, 6]
>>> select_cross_validation_folds(partitions, [1, 2])
[3, 4, 5, 6]
>>> select_cross_validation_folds(partitions, [-1, 2])
[9, 10, 5, 6]
"""
data_list = [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]]
return data_list


Expand Down
2 changes: 1 addition & 1 deletion tests/test_partition_dataset_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"drop_last": False,
"even_divisible": False,
},
[[13, 7, 14, 2, 3], [11, 8, 6, 5, 9, 10, 12, 1, 4]],
[[13, 7, 14, 2, 3], [6, 8, 1, 5, 12, 11, 4, 9, 10]],
]


Expand Down