Skip to content

Commit 82d7faf

Browse files
Add inference of Compare nodes (#979)
* Add inference for Compare nodes Ref #846. Identity checks are currently Uninferable as there is no sensible way to infer that two Instances refer to the same object without accurately modelling control flow. Co-authored-by: Pierre Sassoulas <[email protected]>
1 parent 24a1118 commit 82d7faf

File tree

2 files changed

+351
-2
lines changed

2 files changed

+351
-2
lines changed

astroid/inference.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
"""this module contains a set of functions to handle inference on astroid trees
2828
"""
2929

30+
import ast
3031
import functools
3132
import itertools
3233
import operator
34+
from typing import Any, Iterable
3335

3436
import wrapt
3537

@@ -790,6 +792,98 @@ def infer_binop(self, context=None):
790792
nodes.BinOp._infer_binop = _infer_binop
791793
nodes.BinOp._infer = infer_binop
792794

795+
COMPARE_OPS = {
796+
"==": operator.eq,
797+
"!=": operator.ne,
798+
"<": operator.lt,
799+
"<=": operator.le,
800+
">": operator.gt,
801+
">=": operator.ge,
802+
"in": lambda a, b: a in b,
803+
"not in": lambda a, b: a not in b,
804+
}
805+
UNINFERABLE_OPS = {
806+
"is",
807+
"is not",
808+
}
809+
810+
811+
def _to_literal(node: nodes.NodeNG) -> Any:
812+
# Can raise SyntaxError or ValueError from ast.literal_eval
813+
# Is this the stupidest idea or the simplest idea?
814+
return ast.literal_eval(node.as_string())
815+
816+
817+
def _do_compare(
818+
left_iter: Iterable[nodes.NodeNG], op: str, right_iter: Iterable[nodes.NodeNG]
819+
) -> "bool | type[util.Uninferable]":
820+
"""
821+
If all possible combinations are either True or False, return that:
822+
>>> _do_compare([1, 2], '<=', [3, 4])
823+
True
824+
>>> _do_compare([1, 2], '==', [3, 4])
825+
False
826+
827+
If any item is uninferable, or if some combinations are True and some
828+
are False, return Uninferable:
829+
>>> _do_compare([1, 3], '<=', [2, 4])
830+
util.Uninferable
831+
"""
832+
retval = None
833+
if op in UNINFERABLE_OPS:
834+
return util.Uninferable
835+
op_func = COMPARE_OPS[op]
836+
837+
for left, right in itertools.product(left_iter, right_iter):
838+
if left is util.Uninferable or right is util.Uninferable:
839+
return util.Uninferable
840+
841+
try:
842+
left, right = _to_literal(left), _to_literal(right)
843+
except (SyntaxError, ValueError):
844+
return util.Uninferable
845+
846+
try:
847+
expr = op_func(left, right)
848+
except TypeError as exc:
849+
raise AstroidTypeError from exc
850+
851+
if retval is None:
852+
retval = expr
853+
elif retval != expr:
854+
return util.Uninferable
855+
# (or both, but "True | False" is basically the same)
856+
857+
return retval # it was all the same value
858+
859+
860+
def _infer_compare(self: nodes.Compare, context: InferenceContext) -> Any:
861+
"""Chained comparison inference logic."""
862+
retval = True
863+
864+
ops = self.ops
865+
left_node = self.left
866+
lhs = list(left_node.infer(context=context))
867+
# should we break early if first element is uninferable?
868+
for op, right_node in ops:
869+
# eagerly evaluate rhs so that values can be re-used as lhs
870+
rhs = list(right_node.infer(context=context))
871+
try:
872+
retval = _do_compare(lhs, op, rhs)
873+
except AstroidTypeError:
874+
retval = util.Uninferable
875+
break
876+
if retval is not True:
877+
break # short-circuit
878+
lhs = rhs # continue
879+
if retval is util.Uninferable:
880+
yield retval
881+
else:
882+
yield nodes.Const(retval)
883+
884+
885+
nodes.Compare._infer = _infer_compare
886+
793887

794888
def _infer_augassign(self, context=None):
795889
"""Inference logic for augmented binary operations."""

tests/unittest_inference.py

Lines changed: 257 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5280,6 +5280,261 @@ def f(**kwargs):
52805280
assert next(extract_node(code).infer()).as_string() == "{'f': 1}"
52815281

52825282

5283+
@pytest.mark.parametrize(
5284+
"op,result",
5285+
[
5286+
("<", False),
5287+
("<=", True),
5288+
("==", True),
5289+
(">=", True),
5290+
(">", False),
5291+
("!=", False),
5292+
],
5293+
)
5294+
def test_compare(op, result) -> None:
5295+
code = """
5296+
123 {} 123
5297+
""".format(
5298+
op
5299+
)
5300+
node = extract_node(code)
5301+
inferred = next(node.infer())
5302+
assert inferred.value == result
5303+
5304+
5305+
@pytest.mark.xfail(reason="uninferable")
5306+
@pytest.mark.parametrize(
5307+
"op,result",
5308+
[
5309+
("is", True),
5310+
("is not", False),
5311+
],
5312+
)
5313+
def test_compare_identity(op, result) -> None:
5314+
code = """
5315+
obj = object()
5316+
obj {} obj
5317+
""".format(
5318+
op
5319+
)
5320+
node = extract_node(code)
5321+
inferred = next(node.infer())
5322+
assert inferred.value == result
5323+
5324+
5325+
@pytest.mark.parametrize(
5326+
"op,result",
5327+
[
5328+
("in", True),
5329+
("not in", False),
5330+
],
5331+
)
5332+
def test_compare_membership(op, result) -> None:
5333+
code = """
5334+
1 {} [1, 2, 3]
5335+
""".format(
5336+
op
5337+
)
5338+
node = extract_node(code)
5339+
inferred = next(node.infer())
5340+
assert inferred.value == result
5341+
5342+
5343+
@pytest.mark.parametrize(
5344+
"lhs,rhs,result",
5345+
[
5346+
(1, 1, True),
5347+
(1, 1.1, True),
5348+
(1.1, 1, False),
5349+
(1.0, 1.0, True),
5350+
("abc", "def", True),
5351+
("abc", "", False),
5352+
([], [1], True),
5353+
((1, 2), (2, 3), True),
5354+
((1, 0), (1,), False),
5355+
(True, True, True),
5356+
(True, False, False),
5357+
(False, 1, True),
5358+
(1 + 0j, 2 + 0j, util.Uninferable),
5359+
(+0.0, -0.0, True),
5360+
(0, "1", util.Uninferable),
5361+
(b"\x00", b"\x01", True),
5362+
],
5363+
)
5364+
def test_compare_lesseq_types(lhs, rhs, result) -> None:
5365+
code = """
5366+
{lhs!r} <= {rhs!r}
5367+
""".format(
5368+
lhs=lhs, rhs=rhs
5369+
)
5370+
node = extract_node(code)
5371+
inferred = next(node.infer())
5372+
assert inferred.value == result
5373+
5374+
5375+
def test_compare_chained() -> None:
5376+
code = """
5377+
3 < 5 > 3
5378+
"""
5379+
node = extract_node(code)
5380+
inferred = next(node.infer())
5381+
assert inferred.value is True
5382+
5383+
5384+
def test_compare_inferred_members() -> None:
5385+
code = """
5386+
a = 11
5387+
b = 13
5388+
a < b
5389+
"""
5390+
node = extract_node(code)
5391+
inferred = next(node.infer())
5392+
assert inferred.value is True
5393+
5394+
5395+
def test_compare_instance_members() -> None:
5396+
code = """
5397+
class A:
5398+
value = 123
5399+
class B:
5400+
@property
5401+
def value(self):
5402+
return 456
5403+
A().value < B().value
5404+
"""
5405+
node = extract_node(code)
5406+
inferred = next(node.infer())
5407+
assert inferred.value is True
5408+
5409+
5410+
@pytest.mark.xfail(reason="unimplemented")
5411+
def test_compare_dynamic() -> None:
5412+
code = """
5413+
class A:
5414+
def __le__(self, other):
5415+
return True
5416+
A() <= None
5417+
"""
5418+
node = extract_node(code)
5419+
inferred = next(node.infer())
5420+
assert inferred.value is True
5421+
5422+
5423+
def test_compare_uninferable_member() -> None:
5424+
code = """
5425+
from unknown import UNKNOWN
5426+
0 <= UNKNOWN
5427+
"""
5428+
node = extract_node(code)
5429+
inferred = next(node.infer())
5430+
assert inferred is util.Uninferable
5431+
5432+
5433+
def test_compare_chained_comparisons_shortcircuit_on_false() -> None:
5434+
code = """
5435+
from unknown import UNKNOWN
5436+
2 < 1 < UNKNOWN
5437+
"""
5438+
node = extract_node(code)
5439+
inferred = next(node.infer())
5440+
assert inferred.value is False
5441+
5442+
5443+
def test_compare_chained_comparisons_continue_on_true() -> None:
5444+
code = """
5445+
from unknown import UNKNOWN
5446+
1 < 2 < UNKNOWN
5447+
"""
5448+
node = extract_node(code)
5449+
inferred = next(node.infer())
5450+
assert inferred is util.Uninferable
5451+
5452+
5453+
@pytest.mark.xfail(reason="unimplemented")
5454+
def test_compare_known_false_branch() -> None:
5455+
code = """
5456+
a = 'hello'
5457+
if 1 < 2:
5458+
a = 'goodbye'
5459+
a
5460+
"""
5461+
node = extract_node(code)
5462+
inferred = list(node.infer())
5463+
assert len(inferred) == 1
5464+
assert isinstance(inferred[0], nodes.Const)
5465+
assert inferred[0].value == "hello"
5466+
5467+
5468+
def test_compare_ifexp_constant() -> None:
5469+
code = """
5470+
a = 'hello' if 1 < 2 else 'goodbye'
5471+
a
5472+
"""
5473+
node = extract_node(code)
5474+
inferred = list(node.infer())
5475+
assert len(inferred) == 1
5476+
assert isinstance(inferred[0], nodes.Const)
5477+
assert inferred[0].value == "hello"
5478+
5479+
5480+
def test_compare_typeerror() -> None:
5481+
code = """
5482+
123 <= "abc"
5483+
"""
5484+
node = extract_node(code)
5485+
inferred = list(node.infer())
5486+
assert len(inferred) == 1
5487+
assert inferred[0] is util.Uninferable
5488+
5489+
5490+
def test_compare_multiple_possibilites() -> None:
5491+
code = """
5492+
from unknown import UNKNOWN
5493+
a = 1
5494+
if UNKNOWN:
5495+
a = 2
5496+
b = 3
5497+
if UNKNOWN:
5498+
b = 4
5499+
a < b
5500+
"""
5501+
node = extract_node(code)
5502+
inferred = list(node.infer())
5503+
assert len(inferred) == 1
5504+
# All possible combinations are true: (1 < 3), (1 < 4), (2 < 3), (2 < 4)
5505+
assert inferred[0].value is True
5506+
5507+
5508+
def test_compare_ambiguous_multiple_possibilites() -> None:
5509+
code = """
5510+
from unknown import UNKNOWN
5511+
a = 1
5512+
if UNKNOWN:
5513+
a = 3
5514+
b = 2
5515+
if UNKNOWN:
5516+
b = 4
5517+
a < b
5518+
"""
5519+
node = extract_node(code)
5520+
inferred = list(node.infer())
5521+
assert len(inferred) == 1
5522+
# Not all possible combinations are true: (1 < 2), (1 < 4), (3 !< 2), (3 < 4)
5523+
assert inferred[0] is util.Uninferable
5524+
5525+
5526+
def test_compare_nonliteral() -> None:
5527+
code = """
5528+
def func(a, b):
5529+
return (a, b) <= (1, 2) #@
5530+
"""
5531+
return_node = extract_node(code)
5532+
node = return_node.value
5533+
inferred = list(node.infer()) # should not raise ValueError
5534+
assert len(inferred) == 1
5535+
assert inferred[0] is util.Uninferable
5536+
5537+
52835538
def test_limit_inference_result_amount() -> None:
52845539
"""Test setting limit inference result amount"""
52855540
code = """
@@ -5560,7 +5815,7 @@ def method(self):
55605815
""",
55615816
],
55625817
)
5563-
def test_subclass_of_exception(code):
5818+
def test_subclass_of_exception(code) -> None:
55645819
inferred = next(extract_node(code).infer())
55655820
assert isinstance(inferred, Instance)
55665821
args = next(inferred.igetattr("args"))
@@ -5721,7 +5976,7 @@ def test(self):
57215976
),
57225977
],
57235978
)
5724-
def test_inference_is_limited_to_the_boundnode(code, instance_name):
5979+
def test_inference_is_limited_to_the_boundnode(code, instance_name) -> None:
57255980
node = extract_node(code)
57265981
inferred = next(node.infer())
57275982
assert isinstance(inferred, Instance)

0 commit comments

Comments
 (0)