Skip to content

Updating test code to follow single iterator constraint #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None:

source_dp = IterableWrapper(range(10))
ref_dp = IterableWrapper(range(20))
ref_dp2 = IterableWrapper(range(20))

# Functional Test: Output should be a zip list of tuple
zip_dp = source_dp.zip_with_iter(
Expand All @@ -114,7 +115,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None:

# Functional Test: keep_key=True, and key should show up as the first element
zip_dp_w_key = source_dp.zip_with_iter(
ref_datapipe=ref_dp, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10
ref_datapipe=ref_dp2, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10
)
self.assertEqual([(i, (i, i)) for i in range(10)], list(zip_dp_w_key))

Expand Down Expand Up @@ -145,13 +146,13 @@ def merge_to_string(item1, item2):

# Without a custom merge function, there will be nested tuples
zip_dp2 = zip_dp.zip_with_iter(
ref_datapipe=ref_dp, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100
ref_datapipe=ref_dp2, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100
)
self.assertEqual([((i, i), i) for i in range(10)], list(zip_dp2))

# With a custom merge function, nesting can be prevented
zip_dp2_w_merge = zip_dp.zip_with_iter(
ref_datapipe=ref_dp,
ref_datapipe=ref_dp2,
key_fn=lambda x: x[0],
ref_key_fn=lambda x: x,
keep_key=False,
Expand Down Expand Up @@ -524,10 +525,11 @@ def test_sample_multiplexer_iterdatapipe(self) -> None:

def test_in_batch_shuffler_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10)).batch(3)
source_dp2 = IterableWrapper(range(10)).batch(3)

# Functional Test: drop last reduces length
filtered_dp = source_dp.in_batch_shuffle()
for ret_batch, exp_batch in zip(filtered_dp, source_dp):
for ret_batch, exp_batch in zip(filtered_dp, source_dp2):
ret_batch.sort()
self.assertEqual(ret_batch, exp_batch)

Expand Down Expand Up @@ -762,28 +764,32 @@ def test_unzipper_iterdatapipe(self):
with self.assertRaises(BufferError):
list(dp2)

# Reset Test: reset the DataPipe after reading part of it
# Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
dp1, dp2 = source_dp.unzip(sequence_length=2)
i1, i2 = iter(dp1), iter(dp2)
_ = iter(dp1)
output2 = []
for i, n2 in enumerate(i2):
output2.append(n2)
if i == 4:
i1 = iter(dp1) # Doesn't reset because i1 hasn't been read
self.assertEqual(list(range(10, 20)), output2)
with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
for i, n2 in enumerate(dp2):
output2.append(n2)
if i == 4:
_ = iter(dp1) # This will reset all child DataPipes
self.assertEqual(list(range(10, 15)), output2)

# Reset Test: DataPipe reset when some of it have been read
dp1, dp2 = source_dp.unzip(sequence_length=2)
i1, i2 = iter(dp1), iter(dp2)
output1, output2 = [], []
for i, (n1, n2) in enumerate(zip(i1, i2)):
for i, (n1, n2) in enumerate(zip(dp1, dp2)):
output1.append(n1)
output2.append(n2)
if i == 4:
with warnings.catch_warnings(record=True) as wa:
i1 = iter(dp1) # Reset both all child DataPipe
_ = iter(dp1) # Reset both all child DataPipe
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
break
for i, (n1, n2) in enumerate(zip(dp1, dp2)):
output1.append(n1)
output2.append(n2)
self.assertEqual(list(range(5)) + list(range(10)), output1)
self.assertEqual(list(range(10, 15)) + list(range(10, 20)), output2)

Expand Down
11 changes: 7 additions & 4 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,16 @@ def fill_hash_dict():
datapipe2 = FileOpener(datapipe1, mode="b")
hash_check_dp = HashChecker(datapipe2, hash_dict)

expected_res = list(datapipe2)

# Functional Test: Ensure the DataPipe values are unchanged if the hashes are the same
for (expected_path, expected_stream), (actual_path, actual_stream) in zip(datapipe2, hash_check_dp):
for (expected_path, expected_stream), (actual_path, actual_stream) in zip(expected_res, hash_check_dp):
self.assertEqual(expected_path, actual_path)
self.assertEqual(expected_stream.read(), actual_stream.read())

# Functional Test: Ensure the rewind option works, and the stream is empty when there is no rewind
hash_check_dp_no_reset = HashChecker(datapipe2, hash_dict, rewind=False)
for (expected_path, _), (actual_path, actual_stream) in zip(datapipe2, hash_check_dp_no_reset):
for (expected_path, _), (actual_path, actual_stream) in zip(expected_res, hash_check_dp_no_reset):
self.assertEqual(expected_path, actual_path)
self.assertEqual(b"", actual_stream.read())

Expand Down Expand Up @@ -458,7 +460,7 @@ def test_xz_archive_reader_iterdatapipe(self):
self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset)

# Reset Test: Ensure the order is consistent between iterations
for r1, r2 in zip(xz_loader_dp, xz_loader_dp):
for r1, r2 in zip(list(xz_loader_dp), list(xz_loader_dp)):
self.assertEqual(r1[0], r2[0])

# __len__ Test: doesn't have valid length
Expand Down Expand Up @@ -497,7 +499,8 @@ def test_bz2_archive_reader_iterdatapipe(self):
self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset)

# Reset Test: Ensure the order is consistent between iterations
for r1, r2 in zip(bz2_loader_dp, bz2_loader_dp):

for r1, r2 in zip(list(bz2_loader_dp), list(bz2_loader_dp)):
self.assertEqual(r1[0], r2[0])

# __len__ Test: doesn't have valid length
Expand Down
3 changes: 3 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _serialization_test_for_single_dp(self, dp, use_dill, is_dataframe=False):
_ = next(it)
test_helper_fn(dp, use_dill)
# 3. Testing for serialization after DataPipe is fully read
it = iter(dp)
_ = list(it)
test_helper_fn(dp, use_dill)

Expand All @@ -146,10 +147,12 @@ def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill):
self._serialization_test_helper(dp2, use_dill=use_dill)
# 2.5. Testing for serialization after one child DataPipe is fully read
# (Only for DataPipes with children DataPipes)
it1 = iter(dp1)
_ = list(it1) # fully read one child
self._serialization_test_helper(dp1, use_dill=use_dill)
self._serialization_test_helper(dp2, use_dill=use_dill)
# 3. Testing for serialization after DataPipe is fully read
it2 = iter(dp2)
_ = list(it2) # fully read the other child
self._serialization_test_helper(dp1, use_dill=use_dill)
self._serialization_test_helper(dp2, use_dill=use_dill)
Expand Down