Skip to content

Commit 3010efc

Browse files
ilevkivskyijhance
authored andcommitted
Consolidate descriptor handling in checkmember.py (#18831)
This is not a pure refactoring, but almost. Right now we are in a weird situation where we have two inconsistencies: * `__set__()` is handled in `checker.py` while `__get__()` is handled in `checkmember.py` * rules for when to use binder are slightly different between descriptors and settable properties. This PR fixes these two things. As a nice bonus we should get free support for unions in `__set__()`.
1 parent b1ac028 commit 3010efc

File tree

3 files changed

+133
-27
lines changed

3 files changed

+133
-27
lines changed

mypy/checker.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -3170,7 +3170,7 @@ def check_assignment(
31703170
)
31713171
else:
31723172
self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=")
3173-
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
3173+
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue, rvalue)
31743174
# If we're assigning to __getattr__ or similar methods, check that the signature is
31753175
# valid.
31763176
if isinstance(lvalue, NameExpr) and lvalue.node:
@@ -4263,7 +4263,9 @@ def check_multi_assignment_from_iterable(
42634263
else:
42644264
self.msg.type_not_iterable(rvalue_type, context)
42654265

4266-
def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, Var | None]:
4266+
def check_lvalue(
4267+
self, lvalue: Lvalue, rvalue: Expression | None = None
4268+
) -> tuple[Type | None, IndexExpr | None, Var | None]:
42674269
lvalue_type = None
42684270
index_lvalue = None
42694271
inferred = None
@@ -4281,7 +4283,7 @@ def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, V
42814283
elif isinstance(lvalue, IndexExpr):
42824284
index_lvalue = lvalue
42834285
elif isinstance(lvalue, MemberExpr):
4284-
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True)
4286+
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True, rvalue)
42854287
self.store_type(lvalue, lvalue_type)
42864288
elif isinstance(lvalue, NameExpr):
42874289
lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True)
@@ -4552,12 +4554,8 @@ def check_member_assignment(
45524554
45534555
Return the inferred rvalue_type, inferred lvalue_type, and whether to use the binder
45544556
for this assignment.
4555-
4556-
Note: this method exists here and not in checkmember.py, because we need to take
4557-
care about interaction between binder and __set__().
45584557
"""
45594558
instance_type = get_proper_type(instance_type)
4560-
attribute_type = get_proper_type(attribute_type)
45614559
# Descriptors don't participate in class-attribute access
45624560
if (isinstance(instance_type, FunctionLike) and instance_type.is_type_obj()) or isinstance(
45634561
instance_type, TypeType
@@ -4569,8 +4567,8 @@ def check_member_assignment(
45694567
get_lvalue_type = self.expr_checker.analyze_ordinary_member_access(
45704568
lvalue, is_lvalue=False
45714569
)
4572-
use_binder = is_same_type(get_lvalue_type, attribute_type)
45734570

4571+
<<<<<<< HEAD
45744572
if not isinstance(attribute_type, Instance):
45754573
# TODO: support __set__() for union types.
45764574
rvalue_type = self.check_simple_assignment(attribute_type, rvalue, context)
@@ -4664,13 +4662,23 @@ def check_member_assignment(
46644662
return AnyType(TypeOfAny.from_error), get_type, False
46654663

46664664
set_type = inferred_dunder_set_type.arg_types[1]
4665+
=======
4666+
>>>>>>> df9ddfcac (Consolidate descriptor handling in checkmember.py (#18831))
46674667
# Special case: if the rvalue_type is a subtype of both '__get__' and '__set__' types,
46684668
# and '__get__' type is narrower than '__set__', then we invoke the binder to narrow type
46694669
# by this assignment. Technically, this is not safe, but in practice this is
46704670
# what a user expects.
4671+
<<<<<<< HEAD
46714672
rvalue_type = self.check_simple_assignment(set_type, rvalue, context)
46724673
infer = is_subtype(rvalue_type, get_type) and is_subtype(get_type, set_type)
46734674
return rvalue_type if infer else set_type, get_type, infer
4675+
=======
4676+
rvalue_type, _ = self.check_simple_assignment(attribute_type, rvalue, context)
4677+
infer = is_subtype(rvalue_type, get_lvalue_type) and is_subtype(
4678+
get_lvalue_type, attribute_type
4679+
)
4680+
return rvalue_type if infer else attribute_type, attribute_type, infer
4681+
>>>>>>> df9ddfcac (Consolidate descriptor handling in checkmember.py (#18831))
46744682

46754683
def check_indexed_assignment(
46764684
self, lvalue: IndexExpr, rvalue: Expression, context: Context

mypy/checkexpr.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3327,8 +3327,13 @@ def visit_member_expr(self, e: MemberExpr, is_lvalue: bool = False) -> Type:
33273327
self.chk.warn_deprecated(e.node, e)
33283328
return narrowed
33293329

3330-
def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type:
3331-
"""Analyse member expression or member lvalue."""
3330+
def analyze_ordinary_member_access(
3331+
self, e: MemberExpr, is_lvalue: bool, rvalue: Expression | None = None
3332+
) -> Type:
3333+
"""Analyse member expression or member lvalue.
3334+
3335+
An rvalue can be provided optionally to infer better setter type when is_lvalue is True.
3336+
"""
33323337
if e.kind is not None:
33333338
# This is a reference to a module attribute.
33343339
return self.analyze_ref_expr(e)
@@ -3360,6 +3365,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type
33603365
in_literal_context=self.is_literal_context(),
33613366
module_symbol_table=module_symbol_table,
33623367
is_self=is_self,
3368+
rvalue=rvalue,
33633369
)
33643370

33653371
return member_type

mypy/checkmember.py

+109-17
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ArgKind,
2424
Context,
2525
Decorator,
26+
Expression,
2627
FuncBase,
2728
FuncDef,
2829
IndexExpr,
@@ -101,6 +102,7 @@ def __init__(
101102
module_symbol_table: SymbolTable | None = None,
102103
no_deferral: bool = False,
103104
is_self: bool = False,
105+
rvalue: Expression | None = None,
104106
) -> None:
105107
self.is_lvalue = is_lvalue
106108
self.is_super = is_super
@@ -113,6 +115,9 @@ def __init__(
113115
self.module_symbol_table = module_symbol_table
114116
self.no_deferral = no_deferral
115117
self.is_self = is_self
118+
if rvalue is not None:
119+
assert is_lvalue
120+
self.rvalue = rvalue
116121

117122
def named_type(self, name: str) -> Instance:
118123
return self.chk.named_type(name)
@@ -139,6 +144,7 @@ def copy_modified(
139144
self_type=self.self_type,
140145
module_symbol_table=self.module_symbol_table,
141146
no_deferral=self.no_deferral,
147+
rvalue=self.rvalue,
142148
)
143149
if messages is not None:
144150
mx.msg = messages
@@ -168,6 +174,7 @@ def analyze_member_access(
168174
module_symbol_table: SymbolTable | None = None,
169175
no_deferral: bool = False,
170176
is_self: bool = False,
177+
rvalue: Expression | None = None,
171178
) -> Type:
172179
"""Return the type of attribute 'name' of 'typ'.
173180
@@ -186,11 +193,14 @@ def analyze_member_access(
186193
of 'original_type'. 'original_type' is always preserved as the 'typ' type used in
187194
the initial, non-recursive call. The 'self_type' is a component of 'original_type'
188195
to which generic self should be bound (a narrower type that has a fallback to instance).
189-
Currently this is used only for union types.
196+
Currently, this is used only for union types.
190197
191-
'module_symbol_table' is passed to this function if 'typ' is actually a module
198+
'module_symbol_table' is passed to this function if 'typ' is actually a module,
192199
and we want to keep track of the available attributes of the module (since they
193200
are not available via the type object directly)
201+
202+
'rvalue' can be provided optionally to infer better setter type when is_lvalue is True,
203+
most notably this helps for descriptors with overloaded __set__() method.
194204
"""
195205
mx = MemberContext(
196206
is_lvalue=is_lvalue,
@@ -204,6 +214,7 @@ def analyze_member_access(
204214
module_symbol_table=module_symbol_table,
205215
no_deferral=no_deferral,
206216
is_self=is_self,
217+
rvalue=rvalue,
207218
)
208219
result = _analyze_member_access(name, typ, mx, override_info)
209220
possible_literal = get_proper_type(result)
@@ -629,17 +640,15 @@ def check_final_member(name: str, info: TypeInfo, msg: MessageBuilder, ctx: Cont
629640
msg.cant_assign_to_final(name, attr_assign=True, ctx=ctx)
630641

631642

632-
def analyze_descriptor_access(
633-
descriptor_type: Type, mx: MemberContext, *, assignment: bool = False
634-
) -> Type:
643+
def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
635644
"""Type check descriptor access.
636645
637646
Arguments:
638647
descriptor_type: The type of the descriptor attribute being accessed
639648
(the type of ``f`` in ``a.f`` when ``f`` is a descriptor).
640649
mx: The current member access context.
641650
Return:
642-
The return type of the appropriate ``__get__`` overload for the descriptor.
651+
The return type of the appropriate ``__get__/__set__`` overload for the descriptor.
643652
"""
644653
instance_type = get_proper_type(mx.self_type)
645654
orig_descriptor_type = descriptor_type
@@ -648,15 +657,24 @@ def analyze_descriptor_access(
648657
if isinstance(descriptor_type, UnionType):
649658
# Map the access over union types
650659
return make_simplified_union(
651-
[
652-
analyze_descriptor_access(typ, mx, assignment=assignment)
653-
for typ in descriptor_type.items
654-
]
660+
[analyze_descriptor_access(typ, mx) for typ in descriptor_type.items]
655661
)
656662
elif not isinstance(descriptor_type, Instance):
657663
return orig_descriptor_type
658664

659-
if not descriptor_type.type.has_readable_member("__get__"):
665+
if not mx.is_lvalue and not descriptor_type.type.has_readable_member("__get__"):
666+
return orig_descriptor_type
667+
668+
# We do this check first to accommodate for descriptors with only __set__ method.
669+
# If there is no __set__, we type-check that the assigned value matches
670+
# the return type of __get__. This doesn't match the python semantics,
671+
# (which allow you to override the descriptor with any value), but preserves
672+
# the type of accessing the attribute (even after the override).
673+
if mx.is_lvalue and descriptor_type.type.has_readable_member("__set__"):
674+
return analyze_descriptor_assign(descriptor_type, mx)
675+
676+
if mx.is_lvalue and not descriptor_type.type.has_readable_member("__get__"):
677+
# This turned out to be not a descriptor after all.
660678
return orig_descriptor_type
661679

662680
dunder_get = descriptor_type.type.get_method("__get__")
@@ -713,11 +731,10 @@ def analyze_descriptor_access(
713731
callable_name=callable_name,
714732
)
715733

716-
if not assignment:
717-
mx.chk.check_deprecated(dunder_get, mx.context)
718-
mx.chk.warn_deprecated_overload_item(
719-
dunder_get, mx.context, target=inferred_dunder_get_type, selftype=descriptor_type
720-
)
734+
mx.chk.check_deprecated(dunder_get, mx.context)
735+
mx.chk.warn_deprecated_overload_item(
736+
dunder_get, mx.context, target=inferred_dunder_get_type, selftype=descriptor_type
737+
)
721738

722739
inferred_dunder_get_type = get_proper_type(inferred_dunder_get_type)
723740
if isinstance(inferred_dunder_get_type, AnyType):
@@ -736,6 +753,79 @@ def analyze_descriptor_access(
736753
return inferred_dunder_get_type.ret_type
737754

738755

756+
def analyze_descriptor_assign(descriptor_type: Instance, mx: MemberContext) -> Type:
757+
instance_type = get_proper_type(mx.self_type)
758+
dunder_set = descriptor_type.type.get_method("__set__")
759+
if dunder_set is None:
760+
mx.chk.fail(
761+
message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format(
762+
descriptor_type.str_with_options(mx.msg.options)
763+
),
764+
mx.context,
765+
)
766+
return AnyType(TypeOfAny.from_error)
767+
768+
bound_method = analyze_decorator_or_funcbase_access(
769+
defn=dunder_set,
770+
itype=descriptor_type,
771+
name="__set__",
772+
mx=mx.copy_modified(is_lvalue=False, self_type=descriptor_type),
773+
)
774+
typ = map_instance_to_supertype(descriptor_type, dunder_set.info)
775+
dunder_set_type = expand_type_by_instance(bound_method, typ)
776+
777+
callable_name = mx.chk.expr_checker.method_fullname(descriptor_type, "__set__")
778+
rvalue = mx.rvalue or TempNode(AnyType(TypeOfAny.special_form), context=mx.context)
779+
dunder_set_type = mx.chk.expr_checker.transform_callee_type(
780+
callable_name,
781+
dunder_set_type,
782+
[TempNode(instance_type, context=mx.context), rvalue],
783+
[ARG_POS, ARG_POS],
784+
mx.context,
785+
object_type=descriptor_type,
786+
)
787+
788+
# For non-overloaded setters, the result should be type-checked like a regular assignment.
789+
# Hence, we first only try to infer the type by using the rvalue as type context.
790+
type_context = rvalue
791+
with mx.msg.filter_errors():
792+
_, inferred_dunder_set_type = mx.chk.expr_checker.check_call(
793+
dunder_set_type,
794+
[TempNode(instance_type, context=mx.context), type_context],
795+
[ARG_POS, ARG_POS],
796+
mx.context,
797+
object_type=descriptor_type,
798+
callable_name=callable_name,
799+
)
800+
801+
# And now we in fact type check the call, to show errors related to wrong arguments
802+
# count, etc., replacing the type context for non-overloaded setters only.
803+
inferred_dunder_set_type = get_proper_type(inferred_dunder_set_type)
804+
if isinstance(inferred_dunder_set_type, CallableType):
805+
type_context = TempNode(AnyType(TypeOfAny.special_form), context=mx.context)
806+
mx.chk.expr_checker.check_call(
807+
dunder_set_type,
808+
[TempNode(instance_type, context=mx.context), type_context],
809+
[ARG_POS, ARG_POS],
810+
mx.context,
811+
object_type=descriptor_type,
812+
callable_name=callable_name,
813+
)
814+
815+
# Search for possible deprecations:
816+
mx.chk.check_deprecated(dunder_set, mx.context)
817+
mx.chk.warn_deprecated_overload_item(
818+
dunder_set, mx.context, target=inferred_dunder_set_type, selftype=descriptor_type
819+
)
820+
821+
# In the following cases, a message already will have been recorded in check_call.
822+
if (not isinstance(inferred_dunder_set_type, CallableType)) or (
823+
len(inferred_dunder_set_type.arg_types) < 2
824+
):
825+
return AnyType(TypeOfAny.from_error)
826+
return inferred_dunder_set_type.arg_types[1]
827+
828+
739829
def is_instance_var(var: Var) -> bool:
740830
"""Return if var is an instance variable according to PEP 526."""
741831
return (
@@ -820,6 +910,7 @@ def analyze_var(
820910
# A property cannot have an overloaded type => the cast is fine.
821911
assert isinstance(expanded_signature, CallableType)
822912
if var.is_settable_property and mx.is_lvalue and var.setter_type is not None:
913+
# TODO: use check_call() to infer better type, same as for __set__().
823914
result = expanded_signature.arg_types[0]
824915
else:
825916
result = expanded_signature.ret_type
@@ -832,7 +923,7 @@ def analyze_var(
832923
result = AnyType(TypeOfAny.special_form)
833924
fullname = f"{var.info.fullname}.{name}"
834925
hook = mx.chk.plugin.get_attribute_hook(fullname)
835-
if result and not mx.is_lvalue and not implicit:
926+
if result and not (implicit or var.info.is_protocol and is_instance_var(var)):
836927
result = analyze_descriptor_access(result, mx)
837928
if hook:
838929
result = hook(
@@ -1106,6 +1197,7 @@ def analyze_class_attribute_access(
11061197
result = add_class_tvars(
11071198
t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars
11081199
)
1200+
# __set__ is not called on class objects.
11091201
if not mx.is_lvalue:
11101202
result = analyze_descriptor_access(result, mx)
11111203

0 commit comments

Comments
 (0)