15
15
from collections import defaultdict
16
16
from itertools import product , starmap
17
17
from pathlib import PurePath
18
- from typing import Dict , Generator , List , Optional , Sequence , Tuple , Union
18
+ from typing import Dict , Generator , Iterable , List , Optional , Sequence , Tuple , Union
19
19
20
20
import numpy as np
21
21
import torch
@@ -570,21 +570,27 @@ def partition_dataset(
570
570
Args:
571
571
data: input dataset to split, expect a list of data.
572
572
ratios: a list of ratio number to split the dataset, like [8, 1, 1].
573
- num_partitions: expected number of the partitions to evenly split, only works when no `ratios`.
573
+ num_partitions: expected number of the partitions to evenly split, only works when `ratios` not specified .
574
574
shuffle: whether to shuffle the original dataset before splitting.
575
575
seed: random seed to shuffle the dataset, only works when `shuffle` is True.
576
576
drop_last: only works when `even_divisible` is False and no ratios specified.
577
577
if True, will drop the tail of the data to make it evenly divisible across partitions.
578
578
if False, will add extra indices to make the data evenly divisible across partitions.
579
579
even_divisible: if True, guarantee every partition has same length.
580
580
581
- Examples:
582
- data: [1, 2, 3, 4, 5]
583
- (1) ratios: [0.6, 0.2, 0.2], shuffle=False, output: [[1, 2, 3], [4], [5]]
584
- num_partitions=2, shuffle=False
585
- (2) even_divisible=True, drop_last=True, output: [[1, 3], [2, 4]]
586
- (3) even_divisible=True, drop_last=False, output: [[1, 3, 5], [2, 4, 1]]
587
- (4) even_divisible=False, drop_last=False, output: [[1, 3, 5], [2, 4]]
581
+ Examples::
582
+
583
+ >>> data = [1, 2, 3, 4, 5]
584
+ >>> partition_dataset(data, ratios=[0.6, 0.2, 0.2], shuffle=False)
585
+ [[1, 2, 3], [4], [5]]
586
+ >>> partition_dataset(data, num_partitions=2, shuffle=False)
587
+ [[1, 3, 5], [2, 4]]
588
+ >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=True)
589
+ [[1, 3], [2, 4]]
590
+ >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=False)
591
+ [[1, 3, 5], [2, 4, 1]]
592
+ >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=False, drop_last=False)
593
+ [[1, 3, 5], [2, 4]]
588
594
589
595
"""
590
596
data_len = len (data )
@@ -593,52 +599,53 @@ def partition_dataset(
593
599
indices = list (range (data_len ))
594
600
if shuffle :
595
601
# deterministically shuffle based on fixed seed for every process
596
- np .random .seed (seed )
597
- np . random .shuffle (indices )
602
+ rs = np .random .RandomState (seed )
603
+ rs .shuffle (indices )
598
604
599
- if ratios is not None :
600
- start_idx = next_idx = 0
605
+ if ratios :
606
+ next_idx = 0
601
607
rsum = sum (ratios )
602
608
for r in ratios :
603
609
start_idx = next_idx
604
610
next_idx = min (start_idx + int (r / rsum * data_len + 0.5 ), data_len )
605
611
datasets .append ([data [i ] for i in indices [start_idx :next_idx ]])
612
+ return datasets
613
+
614
+ if not num_partitions :
615
+ raise ValueError ("must specify number of partitions or ratios." )
616
+ # evenly split the data without ratios
617
+ if not even_divisible and drop_last :
618
+ raise RuntimeError ("drop_last only works when even_divisible is True." )
619
+ if data_len < num_partitions :
620
+ raise RuntimeError (f"there is no enough data to be split into { num_partitions } partitions." )
621
+
622
+ if drop_last and data_len % num_partitions != 0 :
623
+ # split to nearest available length that is evenly divisible
624
+ num_samples = math .ceil ((data_len - num_partitions ) / num_partitions )
606
625
else :
607
- if num_partitions is None :
608
- raise ValueError ("must specify number of partitions." )
609
- # evenly split the data without ratios
610
- if not even_divisible and drop_last :
611
- raise RuntimeError ("drop_last only works when even_divisible is True." )
612
- if data_len < num_partitions :
613
- raise RuntimeError (f"there is no enough data to be splitted for { num_partitions } partitions." )
614
-
615
- if drop_last and data_len % num_partitions != 0 :
616
- # split to nearest available length that is evenly divisible
617
- num_samples = math .ceil ((data_len - num_partitions ) / num_partitions )
618
- else :
619
- num_samples = math .ceil (data_len / num_partitions )
620
- # use original data length if not even divisible
621
- total_size = num_samples * num_partitions if even_divisible else data_len
626
+ num_samples = math .ceil (data_len / num_partitions )
627
+ # use original data length if not even divisible
628
+ total_size = num_samples * num_partitions if even_divisible else data_len
622
629
623
- if not drop_last and total_size - data_len > 0 :
624
- # add extra samples to make it evenly divisible
625
- indices += indices [: (total_size - data_len )]
626
- else :
627
- # remove tail of data to make it evenly divisible
628
- indices = indices [:total_size ]
630
+ if not drop_last and total_size - data_len > 0 :
631
+ # add extra samples to make it evenly divisible
632
+ indices += indices [: (total_size - data_len )]
633
+ else :
634
+ # remove tail of data to make it evenly divisible
635
+ indices = indices [:total_size ]
629
636
630
- for i in range (num_partitions ):
631
- _indices = indices [i :total_size :num_partitions ]
632
- datasets .append ([data [j ] for j in _indices ])
637
+ for i in range (num_partitions ):
638
+ _indices = indices [i :total_size :num_partitions ]
639
+ datasets .append ([data [j ] for j in _indices ])
633
640
634
641
return datasets
635
642
636
643
637
644
def partition_dataset_classes (
638
645
data : Sequence ,
639
646
classes : Sequence [int ],
640
- ratios : Optional [Sequence [float ]],
641
- num_partitions : Optional [int ],
647
+ ratios : Optional [Sequence [float ]] = None ,
648
+ num_partitions : Optional [int ] = None ,
642
649
shuffle : bool = False ,
643
650
seed : int = 0 ,
644
651
drop_last : bool = False ,
@@ -661,67 +668,70 @@ def partition_dataset_classes(
661
668
if False, will add extra indices to make the data evenly divisible across partitions.
662
669
even_divisible: if True, guarantee every partition has same length.
663
670
664
- Examples:
665
- data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
666
- classes: [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
667
- shuffle: False, ratios: [2, 1]
668
- output: [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]
671
+ Examples::
672
+
673
+ >>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
674
+ >>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
675
+ >>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1])
676
+ [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]
669
677
670
678
"""
671
- data_len = len (data )
679
+ if not classes or len (classes ) != len (data ):
680
+ raise ValueError (f"length of classes { classes } must match the dataset length { len (data )} ." )
672
681
datasets = list ()
682
+ class_indices = defaultdict (list )
683
+ for i , c in enumerate (classes ):
684
+ class_indices [c ].append (i )
685
+
686
+ class_partition_indices : List [Sequence ] = list ()
687
+ for _ , per_class_indices in sorted (class_indices .items ()):
688
+ per_class_partition_indices = partition_dataset (
689
+ data = per_class_indices ,
690
+ ratios = ratios ,
691
+ num_partitions = num_partitions ,
692
+ shuffle = shuffle ,
693
+ seed = seed ,
694
+ drop_last = drop_last ,
695
+ even_divisible = even_divisible ,
696
+ )
697
+ if len (class_partition_indices ) == 0 :
698
+ class_partition_indices = per_class_partition_indices
699
+ else :
700
+ for part , data_indices in zip (class_partition_indices , per_class_partition_indices ):
701
+ part += data_indices
673
702
674
- if classes is not None :
675
- if len (classes ) != data_len :
676
- raise ValueError ("length of classes must match the dataset length." )
677
- class_indices = defaultdict (list )
678
- for i , c in enumerate (classes ):
679
- class_indices [c ].append (i )
680
-
681
- class_partition_indices : List [Sequence ] = list ()
682
- for _ , per_class_indices in sorted (class_indices .items ()):
683
- per_class_partition_indices = partition_dataset (
684
- data = per_class_indices ,
685
- ratios = ratios ,
686
- num_partitions = num_partitions ,
687
- shuffle = shuffle ,
688
- seed = seed ,
689
- drop_last = drop_last ,
690
- even_divisible = even_divisible ,
691
- )
692
- if len (class_partition_indices ) == 0 :
693
- class_partition_indices = per_class_partition_indices
694
- else :
695
- for part , data_indices in zip (class_partition_indices , per_class_partition_indices ):
696
- part += data_indices
697
-
698
- for indices in class_partition_indices :
699
- if shuffle :
700
- np .random .seed (seed )
701
- np .random .shuffle (indices )
702
- datasets .append ([data [j ] for j in indices ])
703
+ rs = np .random .RandomState (seed )
704
+ for indices in class_partition_indices :
705
+ if shuffle :
706
+ rs .shuffle (indices )
707
+ datasets .append ([data [j ] for j in indices ])
703
708
704
709
return datasets
705
710
706
711
707
- def select_cross_validation_folds (partitions : Sequence [List ], folds : Union [Sequence [int ], int ]) -> List :
712
+ def select_cross_validation_folds (partitions : Sequence [Iterable ], folds : Union [Sequence [int ], int ]) -> List :
708
713
"""
709
- Select cross validation data based on data partitions and specified fold indice.
710
- if a list of folds provided, concatenate the partitions of these folds.
711
- For example, `partitions`: [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] and `folds`: 2
712
- The the output will be [5, 6], if `folds`: [1, 2], output will be [3, 4, 5, 6].
714
+ Select cross validation data based on data partitions and specified fold index.
715
+ if a list of fold indices is provided, concatenate the partitions of these folds.
713
716
714
- """
715
- folds = ensure_tuple (folds )
716
- for fold in folds :
717
- if fold >= len (partitions ):
718
- raise ValueError (f"fold index: { fold } is bigger than number of partitions." )
717
+ Args:
718
+ partitions: a sequence of datasets, each item is a iterable
719
+ folds: the indices of the partitions to be combined.
720
+
721
+ Returns:
722
+ A list of combined datasets.
719
723
720
- data_list = list ()
721
- for i , data in enumerate (partitions ):
722
- if i in folds :
723
- data_list += data
724
+ Example::
724
725
726
+ >>> partitions = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
727
+ >>> select_cross_validation_folds(partitions, 2)
728
+ [5, 6]
729
+ >>> select_cross_validation_folds(partitions, [1, 2])
730
+ [3, 4, 5, 6]
731
+ >>> select_cross_validation_folds(partitions, [-1, 2])
732
+ [9, 10, 5, 6]
733
+ """
734
+ data_list = [data_item for fold_id in ensure_tuple (folds ) for data_item in partitions [fold_id ]]
725
735
return data_list
726
736
727
737
0 commit comments