15
15
from collections import defaultdict
16
16
from itertools import product , starmap
17
17
from pathlib import PurePath
18
- from typing import Dict , Generator , List , Optional , Sequence , Tuple , Union
18
+ from typing import Dict , Generator , Iterable , List , Optional , Sequence , Tuple , Union
19
19
20
20
import numpy as np
21
21
import torch
@@ -570,21 +570,27 @@ def partition_dataset(
570
570
Args:
571
571
data: input dataset to split, expect a list of data.
572
572
ratios: a list of ratio number to split the dataset, like [8, 1, 1].
573
- num_partitions: expected number of the partitions to evenly split, only works when no `ratios`.
573
+ num_partitions: expected number of the partitions to evenly split, only works when `ratios` not specified .
574
574
shuffle: whether to shuffle the original dataset before splitting.
575
575
seed: random seed to shuffle the dataset, only works when `shuffle` is True.
576
576
drop_last: only works when `even_divisible` is False and no ratios specified.
577
577
if True, will drop the tail of the data to make it evenly divisible across partitions.
578
578
if False, will add extra indices to make the data evenly divisible across partitions.
579
579
even_divisible: if True, guarantee every partition has same length.
580
580
581
- Examples:
582
- data: [1, 2, 3, 4, 5]
583
- (1) ratios: [0.6, 0.2, 0.2], shuffle=False, output: [[1, 2, 3], [4], [5]]
584
- num_partitions=2, shuffle=False
585
- (2) even_divisible=True, drop_last=True, output: [[1, 3], [2, 4]]
586
- (3) even_divisible=True, drop_last=False, output: [[1, 3, 5], [2, 4, 1]]
587
- (4) even_divisible=False, drop_last=False, output: [[1, 3, 5], [2, 4]]
581
+ Examples::
582
+
583
+ >>> data = [1, 2, 3, 4, 5]
584
+ >>> partition_dataset(data, ratios=[0.6, 0.2, 0.2], shuffle=False)
585
+ [[1, 2, 3], [4], [5]]
586
+ >>> partition_dataset(data, num_partitions=2, shuffle=False)
587
+ [[1, 3, 5], [2, 4]]
588
+ >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=True)
589
+ [[1, 3], [2, 4]]
590
+ >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=False)
591
+ [[1, 3, 5], [2, 4, 1]]
592
+ >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=False, drop_last=False)
593
+ [[1, 3, 5], [2, 4]]
588
594
589
595
"""
590
596
data_len = len (data )
@@ -593,52 +599,53 @@ def partition_dataset(
593
599
indices = list (range (data_len ))
594
600
if shuffle :
595
601
# deterministically shuffle based on fixed seed for every process
596
- np .random .seed (seed )
597
- np . random .shuffle (indices )
602
+ rs = np .random .RandomState (seed )
603
+ rs .shuffle (indices )
598
604
599
- if ratios is not None :
600
- start_idx = next_idx = 0
605
+ if ratios :
606
+ next_idx = 0
601
607
rsum = sum (ratios )
602
608
for r in ratios :
603
609
start_idx = next_idx
604
610
next_idx = min (start_idx + int (r / rsum * data_len + 0.5 ), data_len )
605
611
datasets .append ([data [i ] for i in indices [start_idx :next_idx ]])
612
+ return datasets
613
+
614
+ if not num_partitions :
615
+ raise ValueError ("must specify number of partitions or ratios." )
616
+ # evenly split the data without ratios
617
+ if not even_divisible and drop_last :
618
+ raise RuntimeError ("drop_last only works when even_divisible is True." )
619
+ if data_len < num_partitions :
620
+ raise RuntimeError (f"there is no enough data to be split into { num_partitions } partitions." )
621
+
622
+ if drop_last and data_len % num_partitions != 0 :
623
+ # split to nearest available length that is evenly divisible
624
+ num_samples = math .ceil ((data_len - num_partitions ) / num_partitions )
606
625
else :
607
- if num_partitions is None :
608
- raise ValueError ("must specify number of partitions." )
609
- # evenly split the data without ratios
610
- if not even_divisible and drop_last :
611
- raise RuntimeError ("drop_last only works when even_divisible is True." )
612
- if data_len < num_partitions :
613
- raise RuntimeError (f"there is no enough data to be splitted for { num_partitions } partitions." )
614
-
615
- if drop_last and data_len % num_partitions != 0 :
616
- # split to nearest available length that is evenly divisible
617
- num_samples = math .ceil ((data_len - num_partitions ) / num_partitions )
618
- else :
619
- num_samples = math .ceil (data_len / num_partitions )
620
- # use original data length if not even divisible
621
- total_size = num_samples * num_partitions if even_divisible else data_len
626
+ num_samples = math .ceil (data_len / num_partitions )
627
+ # use original data length if not even divisible
628
+ total_size = num_samples * num_partitions if even_divisible else data_len
622
629
623
- if not drop_last and total_size - data_len > 0 :
624
- # add extra samples to make it evenly divisible
625
- indices += indices [: (total_size - data_len )]
626
- else :
627
- # remove tail of data to make it evenly divisible
628
- indices = indices [:total_size ]
630
+ if not drop_last and total_size - data_len > 0 :
631
+ # add extra samples to make it evenly divisible
632
+ indices += indices [: (total_size - data_len )]
633
+ else :
634
+ # remove tail of data to make it evenly divisible
635
+ indices = indices [:total_size ]
629
636
630
- for i in range (num_partitions ):
631
- _indices = indices [i :total_size :num_partitions ]
632
- datasets .append ([data [j ] for j in _indices ])
637
+ for i in range (num_partitions ):
638
+ _indices = indices [i :total_size :num_partitions ]
639
+ datasets .append ([data [j ] for j in _indices ])
633
640
634
641
return datasets
635
642
636
643
637
644
def partition_dataset_classes (
638
645
data : Sequence ,
639
646
classes : Sequence [int ],
640
- ratios : Optional [Sequence [float ]],
641
- num_partitions : Optional [int ],
647
+ ratios : Optional [Sequence [float ]] = None ,
648
+ num_partitions : Optional [int ] = None ,
642
649
shuffle : bool = False ,
643
650
seed : int = 0 ,
644
651
drop_last : bool = False ,
@@ -661,67 +668,72 @@ def partition_dataset_classes(
661
668
if False, will add extra indices to make the data evenly divisible across partitions.
662
669
even_divisible: if True, guarantee every partition has same length.
663
670
664
- Examples:
665
- data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
666
- classes: [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
667
- shuffle: False, ratios: [2, 1]
668
- output: [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]
671
+ Examples::
672
+
673
+ >>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
674
+ >>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
675
+ >>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1])
676
+ [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]
669
677
670
678
"""
671
- data_len = len (data )
672
679
datasets = list ()
680
+ if not classes :
681
+ return []
682
+ if len (classes ) != len (data ):
683
+ raise ValueError (f"length of classes { len (classes )} must match the dataset length { len (data )} ." )
684
+ class_indices = defaultdict (list )
685
+ for i , c in enumerate (classes ):
686
+ class_indices [c ].append (i )
687
+
688
+ class_partition_indices : List [Sequence ] = list ()
689
+ for _ , per_class_indices in sorted (class_indices .items ()):
690
+ per_class_partition_indices = partition_dataset (
691
+ data = per_class_indices ,
692
+ ratios = ratios ,
693
+ num_partitions = num_partitions ,
694
+ shuffle = shuffle ,
695
+ seed = seed ,
696
+ drop_last = drop_last ,
697
+ even_divisible = even_divisible ,
698
+ )
699
+ if len (class_partition_indices ) == 0 :
700
+ class_partition_indices = per_class_partition_indices
701
+ else :
702
+ for part , data_indices in zip (class_partition_indices , per_class_partition_indices ):
703
+ part += data_indices
673
704
674
- if classes is not None :
675
- if len (classes ) != data_len :
676
- raise ValueError ("length of classes must match the dataset length." )
677
- class_indices = defaultdict (list )
678
- for i , c in enumerate (classes ):
679
- class_indices [c ].append (i )
680
-
681
- class_partition_indices : List [Sequence ] = list ()
682
- for _ , per_class_indices in sorted (class_indices .items ()):
683
- per_class_partition_indices = partition_dataset (
684
- data = per_class_indices ,
685
- ratios = ratios ,
686
- num_partitions = num_partitions ,
687
- shuffle = shuffle ,
688
- seed = seed ,
689
- drop_last = drop_last ,
690
- even_divisible = even_divisible ,
691
- )
692
- if len (class_partition_indices ) == 0 :
693
- class_partition_indices = per_class_partition_indices
694
- else :
695
- for part , data_indices in zip (class_partition_indices , per_class_partition_indices ):
696
- part += data_indices
697
-
698
- for indices in class_partition_indices :
699
- if shuffle :
700
- np .random .seed (seed )
701
- np .random .shuffle (indices )
702
- datasets .append ([data [j ] for j in indices ])
705
+ rs = np .random .RandomState (seed )
706
+ for indices in class_partition_indices :
707
+ if shuffle :
708
+ rs .shuffle (indices )
709
+ datasets .append ([data [j ] for j in indices ])
703
710
704
711
return datasets
705
712
706
713
707
- def select_cross_validation_folds (partitions : Sequence [List ], folds : Union [Sequence [int ], int ]) -> List :
714
+ def select_cross_validation_folds (partitions : Sequence [Iterable ], folds : Union [Sequence [int ], int ]) -> List :
708
715
"""
709
- Select cross validation data based on data partitions and specified fold indice.
710
- if a list of folds provided, concatenate the partitions of these folds.
711
- For example, `partitions`: [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] and `folds`: 2
712
- The the output will be [5, 6], if `folds`: [1, 2], output will be [3, 4, 5, 6].
716
+ Select cross validation data based on data partitions and specified fold index.
717
+ if a list of fold indices is provided, concatenate the partitions of these folds.
713
718
714
- """
715
- folds = ensure_tuple (folds )
716
- for fold in folds :
717
- if fold >= len (partitions ):
718
- raise ValueError (f"fold index: { fold } is bigger than number of partitions." )
719
+ Args:
720
+ partitions: a sequence of datasets, each item is a iterable
721
+ folds: the indices of the partitions to be combined.
722
+
723
+ Returns:
724
+ A list of combined datasets.
719
725
720
- data_list = list ()
721
- for i , data in enumerate (partitions ):
722
- if i in folds :
723
- data_list += data
726
+ Example::
724
727
728
+ >>> partitions = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
729
+ >>> select_cross_validation_folds(partitions, 2)
730
+ [5, 6]
731
+ >>> select_cross_validation_folds(partitions, [1, 2])
732
+ [3, 4, 5, 6]
733
+ >>> select_cross_validation_folds(partitions, [-1, 2])
734
+ [9, 10, 5, 6]
735
+ """
736
+ data_list = [data_item for fold_id in ensure_tuple (folds ) for data_item in partitions [fold_id ]]
725
737
return data_list
726
738
727
739
0 commit comments