Skip to content

Commit c06066a

Browse files
NivekTfacebook-github-bot
authored andcommitted
Fixing HeaderIterDP's __len__ function (#166)
Summary: Pull Request resolved: #166 The previous implementation simply return the limit as the length for `HeaderIterDataPipe`. This updated implementation takes it a step further to account for the different possible scenarios. Fixes #123 Fixes #134 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D33589168 Pulled By: NivekT fbshipit-source-id: a1b20357031a8a606e1ba2fa40032cf48f91e145
1 parent 24c25c0 commit c06066a

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

test/test_datapipe.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,24 @@ def test_header_iterdatapipe(self) -> None:
307307
# __len__ Test: returns the limit when it is less than the length of source
308308
self.assertEqual(5, len(header_dp))
309309

310-
# TODO(123): __len__ Test: returns the length of source when it is less than the limit
311-
# header_dp = source_dp.header(30)
312-
# self.assertEqual(20, len(header_dp))
310+
# __len__ Test: returns the length of source when it is less than the limit
311+
header_dp = source_dp.header(30)
312+
self.assertEqual(20, len(header_dp))
313+
314+
# __len__ Test: returns limit if source doesn't have length
315+
source_dp_NoLen = IDP_NoLen(list(range(20)))
316+
header_dp = source_dp_NoLen.header(30)
317+
with warnings.catch_warnings(record=True) as wa:
318+
self.assertEqual(30, len(header_dp))
319+
self.assertEqual(len(wa), 1)
320+
self.assertRegex(
321+
str(wa[0].message), r"length of this HeaderIterDataPipe is inferred to be equal to its limit"
322+
)
323+
324+
# __len__ Test: returns limit if source doesn't have length, but it has been iterated through once
325+
for _ in header_dp:
326+
pass
327+
self.assertEqual(20, len(header_dp))
313328

314329
def test_enumerator_iterdatapipe(self) -> None:
315330
letters = "abcde"

torchdata/datapipes/iter/util/header.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
from typing import Iterator, TypeVar
3+
from warnings import warn
34

45
from torchdata.datapipes import functional_datapipe
56
from torchdata.datapipes.iter import IterDataPipe
@@ -20,14 +21,28 @@ class HeaderIterDataPipe(IterDataPipe[T_co]):
2021
def __init__(self, source_datapipe: IterDataPipe[T_co], limit: int = 10) -> None:
2122
self.source_datapipe: IterDataPipe[T_co] = source_datapipe
2223
self.limit: int = limit
24+
self.length: int = -1
2325

2426
def __iter__(self) -> Iterator[T_co]:
25-
for i, value in enumerate(self.source_datapipe):
26-
if i < self.limit:
27+
i: int = 0
28+
for value in self.source_datapipe:
29+
i += 1
30+
if i <= self.limit:
2731
yield value
2832
else:
2933
break
34+
self.length = min(i, self.limit) # We know length with certainty when we reach here
3035

31-
# TODO(134): Fix the case that the length of source_datapipe is shorter than limit
3236
def __len__(self) -> int:
33-
return self.limit
37+
if self.length != -1:
38+
return self.length
39+
try:
40+
source_len = len(self.source_datapipe)
41+
self.length = min(source_len, self.limit)
42+
return self.length
43+
except TypeError:
44+
warn(
45+
"The length of this HeaderIterDataPipe is inferred to be equal to its limit."
46+
"The actual value may be smaller if the actual length of source_datapipe is smaller than the limit."
47+
)
48+
return self.limit

0 commit comments

Comments
 (0)