Skip to content

Commit fd8e60c

Browse files
committed
fix newline stripping in plain text readers (#174)
Summary: Fixes #173 Note that the [input to `strip`](https://docs.python.org/3/library/stdtypes.html#str.strip) > is a string specifying the **set of characters** to be removed. [Emphasis mine] Thus, stripping works something like ```python for char in chars: string.replace(char, "") ``` rather than ```python string.replace(chars, "") ``` This means that always stripping `"\r\n"` is harmless even if the line terminator is only `"\n"` or `\"r"`. Reviewed By: ejguan Differential Revision: D33684458 Pulled By: NivekT fbshipit-source-id: 9821b77d60d3afe038ae698965beefe319783aa1 ghstack-source-id: 37a119b Pull Request resolved: #176
1 parent 31ab2d6 commit fd8e60c

File tree

2 files changed

+13
-43
lines changed

2 files changed

+13
-43
lines changed

test/test_datapipe.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -520,14 +520,12 @@ def test_bucket_batcher_iterdatapipe(self) -> None:
520520
batch_dp = source_dp.bucketbatch(
521521
batch_size=3, drop_last=True, batch_num=100, bucket_num=1, in_batch_shuffle=True
522522
)
523-
self.assertEqual(3, len(batch_dp))
524523
self.assertEqual(9, len(list(batch_dp.unbatch())))
525524

526525
# Functional Test: drop last is False preserves length
527526
batch_dp = source_dp.bucketbatch(
528527
batch_size=3, drop_last=False, batch_num=100, bucket_num=1, in_batch_shuffle=False
529528
)
530-
self.assertEqual(4, len(batch_dp))
531529
self.assertEqual(10, len(list(batch_dp.unbatch())))
532530

533531
def _return_self(x):
@@ -539,14 +537,12 @@ def _return_self(x):
539537
)
540538
# bucket_num = 1 means there will be no shuffling if a sort key is given
541539
self.assertEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], list(batch_dp))
542-
self.assertEqual(3, len(batch_dp))
543540
self.assertEqual(9, len(list(batch_dp.unbatch())))
544541

545542
# Functional Test: using sort_key, without in_batch_shuffle
546543
batch_dp = source_dp.bucketbatch(
547544
batch_size=3, drop_last=True, batch_num=100, bucket_num=2, in_batch_shuffle=False, sort_key=_return_self
548545
)
549-
self.assertEqual(3, len(batch_dp))
550546
self.assertEqual(9, len(list(batch_dp.unbatch())))
551547

552548
# Reset Test:
@@ -567,7 +563,8 @@ def _return_self(x):
567563
self.assertEqual(9, len([item for batch in res_after_reset for item in batch]))
568564

569565
# __len__ Test: returns the number of batches
570-
self.assertEqual(3, len(batch_dp))
566+
with self.assertRaises(TypeError):
567+
len(batch_dp)
571568

572569

573570
if __name__ == "__main__":
Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
import random
3-
from typing import Callable, Iterator, Optional, Sized, TypeVar
3+
from typing import Callable, Optional, TypeVar
44

55
from torch.utils.data import DataChunk
66

@@ -42,44 +42,32 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
4242
in_batch_shuffle: bool
4343
length: Optional[int]
4444

45-
def __init__(
46-
self,
45+
def __new__(
46+
cls,
4747
datapipe: IterDataPipe[T_co],
4848
batch_size: int,
4949
drop_last: bool = False,
5050
batch_num: int = 100,
5151
bucket_num: int = 1,
5252
sort_key: Optional[Callable] = None,
5353
in_batch_shuffle: bool = True,
54-
) -> None:
54+
) -> IterDataPipe:
5555
assert batch_size > 0, "Batch size is required to be larger than 0!"
5656
assert batch_num > 0, "Number of batches is required to be larger than 0!"
5757
assert bucket_num > 0, "Number of buckets is required to be larger than 0!"
58-
super().__init__()
5958

60-
# TODO(136): Verify _datapippe is not going to be serialized twice and is able to reconstruct
61-
self._datapipe: IterDataPipe[T_co] = datapipe
62-
self.batch_size: int = batch_size
63-
self.drop_last: bool = drop_last
64-
self.batch_num: int = batch_num
65-
self.bucket_num: int = bucket_num
66-
self.sort_key: Optional[Callable] = sort_key
67-
self.in_batch_shuffle: bool = in_batch_shuffle
68-
69-
self.bucket_size = batch_size * batch_num
70-
self.pool_size = self.bucket_size * bucket_num
59+
bucket_size = batch_size * batch_num
60+
pool_size = bucket_size * bucket_num
7161

7262
# Shuffle by pool_size
7363
if bucket_num > 1 or sort_key is None:
7464
if in_batch_shuffle:
75-
datapipe = (
76-
datapipe.batch(batch_size=self.pool_size, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
77-
)
65+
datapipe = datapipe.batch(batch_size=pool_size, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
7866
else:
79-
datapipe = datapipe.shuffle(buffer_size=self.pool_size)
67+
datapipe = datapipe.shuffle(buffer_size=pool_size)
8068
# Sort by bucket_size if sort_key is given
8169
if sort_key is not None:
82-
datapipe = datapipe.batch(self.bucket_size).map(fn=sort_key).unbatch()
70+
datapipe = datapipe.batch(bucket_size).map(fn=sort_key).unbatch()
8371
# Batch and drop last (if needed)
8472
datapipe = datapipe.batch(batch_size, drop_last=drop_last)
8573
# Shuffle the batched data
@@ -88,20 +76,5 @@ def __init__(
8876
if in_batch_shuffle:
8977
datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
9078
else:
91-
datapipe = datapipe.shuffle(buffer_size=self.bucket_size)
92-
self.datapipe = datapipe
93-
self.length = None
94-
95-
def __iter__(self) -> Iterator:
96-
yield from self.datapipe
97-
98-
def __len__(self) -> int:
99-
if self.length is not None:
100-
return self.length
101-
if isinstance(self._datapipe, Sized):
102-
if self.drop_last:
103-
self.length = len(self._datapipe) // self.batch_size
104-
else:
105-
self.length = (len(self._datapipe) + self.batch_size - 1) // self.batch_size
106-
return self.length
107-
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
79+
datapipe = datapipe.shuffle(buffer_size=bucket_size)
80+
return datapipe

0 commit comments

Comments
 (0)