Skip to content

Commit 6a8415b

Browse files
NivekTfacebook-github-bot
authored andcommitted
Moving DataPipe buffers from __iter__ to instance (self) (#388)
Summary: Pull Request resolved: #388 This PR impacts `ParagraphAggregator` and `IterKeyZipper` as they previously have their buffer with the iterator rather than the instance. This should be landed along with pytorch/pytorch#76999 and pytorch/pytorch#77775 Note that the CI is expected to fail until that PR goes into nightly Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D36519522 Pulled By: NivekT fbshipit-source-id: cc3616e1941777771e8c8b5cd26953f2f6ed48c8
1 parent 7362f7d commit 6a8415b

File tree

2 files changed

+63
-15
lines changed

2 files changed

+63
-15
lines changed

torchdata/datapipes/iter/util/combining.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ def __init__(
7676
if buffer_size is not None and buffer_size <= 0:
7777
raise ValueError("'buffer_size' is required to be either None or a positive integer.")
7878
self.buffer_size: int = buffer_size
79+
self.buffer: OrderedDict = OrderedDict()
7980

8081
def __iter__(self) -> Iterator:
81-
buffer: OrderedDict = OrderedDict()
8282
ref_it = iter(self.ref_datapipe)
8383
warn_once_flag = True
8484
for data in self.source_datapipe:
8585
key = self.key_fn(data)
86-
while key not in buffer:
86+
while key not in self.buffer:
8787
try:
8888
ref_data = next(ref_it)
8989
except StopIteration:
@@ -92,18 +92,18 @@ def __iter__(self) -> Iterator:
9292
"Please consider increasing the buffer size."
9393
)
9494
ref_key = self.ref_key_fn(ref_data)
95-
if ref_key in buffer:
95+
if ref_key in self.buffer:
9696
raise ValueError("Duplicate key is found in reference DataPipe")
97-
if self.buffer_size is not None and len(buffer) > self.buffer_size:
97+
if self.buffer_size is not None and len(self.buffer) > self.buffer_size:
9898
if warn_once_flag:
9999
warn_once_flag = False
100100
warnings.warn(
101101
"Buffer reaches the upper limit, so reference key-data pair begins to "
102102
"be removed from buffer in FIFO order. Please consider increase buffer size."
103103
)
104-
buffer.popitem(last=False)
105-
buffer[ref_key] = ref_data
106-
res = self.merge_fn(data, buffer.pop(key)) if self.merge_fn else (data, buffer.pop(key))
104+
self.buffer.popitem(last=False)
105+
self.buffer[ref_key] = ref_data
106+
res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key))
107107
if self.keep_key:
108108
yield key, res
109109
else:
@@ -112,6 +112,38 @@ def __iter__(self) -> Iterator:
112112
def __len__(self) -> int:
113113
return len(self.source_datapipe)
114114

115+
def reset(self) -> None:
116+
self.buffer = OrderedDict()
117+
118+
def __getstate__(self):
119+
if IterDataPipe.getstate_hook is not None:
120+
return IterDataPipe.getstate_hook(self)
121+
state = (
122+
self.source_datapipe,
123+
self.ref_datapipe,
124+
self.key_fn,
125+
self.ref_key_fn,
126+
self.keep_key,
127+
self.merge_fn,
128+
self.buffer_size,
129+
)
130+
return state
131+
132+
def __setstate__(self, state):
133+
(
134+
self.source_datapipe,
135+
self.ref_datapipe,
136+
self.key_fn,
137+
self.ref_key_fn,
138+
self.keep_key,
139+
self.merge_fn,
140+
self.buffer_size,
141+
) = state
142+
self.buffer = OrderedDict()
143+
144+
def __del__(self):
145+
self.buffer.clear()
146+
115147

116148
@functional_datapipe("zip_with_map")
117149
class MapKeyZipperIterDataPipe(IterDataPipe[T_co]):

torchdata/datapipes/iter/util/paragraphaggregator.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,38 @@ def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Call
4646
self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe
4747
_check_lambda_fn(joiner)
4848
self.joiner: Callable = joiner
49+
self.buffer: List = []
4950

5051
def __iter__(self) -> Iterator[Tuple[str, str]]:
51-
buffer = []
5252
prev_filename = None
5353
for filename, line in self.source_datapipe:
5454
if prev_filename is None:
5555
prev_filename = filename
5656
if line and prev_filename == filename:
57-
buffer.append(line)
57+
self.buffer.append(line)
5858
else:
59-
if buffer:
60-
yield prev_filename, self.joiner(buffer) # type: ignore[misc]
59+
if self.buffer:
60+
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]
6161
if line:
62-
buffer = [line]
62+
self.buffer = [line]
6363
else:
64-
buffer = []
64+
self.buffer = []
6565
prev_filename = filename
66-
if buffer:
67-
yield prev_filename, self.joiner(buffer) # type: ignore[misc]
66+
if self.buffer:
67+
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]
68+
69+
def reset(self) -> None:
70+
self.buffer = []
71+
72+
def __getstate__(self):
73+
if IterDataPipe.getstate_hook is not None:
74+
return IterDataPipe.getstate_hook(self)
75+
state = (self.source_datapipe, self.joiner)
76+
return state
77+
78+
def __setstate__(self, state):
79+
(self.source_datapipe, self.joiner) = state
80+
self.buffer = []
81+
82+
def __del__(self):
83+
self.buffer.clear()

0 commit comments

Comments
 (0)