Skip to content

Commit 1de5e55

Browse files
authored
Fix crashes in class scoped imports (#12023)
Fixes #11045, fixes huggingface/transformers#13390 Fixes #10488 Fixes #7045 Fixes #7806 Fixes #11641 Fixes #11351 Fixes #10488 Co-authored-by: @A5rocks
1 parent 84aaef5 commit 1de5e55

File tree

5 files changed

+149
-8
lines changed

5 files changed

+149
-8
lines changed

mypy/semanal.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@
4848
reduce memory use).
4949
"""
5050

51+
import copy
5152
from contextlib import contextmanager
5253

5354
from typing import (
54-
List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable, Iterator, Iterable
55+
Any, List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable, Iterator, Iterable
5556
)
5657
from typing_extensions import Final, TypeAlias as _TypeAlias
5758

@@ -78,7 +79,7 @@
7879
typing_extensions_aliases,
7980
EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr,
8081
ParamSpecExpr, EllipsisExpr, TypeVarLikeExpr, implicit_module_attrs,
81-
MatchStmt,
82+
MatchStmt, FuncBase
8283
)
8384
from mypy.patterns import (
8485
AsPattern, OrPattern, ValuePattern, SequencePattern,
@@ -4798,7 +4799,38 @@ def add_imported_symbol(self,
47984799
module_hidden: bool) -> None:
47994800
"""Add an alias to an existing symbol through import."""
48004801
assert not module_hidden or not module_public
4801-
symbol = SymbolTableNode(node.kind, node.node,
4802+
4803+
symbol_node: Optional[SymbolNode] = node.node
4804+
4805+
if self.is_class_scope():
4806+
# I promise this type checks; I'm just making mypyc issues go away.
4807+
# mypyc is absolutely convinced that `symbol_node` narrows to a Var in the following,
4808+
# when it can also be a FuncBase. Once fixed, `f` in the following can be removed.
4809+
# See also https://github.com/mypyc/mypyc/issues/892
4810+
f = cast(Any, lambda x: x)
4811+
if isinstance(f(symbol_node), (FuncBase, Var)):
4812+
# For imports in class scope, we construct a new node to represent the symbol and
4813+
# set its `info` attribute to `self.type`.
4814+
existing = self.current_symbol_table().get(name)
4815+
if (
4816+
# The redefinition checks in `add_symbol_table_node` don't work for our
4817+
# constructed Var / FuncBase, so check for possible redefinitions here.
4818+
existing is not None
4819+
and isinstance(f(existing.node), (FuncBase, Var))
4820+
and f(existing.type) == f(symbol_node).type
4821+
):
4822+
symbol_node = existing.node
4823+
else:
4824+
# Construct the new node
4825+
constructed_node = copy.copy(f(symbol_node))
4826+
assert self.type is not None # guaranteed by is_class_scope
4827+
constructed_node.line = context.line
4828+
constructed_node.column = context.column
4829+
constructed_node.info = self.type
4830+
constructed_node._fullname = self.qualified_name(name)
4831+
symbol_node = constructed_node
4832+
4833+
symbol = SymbolTableNode(node.kind, symbol_node,
48024834
module_public=module_public,
48034835
module_hidden=module_hidden)
48044836
self.add_symbol_table_node(name, symbol, context)

test-data/unit/check-classes.test

+109
Original file line numberDiff line numberDiff line change
@@ -7134,3 +7134,112 @@ class B(A): # E: Final class __main__.B has abstract attributes "foo"
71347134
[case testUndefinedBaseclassInNestedClass]
71357135
class C:
71367136
class C1(XX): pass # E: Name "XX" is not defined
7137+
7138+
[case testClassScopeImportFunction]
7139+
class Foo:
7140+
from mod import foo
7141+
7142+
reveal_type(Foo.foo) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
7143+
reveal_type(Foo().foo) # E: Invalid self argument "Foo" to attribute function "foo" with type "Callable[[int, int], int]" \
7144+
# N: Revealed type is "def (y: builtins.int) -> builtins.int"
7145+
[file mod.py]
7146+
def foo(x: int, y: int) -> int: ...
7147+
7148+
[case testClassScopeImportVariable]
7149+
class Foo:
7150+
from mod import foo
7151+
7152+
reveal_type(Foo.foo) # N: Revealed type is "builtins.int"
7153+
reveal_type(Foo().foo) # N: Revealed type is "builtins.int"
7154+
[file mod.py]
7155+
foo: int
7156+
7157+
[case testClassScopeImportModule]
7158+
class Foo:
7159+
import mod
7160+
7161+
reveal_type(Foo.mod) # N: Revealed type is "builtins.object"
7162+
reveal_type(Foo.mod.foo) # N: Revealed type is "builtins.int"
7163+
[file mod.py]
7164+
foo: int
7165+
7166+
[case testClassScopeImportFunctionAlias]
7167+
class Foo:
7168+
from mod import foo
7169+
bar = foo
7170+
7171+
from mod import const_foo
7172+
const_bar = const_foo
7173+
7174+
reveal_type(Foo.foo) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
7175+
reveal_type(Foo.bar) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
7176+
reveal_type(Foo.const_foo) # N: Revealed type is "builtins.int"
7177+
reveal_type(Foo.const_bar) # N: Revealed type is "builtins.int"
7178+
[file mod.py]
7179+
def foo(x: int, y: int) -> int: ...
7180+
const_foo: int
7181+
7182+
[case testClassScopeImportModuleStar]
7183+
class Foo:
7184+
from mod import *
7185+
7186+
reveal_type(Foo.foo) # N: Revealed type is "builtins.int"
7187+
reveal_type(Foo.bar) # N: Revealed type is "def (x: builtins.int) -> builtins.int"
7188+
reveal_type(Foo.baz) # E: "Type[Foo]" has no attribute "baz" \
7189+
# N: Revealed type is "Any"
7190+
[file mod.py]
7191+
foo: int
7192+
def bar(x: int) -> int: ...
7193+
7194+
[case testClassScopeImportFunctionNested]
7195+
class Foo:
7196+
class Bar:
7197+
from mod import baz
7198+
7199+
reveal_type(Foo.Bar.baz) # N: Revealed type is "def (x: builtins.int) -> builtins.int"
7200+
reveal_type(Foo.Bar().baz) # E: Invalid self argument "Bar" to attribute function "baz" with type "Callable[[int], int]" \
7201+
# N: Revealed type is "def () -> builtins.int"
7202+
[file mod.py]
7203+
def baz(x: int) -> int: ...
7204+
7205+
[case testClassScopeImportUndefined]
7206+
class Foo:
7207+
from unknown import foo # E: Cannot find implementation or library stub for module named "unknown" \
7208+
# N: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports
7209+
7210+
reveal_type(Foo.foo) # N: Revealed type is "Any"
7211+
reveal_type(Foo().foo) # N: Revealed type is "Any"
7212+
7213+
[case testClassScopeImportWithFollowImports]
7214+
# flags: --follow-imports=skip
7215+
class Foo:
7216+
from mod import foo
7217+
7218+
reveal_type(Foo().foo) # N: Revealed type is "Any"
7219+
[file mod.py]
7220+
def foo(x: int, y: int) -> int: ...
7221+
7222+
[case testClassScopeImportVarious]
7223+
class Foo:
7224+
from mod1 import foo
7225+
from mod2 import foo # E: Name "foo" already defined on line 2
7226+
7227+
from mod1 import meth1
7228+
def meth1(self, a: str) -> str: ... # E: Name "meth1" already defined on line 5
7229+
7230+
def meth2(self, a: str) -> str: ...
7231+
from mod1 import meth2 # E: Name "meth2" already defined on line 8
7232+
7233+
class Bar:
7234+
from mod1 import foo
7235+
7236+
import mod1
7237+
reveal_type(Foo.foo) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
7238+
reveal_type(Bar.foo) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
7239+
reveal_type(mod1.foo) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
7240+
[file mod1.py]
7241+
def foo(x: int, y: int) -> int: ...
7242+
def meth1(x: int) -> int: ...
7243+
def meth2(x: int) -> int: ...
7244+
[file mod2.py]
7245+
def foo(z: str) -> int: ...

test-data/unit/check-modules.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def f() -> None: pass
131131
[case testImportWithinClassBody2]
132132
import typing
133133
class C:
134-
from m import f
134+
from m import f # E: Method must have at least one argument
135135
f()
136-
f(C) # E: Too many arguments for "f"
136+
f(C) # E: Too many arguments for "f" of "C"
137137
[file m.py]
138138
def f() -> None: pass
139139
[out]

test-data/unit/check-newsemanal.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -2722,7 +2722,7 @@ import m
27222722

27232723
[file m.py]
27242724
class C:
2725-
from mm import f
2725+
from mm import f # E: Method must have at least one argument
27262726
@dec(f)
27272727
def m(self): pass
27282728

@@ -2742,7 +2742,7 @@ import m
27422742

27432743
[file m/__init__.py]
27442744
class C:
2745-
from m.m import f
2745+
from m.m import f # E: Method must have at least one argument
27462746
@dec(f)
27472747
def m(self): pass
27482748

test-data/unit/semanal-modules.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ MypyFile:1(
568568
ImportFrom:2(_x, [y])
569569
AssignmentStmt:3(
570570
NameExpr(z* [m])
571-
NameExpr(y [_x.y]))))
571+
NameExpr(y [__main__.A.y]))))
572572

573573
[case testImportInClassBody2]
574574
class A:

0 commit comments

Comments
 (0)