Skip to content

Refactor test suite to be more readable? #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pmeier opened this issue Jan 20, 2022 · 6 comments
Open

Refactor test suite to be more readable? #175

pmeier opened this issue Jan 20, 2022 · 6 comments

Comments

@pmeier
Copy link
Contributor

pmeier commented Jan 20, 2022

While working on #174, I also worked on the test suite. In there we have the ginormous tests that are hard to parse, because they do so many things at the same time:

data/test/test_datapipe.py

Lines 382 to 426 in c06066a

def test_line_reader_iterdatapipe(self) -> None:
text1 = "Line1\nLine2"
text2 = "Line2,1\nLine2,2\nLine2,3"
# Functional Test: read lines correctly
source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
line_reader_dp = source_dp.readlines()
expected_result = [("file1", line) for line in text1.split("\n")] + [
("file2", line) for line in text2.split("\n")
]
self.assertEqual(expected_result, list(line_reader_dp))
# Functional Test: strip new lines for bytes
source_dp = IterableWrapper(
[("file1", io.BytesIO(text1.encode("utf-8"))), ("file2", io.BytesIO(text2.encode("utf-8")))]
)
line_reader_dp = source_dp.readlines()
expected_result_bytes = [("file1", line.encode("utf-8")) for line in text1.split("\n")] + [
("file2", line.encode("utf-8")) for line in text2.split("\n")
]
self.assertEqual(expected_result_bytes, list(line_reader_dp))
# Functional Test: do not strip new lines
source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
line_reader_dp = source_dp.readlines(strip_newline=False)
expected_result = [
("file1", "Line1\n"),
("file1", "Line2"),
("file2", "Line2,1\n"),
("file2", "Line2,2\n"),
("file2", "Line2,3"),
]
self.assertEqual(expected_result, list(line_reader_dp))
# Reset Test:
source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
line_reader_dp = LineReader(source_dp, strip_newline=False)
n_elements_before_reset = 2
res_before_reset, res_after_reset = reset_after_n_next_calls(line_reader_dp, n_elements_before_reset)
self.assertEqual(expected_result[:n_elements_before_reset], res_before_reset)
self.assertEqual(expected_result, res_after_reset)
# __len__ Test: length isn't implemented since it cannot be known ahead of time
with self.assertRaisesRegex(TypeError, "has no len"):
len(line_reader_dp)

I was wondering if there is a reason for that. Can't we split this into multiple smaller ones? Utilizing pytest, placing the following class in the test module is equivalent to the test above:

class TestLineReader:
    @pytest.fixture
    def text1(self):
        return "Line1\nLine2"

    @pytest.fixture
    def text2(self):
        return "Line2,1\nLine2,2\nLine2,3"

    def test_functional_read_lines_correctly(self, text1, text2):
        source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
        line_reader_dp = source_dp.readlines()
        expected_result = [("file1", line) for line in text1.split("\n")] + [
            ("file2", line) for line in text2.split("\n")
        ]
        assert expected_result == list(line_reader_dp)

    def test_functional_strip_new_lines_for_bytes(self, text1, text2):
        source_dp = IterableWrapper(
            [("file1", io.BytesIO(text1.encode("utf-8"))), ("file2", io.BytesIO(text2.encode("utf-8")))]
        )
        line_reader_dp = source_dp.readlines()
        expected_result_bytes = [("file1", line.encode("utf-8")) for line in text1.split("\n")] + [
            ("file2", line.encode("utf-8")) for line in text2.split("\n")
        ]
        assert expected_result_bytes == list(line_reader_dp)

    def test_functional_do_not_strip_newlines(self, text1, text2):
        source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
        line_reader_dp = source_dp.readlines(strip_newline=False)
        expected_result = [
            ("file1", "Line1\n"),
            ("file1", "Line2"),
            ("file2", "Line2,1\n"),
            ("file2", "Line2,2\n"),
            ("file2", "Line2,3"),
        ]
        assert expected_result == list(line_reader_dp)

    def test_reset(self, text1, text2):
        source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
        line_reader_dp = LineReader(source_dp, strip_newline=False)
        expected_result = [
            ("file1", "Line1\n"),
            ("file1", "Line2"),
            ("file2", "Line2,1\n"),
            ("file2", "Line2,2\n"),
            ("file2", "Line2,3"),
        ]

        n_elements_before_reset = 2
        res_before_reset, res_after_reset = reset_after_n_next_calls(line_reader_dp, n_elements_before_reset)
        assert expected_result[:n_elements_before_reset] == res_before_reset
        assert expected_result == res_after_reset

    def test_len(self, text1, text2):
        source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
        line_reader_dp = LineReader(source_dp, strip_newline=False)

        with pytest.raises(TypeError, match="has no len"):
            len(line_reader_dp)

This is a lot more readable, since we now actually have 5 separate test cases that can individually fail. Plus, while writing this I also found that test_reset and test_len were somewhat dependent on test_functional_do_not_strip_newlines since they don't neither define line_reader_dp nor expected_result themselves.

@pmeier
Copy link
Contributor Author

pmeier commented Jan 20, 2022

Or even more readable:

class TestLineReader:
    @pytest.fixture
    def files_with_text(self):
        return [
            ("file1", "Line1\nLine2"),
            ("file2", "Line2,1\nLine2,2\nLine2,3"),
        ]

    def make_str_dp(self, files_with_text):
        return IterableWrapper([(file, io.StringIO(text)) for file, text in files_with_text])

    def make_bytes_dp(self, files_with_text):
        return IterableWrapper([(file, io.BytesIO(text.encode("utf-8"))) for file, text in files_with_text])

    def test_functional_read_lines_correctly(self, files_with_text):
        line_reader_dp = self.make_str_dp(files_with_text).readlines()

        expected = []
        for file, text in files_with_text:
            expected.extend((file, line) for line in text.splitlines())

        assert expected == list(line_reader_dp)

    def test_functional_strip_new_lines_for_bytes(self, files_with_text):
        line_reader_dp = self.make_bytes_dp(files_with_text).readlines()

        expected = []
        for file, text in files_with_text:
            expected.extend((file, line.encode("utf-8")) for line in text.splitlines())

        assert expected == list(line_reader_dp)

    def test_functional_do_not_strip_newlines(self, files_with_text):
        line_reader_dp = self.make_str_dp(files_with_text).readlines(strip_newline=False)

        expected = []
        for file, text in files_with_text:
            expected.extend((file, line) for line in text.splitlines(keepends=True))

        assert expected == list(line_reader_dp)

    def test_reset(self, files_with_text):
        line_reader_dp = LineReader(self.make_str_dp(files_with_text))

        expected = []
        for file, text in files_with_text:
            expected.extend((file, line) for line in text.splitlines())

        n_elements_before_reset = 2
        res_before_reset, res_after_reset = reset_after_n_next_calls(line_reader_dp, n_elements_before_reset)

        assert expected[:n_elements_before_reset] == res_before_reset
        assert expected == res_after_reset

    def test_len(self, files_with_text):
        line_reader_dp = LineReader(self.make_str_dp(files_with_text))

        with pytest.raises(TypeError, match="has no len"):
            len(line_reader_dp)

@ejguan
Copy link
Contributor

ejguan commented Jan 20, 2022

I like this idea!
cc: @NivekT Do you want to incorporate this into your PR pytorch/pytorch#70215

@pmeier
Copy link
Contributor Author

pmeier commented Jan 20, 2022

Ah, that might be an issue. In PyTorch core you cannot rely on pytest so if you want to have this there, you need to adapt what I proposed a little:

  • For unittest each test case needs to inherit from unittest.TestCase or any other derivative.
  • @pytest.fixture's are not available. A workaround might be to store the files_with_text in a class constant and access it from there.

@ejguan
Copy link
Contributor

ejguan commented Jan 20, 2022

@pytest.fixture's are not available. A workaround might be to store the files_with_text in a class constant and access it from there.

I believe we can do setupClass for this case.

@NivekT
Copy link
Contributor

NivekT commented Jan 20, 2022

Thanks for the suggestion! I think this is cleaner than what we have. It will take quite a bit of manual refactoring of each DataPipe to get there.

I am wondering if we can do something even better - a standard template to test out DataPipe with less manual code writing (maybe just specifying the inputs), similar to what OpsInfo does in PyTorch Core.

@erip
Copy link
Contributor

erip commented Jan 21, 2022

FWIW, we've started something similar in torchtext. See here if you're interested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants