Skip to content

Commit 7362f7d

Browse files
NivekTfacebook-github-bot
authored andcommitted
Updating test code to follow single iterator constraint (#386)
Summary: Pull Request resolved: #386 This should be landed after pytorch/pytorch#75995 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D36182556 Pulled By: NivekT fbshipit-source-id: b9dbbb6a75f97d808a562bb1b5309711eff82f16
1 parent 7555779 commit 7362f7d

File tree

3 files changed

+30
-18
lines changed

3 files changed

+30
-18
lines changed

test/test_iterdatapipe.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None:
105105

106106
source_dp = IterableWrapper(range(10))
107107
ref_dp = IterableWrapper(range(20))
108+
ref_dp2 = IterableWrapper(range(20))
108109

109110
# Functional Test: Output should be a zip list of tuple
110111
zip_dp = source_dp.zip_with_iter(
@@ -114,7 +115,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None:
114115

115116
# Functional Test: keep_key=True, and key should show up as the first element
116117
zip_dp_w_key = source_dp.zip_with_iter(
117-
ref_datapipe=ref_dp, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10
118+
ref_datapipe=ref_dp2, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10
118119
)
119120
self.assertEqual([(i, (i, i)) for i in range(10)], list(zip_dp_w_key))
120121

@@ -145,13 +146,13 @@ def merge_to_string(item1, item2):
145146

146147
# Without a custom merge function, there will be nested tuples
147148
zip_dp2 = zip_dp.zip_with_iter(
148-
ref_datapipe=ref_dp, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100
149+
ref_datapipe=ref_dp2, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100
149150
)
150151
self.assertEqual([((i, i), i) for i in range(10)], list(zip_dp2))
151152

152153
# With a custom merge function, nesting can be prevented
153154
zip_dp2_w_merge = zip_dp.zip_with_iter(
154-
ref_datapipe=ref_dp,
155+
ref_datapipe=ref_dp2,
155156
key_fn=lambda x: x[0],
156157
ref_key_fn=lambda x: x,
157158
keep_key=False,
@@ -524,10 +525,11 @@ def test_sample_multiplexer_iterdatapipe(self) -> None:
524525

525526
def test_in_batch_shuffler_iterdatapipe(self) -> None:
526527
source_dp = IterableWrapper(range(10)).batch(3)
528+
source_dp2 = IterableWrapper(range(10)).batch(3)
527529

528530
# Functional Test: drop last reduces length
529531
filtered_dp = source_dp.in_batch_shuffle()
530-
for ret_batch, exp_batch in zip(filtered_dp, source_dp):
532+
for ret_batch, exp_batch in zip(filtered_dp, source_dp2):
531533
ret_batch.sort()
532534
self.assertEqual(ret_batch, exp_batch)
533535

@@ -762,28 +764,32 @@ def test_unzipper_iterdatapipe(self):
762764
with self.assertRaises(BufferError):
763765
list(dp2)
764766

765-
# Reset Test: reset the DataPipe after reading part of it
767+
# Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
766768
dp1, dp2 = source_dp.unzip(sequence_length=2)
767-
i1, i2 = iter(dp1), iter(dp2)
769+
_ = iter(dp1)
768770
output2 = []
769-
for i, n2 in enumerate(i2):
770-
output2.append(n2)
771-
if i == 4:
772-
i1 = iter(dp1) # Doesn't reset because i1 hasn't been read
773-
self.assertEqual(list(range(10, 20)), output2)
771+
with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
772+
for i, n2 in enumerate(dp2):
773+
output2.append(n2)
774+
if i == 4:
775+
_ = iter(dp1) # This will reset all child DataPipes
776+
self.assertEqual(list(range(10, 15)), output2)
774777

775778
# Reset Test: DataPipe reset when some of it have been read
776779
dp1, dp2 = source_dp.unzip(sequence_length=2)
777-
i1, i2 = iter(dp1), iter(dp2)
778780
output1, output2 = [], []
779-
for i, (n1, n2) in enumerate(zip(i1, i2)):
781+
for i, (n1, n2) in enumerate(zip(dp1, dp2)):
780782
output1.append(n1)
781783
output2.append(n2)
782784
if i == 4:
783785
with warnings.catch_warnings(record=True) as wa:
784-
i1 = iter(dp1) # Reset both all child DataPipe
786+
_ = iter(dp1) # Reset both all child DataPipe
785787
self.assertEqual(len(wa), 1)
786788
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
789+
break
790+
for i, (n1, n2) in enumerate(zip(dp1, dp2)):
791+
output1.append(n1)
792+
output2.append(n2)
787793
self.assertEqual(list(range(5)) + list(range(10)), output1)
788794
self.assertEqual(list(range(10, 15)) + list(range(10, 20)), output2)
789795

test/test_local_io.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,16 @@ def fill_hash_dict():
231231
datapipe2 = FileOpener(datapipe1, mode="b")
232232
hash_check_dp = HashChecker(datapipe2, hash_dict)
233233

234+
expected_res = list(datapipe2)
235+
234236
# Functional Test: Ensure the DataPipe values are unchanged if the hashes are the same
235-
for (expected_path, expected_stream), (actual_path, actual_stream) in zip(datapipe2, hash_check_dp):
237+
for (expected_path, expected_stream), (actual_path, actual_stream) in zip(expected_res, hash_check_dp):
236238
self.assertEqual(expected_path, actual_path)
237239
self.assertEqual(expected_stream.read(), actual_stream.read())
238240

239241
# Functional Test: Ensure the rewind option works, and the stream is empty when there is no rewind
240242
hash_check_dp_no_reset = HashChecker(datapipe2, hash_dict, rewind=False)
241-
for (expected_path, _), (actual_path, actual_stream) in zip(datapipe2, hash_check_dp_no_reset):
243+
for (expected_path, _), (actual_path, actual_stream) in zip(expected_res, hash_check_dp_no_reset):
242244
self.assertEqual(expected_path, actual_path)
243245
self.assertEqual(b"", actual_stream.read())
244246

@@ -458,7 +460,7 @@ def test_xz_archive_reader_iterdatapipe(self):
458460
self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset)
459461

460462
# Reset Test: Ensure the order is consistent between iterations
461-
for r1, r2 in zip(xz_loader_dp, xz_loader_dp):
463+
for r1, r2 in zip(list(xz_loader_dp), list(xz_loader_dp)):
462464
self.assertEqual(r1[0], r2[0])
463465

464466
# __len__ Test: doesn't have valid length
@@ -497,7 +499,8 @@ def test_bz2_archive_reader_iterdatapipe(self):
497499
self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset)
498500

499501
# Reset Test: Ensure the order is consistent between iterations
500-
for r1, r2 in zip(bz2_loader_dp, bz2_loader_dp):
502+
503+
for r1, r2 in zip(list(bz2_loader_dp), list(bz2_loader_dp)):
501504
self.assertEqual(r1[0], r2[0])
502505

503506
# __len__ Test: doesn't have valid length

test/test_serialization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _serialization_test_for_single_dp(self, dp, use_dill, is_dataframe=False):
132132
_ = next(it)
133133
test_helper_fn(dp, use_dill)
134134
# 3. Testing for serialization after DataPipe is fully read
135+
it = iter(dp)
135136
_ = list(it)
136137
test_helper_fn(dp, use_dill)
137138

@@ -146,10 +147,12 @@ def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill):
146147
self._serialization_test_helper(dp2, use_dill=use_dill)
147148
# 2.5. Testing for serialization after one child DataPipe is fully read
148149
# (Only for DataPipes with children DataPipes)
150+
it1 = iter(dp1)
149151
_ = list(it1) # fully read one child
150152
self._serialization_test_helper(dp1, use_dill=use_dill)
151153
self._serialization_test_helper(dp2, use_dill=use_dill)
152154
# 3. Testing for serialization after DataPipe is fully read
155+
it2 = iter(dp2)
153156
_ = list(it2) # fully read the other child
154157
self._serialization_test_helper(dp1, use_dill=use_dill)
155158
self._serialization_test_helper(dp2, use_dill=use_dill)

0 commit comments

Comments
 (0)