diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 8fc4cb16a..b26944972 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -105,6 +105,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None: source_dp = IterableWrapper(range(10)) ref_dp = IterableWrapper(range(20)) + ref_dp2 = IterableWrapper(range(20)) # Functional Test: Output should be a zip list of tuple zip_dp = source_dp.zip_with_iter( @@ -114,7 +115,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None: # Functional Test: keep_key=True, and key should show up as the first element zip_dp_w_key = source_dp.zip_with_iter( - ref_datapipe=ref_dp, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10 + ref_datapipe=ref_dp2, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10 ) self.assertEqual([(i, (i, i)) for i in range(10)], list(zip_dp_w_key)) @@ -145,13 +146,13 @@ def merge_to_string(item1, item2): # Without a custom merge function, there will be nested tuples zip_dp2 = zip_dp.zip_with_iter( - ref_datapipe=ref_dp, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100 + ref_datapipe=ref_dp2, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100 ) self.assertEqual([((i, i), i) for i in range(10)], list(zip_dp2)) # With a custom merge function, nesting can be prevented zip_dp2_w_merge = zip_dp.zip_with_iter( - ref_datapipe=ref_dp, + ref_datapipe=ref_dp2, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, @@ -524,10 +525,11 @@ def test_sample_multiplexer_iterdatapipe(self) -> None: def test_in_batch_shuffler_iterdatapipe(self) -> None: source_dp = IterableWrapper(range(10)).batch(3) + source_dp2 = IterableWrapper(range(10)).batch(3) # Functional Test: drop last reduces length filtered_dp = source_dp.in_batch_shuffle() - for ret_batch, exp_batch in zip(filtered_dp, source_dp): + for ret_batch, exp_batch in zip(filtered_dp, source_dp2): ret_batch.sort() self.assertEqual(ret_batch, exp_batch) @@ -762,28 +764,32 @@ def test_unzipper_iterdatapipe(self): with self.assertRaises(BufferError): list(dp2) - # Reset Test: reset the DataPipe after reading part of it + # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read dp1, dp2 = source_dp.unzip(sequence_length=2) - i1, i2 = iter(dp1), iter(dp2) + _ = iter(dp1) output2 = [] - for i, n2 in enumerate(i2): - output2.append(n2) - if i == 4: - i1 = iter(dp1) # Doesn't reset because i1 hasn't been read - self.assertEqual(list(range(10, 20)), output2) + with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"): + for i, n2 in enumerate(dp2): + output2.append(n2) + if i == 4: + _ = iter(dp1) # This will reset all child DataPipes + self.assertEqual(list(range(10, 15)), output2) # Reset Test: DataPipe reset when some of it have been read dp1, dp2 = source_dp.unzip(sequence_length=2) - i1, i2 = iter(dp1), iter(dp2) output1, output2 = [], [] - for i, (n1, n2) in enumerate(zip(i1, i2)): + for i, (n1, n2) in enumerate(zip(dp1, dp2)): output1.append(n1) output2.append(n2) if i == 4: with warnings.catch_warnings(record=True) as wa: - i1 = iter(dp1) # Reset both all child DataPipe + _ = iter(dp1) # Reset both all child DataPipe self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") + break + for i, (n1, n2) in enumerate(zip(dp1, dp2)): + output1.append(n1) + output2.append(n2) self.assertEqual(list(range(5)) + list(range(10)), output1) self.assertEqual(list(range(10, 15)) + list(range(10, 20)), output2) diff --git a/test/test_local_io.py b/test/test_local_io.py index c20d685ce..dea43e633 100644 --- a/test/test_local_io.py +++ b/test/test_local_io.py @@ -231,14 +231,16 @@ def fill_hash_dict(): datapipe2 = FileOpener(datapipe1, mode="b") hash_check_dp = HashChecker(datapipe2, hash_dict) + expected_res = list(datapipe2) + # Functional Test: Ensure the DataPipe values are unchanged if the hashes are the same - for (expected_path, expected_stream), (actual_path, actual_stream) in zip(datapipe2, hash_check_dp): + for (expected_path, expected_stream), (actual_path, actual_stream) in zip(expected_res, hash_check_dp): self.assertEqual(expected_path, actual_path) self.assertEqual(expected_stream.read(), actual_stream.read()) # Functional Test: Ensure the rewind option works, and the stream is empty when there is no rewind hash_check_dp_no_reset = HashChecker(datapipe2, hash_dict, rewind=False) - for (expected_path, _), (actual_path, actual_stream) in zip(datapipe2, hash_check_dp_no_reset): + for (expected_path, _), (actual_path, actual_stream) in zip(expected_res, hash_check_dp_no_reset): self.assertEqual(expected_path, actual_path) self.assertEqual(b"", actual_stream.read()) @@ -458,7 +460,7 @@ def test_xz_archive_reader_iterdatapipe(self): self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset) # Reset Test: Ensure the order is consistent between iterations - for r1, r2 in zip(xz_loader_dp, xz_loader_dp): + for r1, r2 in zip(list(xz_loader_dp), list(xz_loader_dp)): self.assertEqual(r1[0], r2[0]) # __len__ Test: doesn't have valid length @@ -497,7 +499,8 @@ def test_bz2_archive_reader_iterdatapipe(self): self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset) # Reset Test: Ensure the order is consistent between iterations - for r1, r2 in zip(bz2_loader_dp, bz2_loader_dp): + + for r1, r2 in zip(list(bz2_loader_dp), list(bz2_loader_dp)): self.assertEqual(r1[0], r2[0]) # __len__ Test: doesn't have valid length diff --git a/test/test_serialization.py b/test/test_serialization.py index c7a744d63..2237e63f4 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -132,6 +132,7 @@ def _serialization_test_for_single_dp(self, dp, use_dill, is_dataframe=False): _ = next(it) test_helper_fn(dp, use_dill) # 3. Testing for serialization after DataPipe is fully read + it = iter(dp) _ = list(it) test_helper_fn(dp, use_dill) @@ -146,10 +147,12 @@ def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill): self._serialization_test_helper(dp2, use_dill=use_dill) # 2.5. Testing for serialization after one child DataPipe is fully read # (Only for DataPipes with children DataPipes) + it1 = iter(dp1) _ = list(it1) # fully read one child self._serialization_test_helper(dp1, use_dill=use_dill) self._serialization_test_helper(dp2, use_dill=use_dill) # 3. Testing for serialization after DataPipe is fully read + it2 = iter(dp2) _ = list(it2) # fully read the other child self._serialization_test_helper(dp1, use_dill=use_dill) self._serialization_test_helper(dp2, use_dill=use_dill)