diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 3032aff9d061..b40806917644 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -36,16 +36,21 @@ def check_type_parameter(lefta: Type, righta: Type, variance: int) -> bool: if variance == COVARIANT: - return is_subtype(lefta, righta, check_type_parameter) + return is_subtype(lefta, righta) elif variance == CONTRAVARIANT: - return is_subtype(righta, lefta, check_type_parameter) + return is_subtype(righta, lefta) else: - return is_equivalent(lefta, righta, check_type_parameter) + return is_equivalent(lefta, righta) + + +def ignore_type_parameter(s: Type, t: Type, v: int) -> bool: + return True def is_subtype(left: Type, right: Type, - type_parameter_checker: Optional[TypeParameterChecker] = None, - *, ignore_pos_arg_names: bool = False, + *, + ignore_type_params: bool = False, + ignore_pos_arg_names: bool = False, ignore_declared_variance: bool = False, ignore_promotions: bool = False) -> bool: """Is 'left' subtype of 'right'? @@ -59,7 +64,6 @@ def is_subtype(left: Type, right: Type, between the type arguments (e.g., A and B), taking the variance of the type var into account. """ - type_parameter_checker = type_parameter_checker or check_type_parameter if (isinstance(right, AnyType) or isinstance(right, UnboundType) or isinstance(right, ErasedType)): return True @@ -67,7 +71,8 @@ def is_subtype(left: Type, right: Type, # Normally, when 'left' is not itself a union, the only way # 'left' can be a subtype of the union 'right' is if it is a # subtype of one of the items making up the union. - is_subtype_of_item = any(is_subtype(left, item, type_parameter_checker, + is_subtype_of_item = any(is_subtype(left, item, + ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, ignore_promotions=ignore_promotions) @@ -83,69 +88,65 @@ def is_subtype(left: Type, right: Type, elif is_subtype_of_item: return True # otherwise, fall through - return left.accept(SubtypeVisitor(right, type_parameter_checker, + return left.accept(SubtypeVisitor(right, + ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, ignore_promotions=ignore_promotions)) def is_subtype_ignoring_tvars(left: Type, right: Type) -> bool: - def ignore_tvars(s: Type, t: Type, v: int) -> bool: - return True - return is_subtype(left, right, ignore_tvars) + return is_subtype(left, right, ignore_type_params=True) -def is_equivalent(a: Type, - b: Type, - type_parameter_checker: Optional[TypeParameterChecker] = None, +def is_equivalent(a: Type, b: Type, *, + ignore_type_params: bool = False, ignore_pos_arg_names: bool = False ) -> bool: return ( - is_subtype(a, b, type_parameter_checker, ignore_pos_arg_names=ignore_pos_arg_names) - and is_subtype(b, a, type_parameter_checker, ignore_pos_arg_names=ignore_pos_arg_names)) + is_subtype(a, b, ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names) + and is_subtype(b, a, ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names)) class SubtypeVisitor(TypeVisitor[bool]): def __init__(self, right: Type, - type_parameter_checker: TypeParameterChecker, - *, ignore_pos_arg_names: bool = False, + *, + ignore_type_params: bool, + ignore_pos_arg_names: bool = False, ignore_declared_variance: bool = False, ignore_promotions: bool = False) -> None: self.right = right - self.check_type_parameter = type_parameter_checker + self.ignore_type_params = ignore_type_params self.ignore_pos_arg_names = ignore_pos_arg_names self.ignore_declared_variance = ignore_declared_variance self.ignore_promotions = ignore_promotions + self.check_type_parameter = (ignore_type_parameter if ignore_type_params else + check_type_parameter) self._subtype_kind = SubtypeVisitor.build_subtype_kind( - type_parameter_checker=type_parameter_checker, + ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, ignore_promotions=ignore_promotions) @staticmethod def build_subtype_kind(*, - type_parameter_checker: Optional[TypeParameterChecker] = None, + ignore_type_params: bool = False, ignore_pos_arg_names: bool = False, ignore_declared_variance: bool = False, ignore_promotions: bool = False) -> SubtypeKind: - type_parameter_checker = type_parameter_checker or check_type_parameter - return ('subtype', - type_parameter_checker, + return (False, # is proper subtype? + ignore_type_params, ignore_pos_arg_names, ignore_declared_variance, ignore_promotions) - def _lookup_cache(self, left: Instance, right: Instance) -> bool: - return TypeState.is_cached_subtype_check(self._subtype_kind, left, right) - - def _record_cache(self, left: Instance, right: Instance) -> None: - TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) - def _is_subtype(self, left: Type, right: Type) -> bool: return is_subtype(left, right, - type_parameter_checker=self.check_type_parameter, + ignore_type_params=self.ignore_type_params, ignore_pos_arg_names=self.ignore_pos_arg_names, ignore_declared_variance=self.ignore_declared_variance, ignore_promotions=self.ignore_promotions) @@ -190,12 +191,12 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(right, TupleType) and right.fallback.type.is_enum: return self._is_subtype(left, right.fallback) if isinstance(right, Instance): - if self._lookup_cache(left, right): + if TypeState.is_cached_subtype_check(self._subtype_kind, left, right): return True if not self.ignore_promotions: for base in left.type.mro: if base._promote and self._is_subtype(base._promote, self.right): - self._record_cache(left, right) + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return True rname = right.type.fullname() # Always try a nominal check if possible, @@ -208,7 +209,7 @@ def visit_instance(self, left: Instance) -> bool: for lefta, righta, tvar in zip(t.args, right.args, right.type.defn.type_vars)) if nominal: - self._record_cache(left, right) + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return nominal if right.type.is_protocol and is_protocol_implementation(left, right): return True @@ -303,7 +304,8 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: if not left.names_are_wider_than(right): return False for name, l, r in left.zip(right): - if not is_equivalent(l, r, self.check_type_parameter): + if not is_equivalent(l, r, + ignore_type_params=self.ignore_type_params): return False # Non-required key is not compatible with a required key since # indexing may fail unexpectedly if a required key is missing. @@ -1033,13 +1035,7 @@ def __init__(self, right: Type, *, ignore_promotions: bool = False) -> None: @staticmethod def build_subtype_kind(*, ignore_promotions: bool = False) -> SubtypeKind: - return ('subtype_proper', ignore_promotions) - - def _lookup_cache(self, left: Instance, right: Instance) -> bool: - return TypeState.is_cached_subtype_check(self._subtype_kind, left, right) - - def _record_cache(self, left: Instance, right: Instance) -> None: - TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) + return (True, ignore_promotions) def _is_proper_subtype(self, left: Type, right: Type) -> bool: return is_proper_subtype(left, right, ignore_promotions=self.ignore_promotions) @@ -1073,12 +1069,12 @@ def visit_deleted_type(self, left: DeletedType) -> bool: def visit_instance(self, left: Instance) -> bool: right = self.right if isinstance(right, Instance): - if self._lookup_cache(left, right): + if TypeState.is_cached_subtype_check(self._subtype_kind, left, right): return True if not self.ignore_promotions: for base in left.type.mro: if base._promote and self._is_proper_subtype(base._promote, right): - self._record_cache(left, right) + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return True if left.type.has_base(right.type.fullname()): @@ -1095,7 +1091,7 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in zip(left.args, right.args, right.type.defn.type_vars)) if nominal: - self._record_cache(left, right) + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return nominal if (right.type.is_protocol and is_protocol_implementation(left, right, proper_subtype=True)): diff --git a/mypy/typestate.py b/mypy/typestate.py index 318c51cde21c..d694124a24fa 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -17,7 +17,7 @@ # A tuple encoding the specific conditions under which we performed the subtype check. # (e.g. did we want a proper subtype? A regular subtype while ignoring variance?) -SubtypeKind = Tuple[Any, ...] +SubtypeKind = Tuple[bool, ...] # A cache that keeps track of whether the given TypeInfo is a part of a particular # subtype relationship @@ -83,7 +83,8 @@ def reset_all_subtype_caches(cls) -> None: @classmethod def reset_subtype_caches_for(cls, info: TypeInfo) -> None: """Reset subtype caches (if any) for a given supertype TypeInfo.""" - cls._subtype_caches.setdefault(info, dict()).clear() + if info in cls._subtype_caches: + cls._subtype_caches[info].clear() @classmethod def reset_all_subtype_caches_for(cls, info: TypeInfo) -> None: @@ -93,14 +94,19 @@ def reset_all_subtype_caches_for(cls, info: TypeInfo) -> None: @classmethod def is_cached_subtype_check(cls, kind: SubtypeKind, left: Instance, right: Instance) -> bool: - subtype_kinds = cls._subtype_caches.setdefault(right.type, dict()) - return (left, right) in subtype_kinds.setdefault(kind, set()) + info = right.type + if info not in cls._subtype_caches: + return False + cache = cls._subtype_caches[info] + if kind not in cache: + return False + return (left, right) in cache[kind] @classmethod def record_subtype_cache_entry(cls, kind: SubtypeKind, left: Instance, right: Instance) -> None: - subtype_kinds = cls._subtype_caches.setdefault(right.type, dict()) - subtype_kinds.setdefault(kind, set()).add((left, right)) + cache = cls._subtype_caches.setdefault(right.type, dict()) + cache.setdefault(kind, set()).add((left, right)) @classmethod def reset_protocol_deps(cls) -> None: