@@ -105,6 +105,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None:
105
105
106
106
source_dp = IterableWrapper (range (10 ))
107
107
ref_dp = IterableWrapper (range (20 ))
108
+ ref_dp2 = IterableWrapper (range (20 ))
108
109
109
110
# Functional Test: Output should be a zip list of tuple
110
111
zip_dp = source_dp .zip_with_iter (
@@ -114,7 +115,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None:
114
115
115
116
# Functional Test: keep_key=True, and key should show up as the first element
116
117
zip_dp_w_key = source_dp .zip_with_iter (
117
- ref_datapipe = ref_dp , key_fn = lambda x : x , ref_key_fn = lambda x : x , keep_key = True , buffer_size = 10
118
+ ref_datapipe = ref_dp2 , key_fn = lambda x : x , ref_key_fn = lambda x : x , keep_key = True , buffer_size = 10
118
119
)
119
120
self .assertEqual ([(i , (i , i )) for i in range (10 )], list (zip_dp_w_key ))
120
121
@@ -145,13 +146,13 @@ def merge_to_string(item1, item2):
145
146
146
147
# Without a custom merge function, there will be nested tuples
147
148
zip_dp2 = zip_dp .zip_with_iter (
148
- ref_datapipe = ref_dp , key_fn = lambda x : x [0 ], ref_key_fn = lambda x : x , keep_key = False , buffer_size = 100
149
+ ref_datapipe = ref_dp2 , key_fn = lambda x : x [0 ], ref_key_fn = lambda x : x , keep_key = False , buffer_size = 100
149
150
)
150
151
self .assertEqual ([((i , i ), i ) for i in range (10 )], list (zip_dp2 ))
151
152
152
153
# With a custom merge function, nesting can be prevented
153
154
zip_dp2_w_merge = zip_dp .zip_with_iter (
154
- ref_datapipe = ref_dp ,
155
+ ref_datapipe = ref_dp2 ,
155
156
key_fn = lambda x : x [0 ],
156
157
ref_key_fn = lambda x : x ,
157
158
keep_key = False ,
@@ -524,10 +525,11 @@ def test_sample_multiplexer_iterdatapipe(self) -> None:
524
525
525
526
def test_in_batch_shuffler_iterdatapipe (self ) -> None :
526
527
source_dp = IterableWrapper (range (10 )).batch (3 )
528
+ source_dp2 = IterableWrapper (range (10 )).batch (3 )
527
529
528
530
# Functional Test: drop last reduces length
529
531
filtered_dp = source_dp .in_batch_shuffle ()
530
- for ret_batch , exp_batch in zip (filtered_dp , source_dp ):
532
+ for ret_batch , exp_batch in zip (filtered_dp , source_dp2 ):
531
533
ret_batch .sort ()
532
534
self .assertEqual (ret_batch , exp_batch )
533
535
@@ -762,28 +764,32 @@ def test_unzipper_iterdatapipe(self):
762
764
with self .assertRaises (BufferError ):
763
765
list (dp2 )
764
766
765
- # Reset Test: reset the DataPipe after reading part of it
767
+ # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
766
768
dp1 , dp2 = source_dp .unzip (sequence_length = 2 )
767
- i1 , i2 = iter (dp1 ), iter ( dp2 )
769
+ _ = iter (dp1 )
768
770
output2 = []
769
- for i , n2 in enumerate (i2 ):
770
- output2 .append (n2 )
771
- if i == 4 :
772
- i1 = iter (dp1 ) # Doesn't reset because i1 hasn't been read
773
- self .assertEqual (list (range (10 , 20 )), output2 )
771
+ with self .assertRaisesRegex (RuntimeError , r"iterator has been invalidated" ):
772
+ for i , n2 in enumerate (dp2 ):
773
+ output2 .append (n2 )
774
+ if i == 4 :
775
+ _ = iter (dp1 ) # This will reset all child DataPipes
776
+ self .assertEqual (list (range (10 , 15 )), output2 )
774
777
775
778
# Reset Test: DataPipe reset when some of it have been read
776
779
dp1 , dp2 = source_dp .unzip (sequence_length = 2 )
777
- i1 , i2 = iter (dp1 ), iter (dp2 )
778
780
output1 , output2 = [], []
779
- for i , (n1 , n2 ) in enumerate (zip (i1 , i2 )):
781
+ for i , (n1 , n2 ) in enumerate (zip (dp1 , dp2 )):
780
782
output1 .append (n1 )
781
783
output2 .append (n2 )
782
784
if i == 4 :
783
785
with warnings .catch_warnings (record = True ) as wa :
784
- i1 = iter (dp1 ) # Reset both all child DataPipe
786
+ _ = iter (dp1 ) # Reset both all child DataPipe
785
787
self .assertEqual (len (wa ), 1 )
786
788
self .assertRegex (str (wa [0 ].message ), r"Some child DataPipes are not exhausted" )
789
+ break
790
+ for i , (n1 , n2 ) in enumerate (zip (dp1 , dp2 )):
791
+ output1 .append (n1 )
792
+ output2 .append (n2 )
787
793
self .assertEqual (list (range (5 )) + list (range (10 )), output1 )
788
794
self .assertEqual (list (range (10 , 15 )) + list (range (10 , 20 )), output2 )
789
795
0 commit comments