diff --git a/astroid/nodes/node_classes.py b/astroid/nodes/node_classes.py index a4aabe380a..7d01ecfa8e 100644 --- a/astroid/nodes/node_classes.py +++ b/astroid/nodes/node_classes.py @@ -1039,6 +1039,40 @@ def _format_args( return ", ".join(values) +def _infer_attribute( + node: nodes.AssignAttr | nodes.Attribute, + context: InferenceContext | None = None, + **kwargs: Any, +) -> Generator[InferenceResult, None, InferenceErrorInfo]: + """Infer an AssignAttr/Attribute node by using getattr on the associated object.""" + # pylint: disable=import-outside-toplevel + from astroid.constraint import get_constraints + from astroid.nodes import ClassDef + + for owner in node.expr.infer(context): + if isinstance(owner, util.UninferableBase): + yield owner + continue + + context = copy_context(context) + old_boundnode = context.boundnode + try: + context.boundnode = owner + if isinstance(owner, (ClassDef, Instance)): + frame = owner if isinstance(owner, ClassDef) else owner._proxied + context.constraints[node.attrname] = get_constraints(node, frame=frame) + yield from owner.igetattr(node.attrname, context) + except ( + AttributeInferenceError, + InferenceError, + AttributeError, + ): + pass + finally: + context.boundnode = old_boundnode + return InferenceErrorInfo(node=node, context=context) + + class AssignAttr(_base_nodes.LookupMixIn, _base_nodes.ParentAssignNode): """Variation of :class:`ast.Assign` representing assignment to an attribute. @@ -1103,6 +1137,13 @@ def _infer( stmts = list(self.assigned_stmts(context=context)) return _infer_stmts(stmts, context) + @decorators.raise_if_nothing_inferred + @decorators.path_wrapper + def infer_lhs( + self, context: InferenceContext | None = None, **kwargs: Any + ) -> Generator[InferenceResult, None, InferenceErrorInfo | None]: + return _infer_attribute(self, context, **kwargs) + class Assert(_base_nodes.Statement): """Class representing an :class:`ast.Assert` node. @@ -2819,35 +2860,7 @@ def get_children(self): def _infer( self, context: InferenceContext | None = None, **kwargs: Any ) -> Generator[InferenceResult, None, InferenceErrorInfo]: - """Infer an Attribute node by using getattr on the associated object.""" - # pylint: disable=import-outside-toplevel - from astroid.constraint import get_constraints - from astroid.nodes import ClassDef - - for owner in self.expr.infer(context): - if isinstance(owner, util.UninferableBase): - yield owner - continue - - context = copy_context(context) - old_boundnode = context.boundnode - try: - context.boundnode = owner - if isinstance(owner, (ClassDef, Instance)): - frame = owner if isinstance(owner, ClassDef) else owner._proxied - context.constraints[self.attrname] = get_constraints( - self, frame=frame - ) - yield from owner.igetattr(self.attrname, context) - except ( - AttributeInferenceError, - InferenceError, - AttributeError, - ): - pass - finally: - context.boundnode = old_boundnode - return InferenceErrorInfo(node=self, context=context) + return _infer_attribute(self, context, **kwargs) class Global(_base_nodes.NoChildrenNode, _base_nodes.Statement): diff --git a/tests/test_inference.py b/tests/test_inference.py index 4f2c14b40a..7e25ea0dae 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -6319,6 +6319,20 @@ def __init__(self, index): assert isinstance(index[0], nodes.AssignAttr) +def test_infer_assign_attr() -> None: + code = """ + class Counter: + def __init__(self): + self.count = 0 + + def increment(self): + self.count += 1 #@ + """ + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value == 1 + + @pytest.mark.parametrize( "code,instance_name", [