23
23
ArgKind ,
24
24
Context ,
25
25
Decorator ,
26
+ Expression ,
26
27
FuncBase ,
27
28
FuncDef ,
28
29
IndexExpr ,
@@ -101,6 +102,7 @@ def __init__(
101
102
module_symbol_table : SymbolTable | None = None ,
102
103
no_deferral : bool = False ,
103
104
is_self : bool = False ,
105
+ rvalue : Expression | None = None ,
104
106
) -> None :
105
107
self .is_lvalue = is_lvalue
106
108
self .is_super = is_super
@@ -113,6 +115,9 @@ def __init__(
113
115
self .module_symbol_table = module_symbol_table
114
116
self .no_deferral = no_deferral
115
117
self .is_self = is_self
118
+ if rvalue is not None :
119
+ assert is_lvalue
120
+ self .rvalue = rvalue
116
121
117
122
def named_type (self , name : str ) -> Instance :
118
123
return self .chk .named_type (name )
@@ -139,6 +144,7 @@ def copy_modified(
139
144
self_type = self .self_type ,
140
145
module_symbol_table = self .module_symbol_table ,
141
146
no_deferral = self .no_deferral ,
147
+ rvalue = self .rvalue ,
142
148
)
143
149
if messages is not None :
144
150
mx .msg = messages
@@ -168,6 +174,7 @@ def analyze_member_access(
168
174
module_symbol_table : SymbolTable | None = None ,
169
175
no_deferral : bool = False ,
170
176
is_self : bool = False ,
177
+ rvalue : Expression | None = None ,
171
178
) -> Type :
172
179
"""Return the type of attribute 'name' of 'typ'.
173
180
@@ -186,11 +193,14 @@ def analyze_member_access(
186
193
of 'original_type'. 'original_type' is always preserved as the 'typ' type used in
187
194
the initial, non-recursive call. The 'self_type' is a component of 'original_type'
188
195
to which generic self should be bound (a narrower type that has a fallback to instance).
189
- Currently this is used only for union types.
196
+ Currently, this is used only for union types.
190
197
191
- 'module_symbol_table' is passed to this function if 'typ' is actually a module
198
+ 'module_symbol_table' is passed to this function if 'typ' is actually a module,
192
199
and we want to keep track of the available attributes of the module (since they
193
200
are not available via the type object directly)
201
+
202
+ 'rvalue' can be provided optionally to infer better setter type when is_lvalue is True,
203
+ most notably this helps for descriptors with overloaded __set__() method.
194
204
"""
195
205
mx = MemberContext (
196
206
is_lvalue = is_lvalue ,
@@ -204,6 +214,7 @@ def analyze_member_access(
204
214
module_symbol_table = module_symbol_table ,
205
215
no_deferral = no_deferral ,
206
216
is_self = is_self ,
217
+ rvalue = rvalue ,
207
218
)
208
219
result = _analyze_member_access (name , typ , mx , override_info )
209
220
possible_literal = get_proper_type (result )
@@ -629,17 +640,15 @@ def check_final_member(name: str, info: TypeInfo, msg: MessageBuilder, ctx: Cont
629
640
msg .cant_assign_to_final (name , attr_assign = True , ctx = ctx )
630
641
631
642
632
- def analyze_descriptor_access (
633
- descriptor_type : Type , mx : MemberContext , * , assignment : bool = False
634
- ) -> Type :
643
+ def analyze_descriptor_access (descriptor_type : Type , mx : MemberContext ) -> Type :
635
644
"""Type check descriptor access.
636
645
637
646
Arguments:
638
647
descriptor_type: The type of the descriptor attribute being accessed
639
648
(the type of ``f`` in ``a.f`` when ``f`` is a descriptor).
640
649
mx: The current member access context.
641
650
Return:
642
- The return type of the appropriate ``__get__`` overload for the descriptor.
651
+ The return type of the appropriate ``__get__/__set__ `` overload for the descriptor.
643
652
"""
644
653
instance_type = get_proper_type (mx .self_type )
645
654
orig_descriptor_type = descriptor_type
@@ -648,15 +657,24 @@ def analyze_descriptor_access(
648
657
if isinstance (descriptor_type , UnionType ):
649
658
# Map the access over union types
650
659
return make_simplified_union (
651
- [
652
- analyze_descriptor_access (typ , mx , assignment = assignment )
653
- for typ in descriptor_type .items
654
- ]
660
+ [analyze_descriptor_access (typ , mx ) for typ in descriptor_type .items ]
655
661
)
656
662
elif not isinstance (descriptor_type , Instance ):
657
663
return orig_descriptor_type
658
664
659
- if not descriptor_type .type .has_readable_member ("__get__" ):
665
+ if not mx .is_lvalue and not descriptor_type .type .has_readable_member ("__get__" ):
666
+ return orig_descriptor_type
667
+
668
+ # We do this check first to accommodate for descriptors with only __set__ method.
669
+ # If there is no __set__, we type-check that the assigned value matches
670
+ # the return type of __get__. This doesn't match the python semantics,
671
+ # (which allow you to override the descriptor with any value), but preserves
672
+ # the type of accessing the attribute (even after the override).
673
+ if mx .is_lvalue and descriptor_type .type .has_readable_member ("__set__" ):
674
+ return analyze_descriptor_assign (descriptor_type , mx )
675
+
676
+ if mx .is_lvalue and not descriptor_type .type .has_readable_member ("__get__" ):
677
+ # This turned out to be not a descriptor after all.
660
678
return orig_descriptor_type
661
679
662
680
dunder_get = descriptor_type .type .get_method ("__get__" )
@@ -713,11 +731,10 @@ def analyze_descriptor_access(
713
731
callable_name = callable_name ,
714
732
)
715
733
716
- if not assignment :
717
- mx .chk .check_deprecated (dunder_get , mx .context )
718
- mx .chk .warn_deprecated_overload_item (
719
- dunder_get , mx .context , target = inferred_dunder_get_type , selftype = descriptor_type
720
- )
734
+ mx .chk .check_deprecated (dunder_get , mx .context )
735
+ mx .chk .warn_deprecated_overload_item (
736
+ dunder_get , mx .context , target = inferred_dunder_get_type , selftype = descriptor_type
737
+ )
721
738
722
739
inferred_dunder_get_type = get_proper_type (inferred_dunder_get_type )
723
740
if isinstance (inferred_dunder_get_type , AnyType ):
@@ -736,6 +753,79 @@ def analyze_descriptor_access(
736
753
return inferred_dunder_get_type .ret_type
737
754
738
755
756
+ def analyze_descriptor_assign (descriptor_type : Instance , mx : MemberContext ) -> Type :
757
+ instance_type = get_proper_type (mx .self_type )
758
+ dunder_set = descriptor_type .type .get_method ("__set__" )
759
+ if dunder_set is None :
760
+ mx .chk .fail (
761
+ message_registry .DESCRIPTOR_SET_NOT_CALLABLE .format (
762
+ descriptor_type .str_with_options (mx .msg .options )
763
+ ),
764
+ mx .context ,
765
+ )
766
+ return AnyType (TypeOfAny .from_error )
767
+
768
+ bound_method = analyze_decorator_or_funcbase_access (
769
+ defn = dunder_set ,
770
+ itype = descriptor_type ,
771
+ name = "__set__" ,
772
+ mx = mx .copy_modified (is_lvalue = False , self_type = descriptor_type ),
773
+ )
774
+ typ = map_instance_to_supertype (descriptor_type , dunder_set .info )
775
+ dunder_set_type = expand_type_by_instance (bound_method , typ )
776
+
777
+ callable_name = mx .chk .expr_checker .method_fullname (descriptor_type , "__set__" )
778
+ rvalue = mx .rvalue or TempNode (AnyType (TypeOfAny .special_form ), context = mx .context )
779
+ dunder_set_type = mx .chk .expr_checker .transform_callee_type (
780
+ callable_name ,
781
+ dunder_set_type ,
782
+ [TempNode (instance_type , context = mx .context ), rvalue ],
783
+ [ARG_POS , ARG_POS ],
784
+ mx .context ,
785
+ object_type = descriptor_type ,
786
+ )
787
+
788
+ # For non-overloaded setters, the result should be type-checked like a regular assignment.
789
+ # Hence, we first only try to infer the type by using the rvalue as type context.
790
+ type_context = rvalue
791
+ with mx .msg .filter_errors ():
792
+ _ , inferred_dunder_set_type = mx .chk .expr_checker .check_call (
793
+ dunder_set_type ,
794
+ [TempNode (instance_type , context = mx .context ), type_context ],
795
+ [ARG_POS , ARG_POS ],
796
+ mx .context ,
797
+ object_type = descriptor_type ,
798
+ callable_name = callable_name ,
799
+ )
800
+
801
+ # And now we in fact type check the call, to show errors related to wrong arguments
802
+ # count, etc., replacing the type context for non-overloaded setters only.
803
+ inferred_dunder_set_type = get_proper_type (inferred_dunder_set_type )
804
+ if isinstance (inferred_dunder_set_type , CallableType ):
805
+ type_context = TempNode (AnyType (TypeOfAny .special_form ), context = mx .context )
806
+ mx .chk .expr_checker .check_call (
807
+ dunder_set_type ,
808
+ [TempNode (instance_type , context = mx .context ), type_context ],
809
+ [ARG_POS , ARG_POS ],
810
+ mx .context ,
811
+ object_type = descriptor_type ,
812
+ callable_name = callable_name ,
813
+ )
814
+
815
+ # Search for possible deprecations:
816
+ mx .chk .check_deprecated (dunder_set , mx .context )
817
+ mx .chk .warn_deprecated_overload_item (
818
+ dunder_set , mx .context , target = inferred_dunder_set_type , selftype = descriptor_type
819
+ )
820
+
821
+ # In the following cases, a message already will have been recorded in check_call.
822
+ if (not isinstance (inferred_dunder_set_type , CallableType )) or (
823
+ len (inferred_dunder_set_type .arg_types ) < 2
824
+ ):
825
+ return AnyType (TypeOfAny .from_error )
826
+ return inferred_dunder_set_type .arg_types [1 ]
827
+
828
+
739
829
def is_instance_var (var : Var ) -> bool :
740
830
"""Return if var is an instance variable according to PEP 526."""
741
831
return (
@@ -820,6 +910,7 @@ def analyze_var(
820
910
# A property cannot have an overloaded type => the cast is fine.
821
911
assert isinstance (expanded_signature , CallableType )
822
912
if var .is_settable_property and mx .is_lvalue and var .setter_type is not None :
913
+ # TODO: use check_call() to infer better type, same as for __set__().
823
914
result = expanded_signature .arg_types [0 ]
824
915
else :
825
916
result = expanded_signature .ret_type
@@ -832,7 +923,7 @@ def analyze_var(
832
923
result = AnyType (TypeOfAny .special_form )
833
924
fullname = f"{ var .info .fullname } .{ name } "
834
925
hook = mx .chk .plugin .get_attribute_hook (fullname )
835
- if result and not mx . is_lvalue and not implicit :
926
+ if result and not ( implicit or var . info . is_protocol and is_instance_var ( var )) :
836
927
result = analyze_descriptor_access (result , mx )
837
928
if hook :
838
929
result = hook (
@@ -1106,6 +1197,7 @@ def analyze_class_attribute_access(
1106
1197
result = add_class_tvars (
1107
1198
t , isuper , is_classmethod , is_staticmethod , mx .self_type , original_vars = original_vars
1108
1199
)
1200
+ # __set__ is not called on class objects.
1109
1201
if not mx .is_lvalue :
1110
1202
result = analyze_descriptor_access (result , mx )
1111
1203
0 commit comments