File tree 2 files changed +37
-7
lines changed
torchdata/datapipes/iter/util 2 files changed +37
-7
lines changed Original file line number Diff line number Diff line change @@ -307,9 +307,24 @@ def test_header_iterdatapipe(self) -> None:
307
307
# __len__ Test: returns the limit when it is less than the length of source
308
308
self .assertEqual (5 , len (header_dp ))
309
309
310
- # TODO(123): __len__ Test: returns the length of source when it is less than the limit
311
- # header_dp = source_dp.header(30)
312
- # self.assertEqual(20, len(header_dp))
310
+ # __len__ Test: returns the length of source when it is less than the limit
311
+ header_dp = source_dp .header (30 )
312
+ self .assertEqual (20 , len (header_dp ))
313
+
314
+ # __len__ Test: returns limit if source doesn't have length
315
+ source_dp_NoLen = IDP_NoLen (list (range (20 )))
316
+ header_dp = source_dp_NoLen .header (30 )
317
+ with warnings .catch_warnings (record = True ) as wa :
318
+ self .assertEqual (30 , len (header_dp ))
319
+ self .assertEqual (len (wa ), 1 )
320
+ self .assertRegex (
321
+ str (wa [0 ].message ), r"length of this HeaderIterDataPipe is inferred to be equal to its limit"
322
+ )
323
+
324
+ # __len__ Test: returns limit if source doesn't have length, but it has been iterated through once
325
+ for _ in header_dp :
326
+ pass
327
+ self .assertEqual (20 , len (header_dp ))
313
328
314
329
def test_enumerator_iterdatapipe (self ) -> None :
315
330
letters = "abcde"
Original file line number Diff line number Diff line change 1
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
2
from typing import Iterator , TypeVar
3
+ from warnings import warn
3
4
4
5
from torchdata .datapipes import functional_datapipe
5
6
from torchdata .datapipes .iter import IterDataPipe
@@ -20,14 +21,28 @@ class HeaderIterDataPipe(IterDataPipe[T_co]):
20
21
def __init__ (self , source_datapipe : IterDataPipe [T_co ], limit : int = 10 ) -> None :
21
22
self .source_datapipe : IterDataPipe [T_co ] = source_datapipe
22
23
self .limit : int = limit
24
+ self .length : int = - 1
23
25
24
26
def __iter__ (self ) -> Iterator [T_co ]:
25
- for i , value in enumerate (self .source_datapipe ):
26
- if i < self .limit :
27
+ i : int = 0
28
+ for value in self .source_datapipe :
29
+ i += 1
30
+ if i <= self .limit :
27
31
yield value
28
32
else :
29
33
break
34
+ self .length = min (i , self .limit ) # We know length with certainty when we reach here
30
35
31
- # TODO(134): Fix the case that the length of source_datapipe is shorter than limit
32
36
def __len__ (self ) -> int :
33
- return self .limit
37
+ if self .length != - 1 :
38
+ return self .length
39
+ try :
40
+ source_len = len (self .source_datapipe )
41
+ self .length = min (source_len , self .limit )
42
+ return self .length
43
+ except TypeError :
44
+ warn (
45
+ "The length of this HeaderIterDataPipe is inferred to be equal to its limit."
46
+ "The actual value may be smaller if the actual length of source_datapipe is smaller than the limit."
47
+ )
48
+ return self .limit
You can’t perform that action at this time.
0 commit comments