Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ Release date: TBA

* Reduce file system access in ``ast_from_file()``.

* Fix incorrect cache keys for inference results, thereby correctly inferring types
for calls instantiating types dynamically.

Closes #1828
Closes pylint-dev/pylint#7464
Closes pylint-dev/pylint#8074

* ``nodes.FunctionDef`` no longer inherits from ``nodes.Lambda``.
This is a breaking change but considered a bug fix as the nodes did not share the same
API and were not interchangeable.
Expand Down
12 changes: 8 additions & 4 deletions astroid/inference_tip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@

from typing_extensions import ParamSpec

from astroid.context import InferenceContext
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
from astroid.nodes import NodeNG
from astroid.typing import InferenceResult, InferFn

_P = ParamSpec("_P")

_cache: dict[tuple[InferFn, NodeNG], list[InferenceResult] | None] = {}
_cache: dict[
tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult] | None
] = {}


def clear_inference_tip_cache() -> None:
Expand All @@ -31,15 +34,16 @@ def _inference_tip_cached(

def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
node = args[0]
context = args[1]
try:
result = _cache[func, node]
result = _cache[func, node, context]
# If through recursion we end up trying to infer the same
# func + node we raise here.
if result is None:
raise UseInferenceDefault()
except KeyError:
_cache[func, node] = None
result = _cache[func, node] = list(func(*args, **kwargs))
_cache[func, node, context] = None
result = _cache[func, node, context] = list(func(*args, **kwargs))
assert result
return iter(result)

Expand Down
10 changes: 2 additions & 8 deletions tests/brain/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,13 +930,7 @@ class A:
assert inferred.value == 42

def test_typing_cast_multiple_inference_calls(self) -> None:
"""Inference of an outer function should not store the result for cast.
https://github.com/pylint-dev/pylint/issues/8074
Possible solution caused RecursionErrors with Python 3.8 and CPython + PyPy.
https://github.com/pylint-dev/astroid/pull/1982
"""
"""Inference of an outer function should not store the result for cast."""
ast_nodes = builder.extract_node(
"""
from typing import TypeVar, cast
Expand All @@ -954,7 +948,7 @@ def ident(var: T) -> T:

i1 = next(ast_nodes[1].infer())
assert isinstance(i1, nodes.Const)
assert i1.value == 2 # should be "Hello"!
assert i1.value == "Hello"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸŽ‰



class ReBrainTest(unittest.TestCase):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_regrtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,27 @@ def d(self):
assert isinstance(inferred, Instance)
assert inferred.qname() == ".A"

def test_inference_context_consideration(self) -> None:
"""https://github.com/PyCQA/astroid/issues/1828"""
code = """
class Base:
def return_type(self):
return type(self)()
class A(Base):
def method(self):
return self.return_type()
class B(Base):
def method(self):
return self.return_type()
A().method() #@
B().method() #@
"""
node1, node2 = extract_node(code)
inferred1 = next(node1.infer())
assert inferred1.qname() == ".A"
inferred2 = next(node2.infer())
assert inferred2.qname() == ".B"


class Whatever:
a = property(lambda x: x, lambda x: x) # type: ignore[misc]
Expand Down
4 changes: 1 addition & 3 deletions tests/test_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,9 +1771,7 @@ def __init__(self):
"FinalClass",
"ClassB",
"MixinB",
# We don't recognize what 'cls' is at time of .format() call, only
# what it is at the end.
# "strMixin",
"strMixin",
"ClassA",
"MixinA",
"intMixin",
Expand Down