Skip to content

Commit af31c60

Browse files
authored
Merge pull request #8540 from hauntsaninja/assert310
2 parents 8dd6462 + e3dc34e commit af31c60

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ Sankt Petersbug
277277
Segev Finer
278278
Serhii Mozghovyi
279279
Seth Junot
280+
Shantanu Jain
280281
Shubham Adep
281282
Simon Gomizelj
282283
Simon Kerr

changelog/8539.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed assertion rewriting on Python 3.10.

src/_pytest/assertion/rewrite.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -684,12 +684,9 @@ def run(self, mod: ast.Module) -> None:
684684
if not mod.body:
685685
# Nothing to do.
686686
return
687-
# Insert some special imports at the top of the module but after any
688-
# docstrings and __future__ imports.
689-
aliases = [
690-
ast.alias("builtins", "@py_builtins"),
691-
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
692-
]
687+
688+
# We'll insert some special imports at the top of the module, but after any
689+
# docstrings and __future__ imports, so first figure out where that is.
693690
doc = getattr(mod, "docstring", None)
694691
expect_docstring = doc is None
695692
if doc is not None and self.is_rewrite_disabled(doc):
@@ -721,10 +718,27 @@ def run(self, mod: ast.Module) -> None:
721718
lineno = item.decorator_list[0].lineno
722719
else:
723720
lineno = item.lineno
721+
# Now actually insert the special imports.
722+
if sys.version_info >= (3, 10):
723+
aliases = [
724+
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
725+
ast.alias(
726+
"_pytest.assertion.rewrite",
727+
"@pytest_ar",
728+
lineno=lineno,
729+
col_offset=0,
730+
),
731+
]
732+
else:
733+
aliases = [
734+
ast.alias("builtins", "@py_builtins"),
735+
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
736+
]
724737
imports = [
725738
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
726739
]
727740
mod.body[pos:pos] = imports
741+
728742
# Collect asserts.
729743
nodes: List[ast.AST] = [mod]
730744
while nodes:

0 commit comments

Comments
 (0)