diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 37e195f5e0b1..815e2ca281eb 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -52,9 +52,10 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' from __future__ import annotations -from typing import Sequence, Tuple +from typing import Sequence, Tuple, cast from typing_extensions import TypeAlias as _TypeAlias +from mypy.expandtype import expand_type from mypy.nodes import ( UNBOUND_IMPORTED, Decorator, @@ -88,6 +89,8 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' TypeAliasType, TypedDictType, TypeType, + TypeVarId, + TypeVarLikeType, TypeVarTupleType, TypeVarType, TypeVisitor, @@ -388,7 +391,8 @@ def visit_parameters(self, typ: Parameters) -> SnapshotItem: ) def visit_callable_type(self, typ: CallableType) -> SnapshotItem: - # FIX generics + if typ.is_generic(): + typ = self.normalize_callable_variables(typ) return ( "CallableType", snapshot_types(typ.arg_types), @@ -397,8 +401,26 @@ def visit_callable_type(self, typ: CallableType) -> SnapshotItem: tuple(typ.arg_kinds), typ.is_type_obj(), typ.is_ellipsis_args, + snapshot_types(typ.variables), ) + def normalize_callable_variables(self, typ: CallableType) -> CallableType: + """Normalize all type variable ids to run from -1 to -len(variables).""" + tvs = [] + tvmap: dict[TypeVarId, Type] = {} + for i, v in enumerate(typ.variables): + tid = TypeVarId(-1 - i) + if isinstance(v, TypeVarType): + tv: TypeVarLikeType = v.copy_modified(id=tid) + elif isinstance(v, TypeVarTupleType): + tv = v.copy_modified(id=tid) + else: + assert isinstance(v, ParamSpecType) + tv = v.copy_modified(id=tid) + tvs.append(tv) + tvmap[v.id] = tv + return cast(CallableType, expand_type(typ, tvmap)).copy_modified(variables=tvs) + def visit_tuple_type(self, typ: TupleType) -> SnapshotItem: return ("TupleType", snapshot_types(typ.items)) diff --git a/mypy/types.py b/mypy/types.py index 9fb0ede51a68..85d8ccd1c7d9 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -517,15 +517,23 @@ def __init__( @staticmethod def new_unification_variable(old: TypeVarType) -> TypeVarType: new_id = TypeVarId.new(meta_level=1) + return old.copy_modified(id=new_id) + + def copy_modified( + self, + values: Bogus[list[Type]] = _dummy, + upper_bound: Bogus[Type] = _dummy, + id: Bogus[TypeVarId | int] = _dummy, + ) -> TypeVarType: return TypeVarType( - old.name, - old.fullname, - new_id, - old.values, - old.upper_bound, - old.variance, - old.line, - old.column, + self.name, + self.fullname, + self.id if id is _dummy else id, + self.values if values is _dummy else values, + self.upper_bound if upper_bound is _dummy else upper_bound, + self.variance, + self.line, + self.column, ) def accept(self, visitor: TypeVisitor[T]) -> T: @@ -616,16 +624,7 @@ def __init__( @staticmethod def new_unification_variable(old: ParamSpecType) -> ParamSpecType: new_id = TypeVarId.new(meta_level=1) - return ParamSpecType( - old.name, - old.fullname, - new_id, - old.flavor, - old.upper_bound, - line=old.line, - column=old.column, - prefix=old.prefix, - ) + return old.copy_modified(id=new_id) def with_flavor(self, flavor: int) -> ParamSpecType: return ParamSpecType( @@ -737,8 +736,16 @@ def __eq__(self, other: object) -> bool: @staticmethod def new_unification_variable(old: TypeVarTupleType) -> TypeVarTupleType: new_id = TypeVarId.new(meta_level=1) + return old.copy_modified(id=new_id) + + def copy_modified(self, id: Bogus[TypeVarId | int] = _dummy) -> TypeVarTupleType: return TypeVarTupleType( - old.name, old.fullname, new_id, old.upper_bound, line=old.line, column=old.column + self.name, + self.fullname, + self.id if id is _dummy else id, + self.upper_bound, + line=self.line, + column=self.column, ) diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index 0e443abc7237..6a9b060e9f07 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -9969,3 +9969,65 @@ m.py:9: note: Expected: m.py:9: note: def update() -> bool m.py:9: note: Got: m.py:9: note: def update() -> str + +[case testBoundGenericMethodFine] +import main +[file main.py] +import lib +[file main.py.3] +import lib +reveal_type(lib.foo(42)) +[file lib/__init__.pyi] +from lib import context +foo = context.test.foo +[file lib/context.pyi] +from typing import TypeVar +import lib.other + +T = TypeVar("T") +class Test: + def foo(self, x: T, n: lib.other.C = ...) -> T: ... +test: Test + +[file lib/other.pyi] +class C: ... +[file lib/other.pyi.2] +class B: ... +class C(B): ... +[out] +== +== +main.py:2: note: Revealed type is "builtins.int" + +[case testBoundGenericMethodParamSpecFine] +import main +[file main.py] +import lib +[file main.py.3] +from typing import Callable +import lib +f: Callable[[], int] +reveal_type(lib.foo(f)) +[file lib/__init__.pyi] +from lib import context +foo = context.test.foo +[file lib/context.pyi] +from typing_extensions import ParamSpec +from typing import Callable +import lib.other + +P = ParamSpec("P") +class Test: + def foo(self, x: Callable[P, int], n: lib.other.C = ...) -> Callable[P, str]: ... +test: Test + +[file lib/other.pyi] +class C: ... +[file lib/other.pyi.2] +class B: ... +class C(B): ... +[builtins fixtures/dict.pyi] +[out] +== +== +main.py:4: note: Revealed type is "def () -> builtins.str"