From 38ee4dfde79dc187fa4532d0d5e1aa31a1579c29 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 20 Jan 2022 09:49:46 +0100 Subject: [PATCH] fix newline stripping in plain text readers --- test/test_datapipe.py | 14 +++++++------- torchdata/datapipes/iter/util/plain_text_reader.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 580845245..77707f4d9 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -381,13 +381,13 @@ def dict_content_test_helper(iterator): def test_line_reader_iterdatapipe(self) -> None: text1 = "Line1\nLine2" - text2 = "Line2,1\nLine2,2\nLine2,3" + text2 = "Line2,1\r\nLine2,2\r\nLine2,3" # Functional Test: read lines correctly source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))]) line_reader_dp = source_dp.readlines() - expected_result = [("file1", line) for line in text1.split("\n")] + [ - ("file2", line) for line in text2.split("\n") + expected_result = [("file1", line) for line in text1.splitlines()] + [ + ("file2", line) for line in text2.splitlines() ] self.assertEqual(expected_result, list(line_reader_dp)) @@ -396,8 +396,8 @@ def test_line_reader_iterdatapipe(self) -> None: [("file1", io.BytesIO(text1.encode("utf-8"))), ("file2", io.BytesIO(text2.encode("utf-8")))] ) line_reader_dp = source_dp.readlines() - expected_result_bytes = [("file1", line.encode("utf-8")) for line in text1.split("\n")] + [ - ("file2", line.encode("utf-8")) for line in text2.split("\n") + expected_result_bytes = [("file1", line.encode("utf-8")) for line in text1.splitlines()] + [ + ("file2", line.encode("utf-8")) for line in text2.splitlines() ] self.assertEqual(expected_result_bytes, list(line_reader_dp)) @@ -407,8 +407,8 @@ def test_line_reader_iterdatapipe(self) -> None: expected_result = [ ("file1", "Line1\n"), ("file1", "Line2"), - ("file2", "Line2,1\n"), - ("file2", "Line2,2\n"), + ("file2", "Line2,1\r\n"), + ("file2", "Line2,2\r\n"), ("file2", "Line2,3"), ] self.assertEqual(expected_result, list(line_reader_dp)) diff --git a/torchdata/datapipes/iter/util/plain_text_reader.py b/torchdata/datapipes/iter/util/plain_text_reader.py index 4a584ed1a..865996a3b 100644 --- a/torchdata/datapipes/iter/util/plain_text_reader.py +++ b/torchdata/datapipes/iter/util/plain_text_reader.py @@ -44,9 +44,9 @@ def strip_newline(self, stream: Union[Iterator[bytes], Iterator[str]]) -> Union[ for line in stream: if isinstance(line, str): - yield line.strip("\n") + yield line.strip("\r\n") else: - yield line.strip(b"\n") + yield line.strip(b"\r\n") def decode(self, stream: Union[Iterator[bytes], Iterator[str]]) -> Union[Iterator[bytes], Iterator[str]]: if not self._decode: