Skip to content

Moving DataPipe buffers from __iter__ to instance (self) #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
46 changes: 39 additions & 7 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ def __init__(
if buffer_size is not None and buffer_size <= 0:
raise ValueError("'buffer_size' is required to be either None or a positive integer.")
self.buffer_size: int = buffer_size
self.buffer: OrderedDict = OrderedDict()

def __iter__(self) -> Iterator:
buffer: OrderedDict = OrderedDict()
ref_it = iter(self.ref_datapipe)
warn_once_flag = True
for data in self.source_datapipe:
key = self.key_fn(data)
while key not in buffer:
while key not in self.buffer:
try:
ref_data = next(ref_it)
except StopIteration:
Expand All @@ -92,18 +92,18 @@ def __iter__(self) -> Iterator:
"Please consider increasing the buffer size."
)
ref_key = self.ref_key_fn(ref_data)
if ref_key in buffer:
if ref_key in self.buffer:
raise ValueError("Duplicate key is found in reference DataPipe")
if self.buffer_size is not None and len(buffer) > self.buffer_size:
if self.buffer_size is not None and len(self.buffer) > self.buffer_size:
if warn_once_flag:
warn_once_flag = False
warnings.warn(
"Buffer reaches the upper limit, so reference key-data pair begins to "
"be removed from buffer in FIFO order. Please consider increase buffer size."
)
buffer.popitem(last=False)
buffer[ref_key] = ref_data
res = self.merge_fn(data, buffer.pop(key)) if self.merge_fn else (data, buffer.pop(key))
self.buffer.popitem(last=False)
self.buffer[ref_key] = ref_data
res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key))
if self.keep_key:
yield key, res
else:
Expand All @@ -112,6 +112,38 @@ def __iter__(self) -> Iterator:
def __len__(self) -> int:
return len(self.source_datapipe)

def reset(self) -> None:
self.buffer = OrderedDict()

def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
state = (
self.source_datapipe,
self.ref_datapipe,
self.key_fn,
self.ref_key_fn,
self.keep_key,
self.merge_fn,
self.buffer_size,
)
return state

def __setstate__(self, state):
(
self.source_datapipe,
self.ref_datapipe,
self.key_fn,
self.ref_key_fn,
self.keep_key,
self.merge_fn,
self.buffer_size,
) = state
self.buffer = OrderedDict()

def __del__(self):
self.buffer.clear()


@functional_datapipe("zip_with_map")
class MapKeyZipperIterDataPipe(IterDataPipe[T_co]):
Expand Down
32 changes: 24 additions & 8 deletions torchdata/datapipes/iter/util/paragraphaggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,38 @@ def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Call
self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe
_check_lambda_fn(joiner)
self.joiner: Callable = joiner
self.buffer: List = []

def __iter__(self) -> Iterator[Tuple[str, str]]:
buffer = []
prev_filename = None
for filename, line in self.source_datapipe:
if prev_filename is None:
prev_filename = filename
if line and prev_filename == filename:
buffer.append(line)
self.buffer.append(line)
else:
if buffer:
yield prev_filename, self.joiner(buffer) # type: ignore[misc]
if self.buffer:
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]
if line:
buffer = [line]
self.buffer = [line]
else:
buffer = []
self.buffer = []
prev_filename = filename
if buffer:
yield prev_filename, self.joiner(buffer) # type: ignore[misc]
if self.buffer:
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]

def reset(self) -> None:
self.buffer = []

def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
state = (self.source_datapipe, self.joiner)
return state

def __setstate__(self, state):
(self.source_datapipe, self.joiner) = state
self.buffer = []

def __del__(self):
self.buffer.clear()