diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py index d55f08a5..3fc0c98e 100644 --- a/pre_commit_hooks/check_docstring_first.py +++ b/pre_commit_hooks/check_docstring_first.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import ast import io import tokenize from tokenize import tokenize as tokenize_tokenize @@ -12,6 +13,14 @@ )) +def _push_code(seen_code: io.StringIO, tok_type: int, text: str): + if tok_type == tokenize.ENCODING: + return + seen_code.write(text) + if text and not text.isspace(): + seen_code.write(' ') + + def check_docstring_first(src: bytes, filename: str = '') -> int: """Returns nonzero if the source has what looks like a docstring that is not at the beginning of the source. @@ -21,12 +30,18 @@ def check_docstring_first(src: bytes, filename: str = '') -> int: """ found_docstring_line = None found_code_line = None + seen_code = io.StringIO() tok_gen = tokenize_tokenize(io.BytesIO(src).readline) for tok_type, _, (sline, scol), _, _ in tok_gen: # Looks like a docstring! if tok_type == tokenize.STRING and scol == 0: if found_docstring_line is not None: + tree = ast.parse(seen_code.getvalue()) + assignments = ast.AnnAssign, ast.Assign, ast.AugAssign + if tree.body and isinstance(tree.body[-1], assignments): + return 0 + print( f'{filename}:{sline}: Multiple module docstrings ' f'(first docstring on line {found_docstring_line}).', @@ -41,7 +56,10 @@ def check_docstring_first(src: bytes, filename: str = '') -> int: else: found_docstring_line = sline elif tok_type not in NON_CODE_TOKENS and found_code_line is None: + _push_code(seen_code, tok_type, text) found_code_line = sline + else: + _push_code(seen_code, tok_type, text) return 0