Skip to content

Commit 030f4d0

Browse files
committed
Add inference for Compare nodes
Ref #846. Identity checks are currently Uninferable as there is no sensible way to infer that two Instances refer to the same object without accurately modelling control flow.
1 parent e839f57 commit 030f4d0

File tree

2 files changed

+349
-0
lines changed

2 files changed

+349
-0
lines changed

astroid/inference.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
"""this module contains a set of functions to handle inference on astroid trees
2727
"""
2828

29+
import ast
2930
import functools
3031
import itertools
3132
import operator
33+
from typing import Any, Iterable
3234

3335
import wrapt
3436

@@ -798,6 +800,98 @@ def infer_binop(self, context=None):
798800
nodes.BinOp._infer_binop = _infer_binop
799801
nodes.BinOp._infer = infer_binop
800802

803+
COMPARE_OPS = {
804+
"==": operator.eq,
805+
"!=": operator.ne,
806+
"<": operator.lt,
807+
"<=": operator.le,
808+
">": operator.gt,
809+
">=": operator.ge,
810+
"in": lambda a, b: a in b,
811+
"not in": lambda a, b: a not in b,
812+
}
813+
UNINFERABLE_OPS = {
814+
"is",
815+
"is not",
816+
}
817+
818+
819+
def _to_literal(node: nodes.NodeNG) -> Any:
820+
# Can raise SyntaxError or ValueError from ast.literal_eval
821+
# Is this the stupidest idea or the simplest idea?
822+
return ast.literal_eval(node.as_string())
823+
824+
825+
def _do_compare(
826+
left_iter: Iterable[nodes.NodeNG], op: str, right_iter: Iterable[nodes.NodeNG]
827+
) -> "bool | type[util.Uninferable]":
828+
"""
829+
If all possible combinations are either True or False, return that:
830+
>>> _do_compare([1, 2], '<=', [3, 4])
831+
True
832+
>>> _do_compare([1, 2], '==', [3, 4])
833+
False
834+
835+
If any item is uninferable, or if some combinations are True and some
836+
are False, return Uninferable:
837+
>>> _do_compare([1, 3], '<=', [2, 4])
838+
util.Uninferable
839+
"""
840+
retval = None
841+
if op in UNINFERABLE_OPS:
842+
return util.Uninferable
843+
op_func = COMPARE_OPS[op]
844+
845+
for left, right in itertools.product(left_iter, right_iter):
846+
if left is util.Uninferable or right is util.Uninferable:
847+
return util.Uninferable
848+
849+
try:
850+
left, right = _to_literal(left), _to_literal(right)
851+
except (SyntaxError, ValueError):
852+
return util.Uninferable
853+
854+
try:
855+
expr = op_func(left, right)
856+
except TypeError as exc:
857+
raise AstroidTypeError from exc
858+
859+
if retval is None:
860+
retval = expr
861+
elif retval != expr:
862+
return util.Uninferable
863+
# (or both, but "True | False" is basically the same)
864+
865+
return retval # it was all the same value
866+
867+
868+
def _infer_compare(self: nodes.Compare, context: contextmod.InferenceContext) -> Any:
869+
"""Chained comparison inference logic."""
870+
retval = True
871+
872+
ops = self.ops
873+
left_node = self.left
874+
lhs = list(left_node.infer(context=context))
875+
# should we break early if first element is uninferable?
876+
for op, right_node in ops:
877+
# eagerly evaluate rhs so that values can be re-used as lhs
878+
rhs = list(right_node.infer(context=context))
879+
try:
880+
retval = _do_compare(lhs, op, rhs)
881+
except AstroidTypeError:
882+
retval = util.Uninferable
883+
break
884+
if retval is not True:
885+
break # short-circuit
886+
lhs = rhs # continue
887+
if retval is util.Uninferable:
888+
yield retval
889+
else:
890+
yield nodes.Const(retval)
891+
892+
893+
nodes.Compare._infer = _infer_compare
894+
801895

802896
def _infer_augassign(self, context=None):
803897
"""Inference logic for augmented binary operations."""

tests/unittest_inference.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5216,6 +5216,261 @@ def f(**kwargs):
52165216
assert next(extract_node(code).infer()).as_string() == "{'f': 1}"
52175217

52185218

5219+
@pytest.mark.parametrize(
5220+
"op,result",
5221+
[
5222+
("<", False),
5223+
("<=", True),
5224+
("==", True),
5225+
(">=", True),
5226+
(">", False),
5227+
("!=", False),
5228+
],
5229+
)
5230+
def test_compare(op, result):
5231+
code = """
5232+
123 {} 123
5233+
""".format(
5234+
op
5235+
)
5236+
node = extract_node(code)
5237+
inferred = next(node.infer())
5238+
assert inferred.value == result
5239+
5240+
5241+
@pytest.mark.xfail(reason="uninferable")
5242+
@pytest.mark.parametrize(
5243+
"op,result",
5244+
[
5245+
("is", True),
5246+
("is not", False),
5247+
],
5248+
)
5249+
def test_compare_identity(op, result):
5250+
code = """
5251+
obj = object()
5252+
obj {} obj
5253+
""".format(
5254+
op
5255+
)
5256+
node = extract_node(code)
5257+
inferred = next(node.infer())
5258+
assert inferred.value == result
5259+
5260+
5261+
@pytest.mark.parametrize(
5262+
"op,result",
5263+
[
5264+
("in", True),
5265+
("not in", False),
5266+
],
5267+
)
5268+
def test_compare_membership(op, result):
5269+
code = """
5270+
1 {} [1, 2, 3]
5271+
""".format(
5272+
op
5273+
)
5274+
node = extract_node(code)
5275+
inferred = next(node.infer())
5276+
assert inferred.value == result
5277+
5278+
5279+
@pytest.mark.parametrize(
5280+
"lhs,rhs,result",
5281+
[
5282+
(1, 1, True),
5283+
(1, 1.1, True),
5284+
(1.1, 1, False),
5285+
(1.0, 1.0, True),
5286+
("abc", "def", True),
5287+
("abc", "", False),
5288+
([], [1], True),
5289+
((1, 2), (2, 3), True),
5290+
((1, 0), (1,), False),
5291+
(True, True, True),
5292+
(True, False, False),
5293+
(False, 1, True),
5294+
(1 + 0j, 2 + 0j, util.Uninferable),
5295+
(+0.0, -0.0, True),
5296+
(0, "1", util.Uninferable),
5297+
(b"\x00", b"\x01", True),
5298+
],
5299+
)
5300+
def test_compare_lesseq_types(lhs, rhs, result):
5301+
code = """
5302+
{lhs!r} <= {rhs!r}
5303+
""".format(
5304+
lhs=lhs, rhs=rhs
5305+
)
5306+
node = extract_node(code)
5307+
inferred = next(node.infer())
5308+
assert inferred.value == result
5309+
5310+
5311+
def test_compare_chained():
5312+
code = """
5313+
3 < 5 > 3
5314+
"""
5315+
node = extract_node(code)
5316+
inferred = next(node.infer())
5317+
assert inferred.value is True
5318+
5319+
5320+
def test_compare_inferred_members():
5321+
code = """
5322+
a = 11
5323+
b = 13
5324+
a < b
5325+
"""
5326+
node = extract_node(code)
5327+
inferred = next(node.infer())
5328+
assert inferred.value is True
5329+
5330+
5331+
def test_compare_instance_members():
5332+
code = """
5333+
class A:
5334+
value = 123
5335+
class B:
5336+
@property
5337+
def value(self):
5338+
return 456
5339+
A().value < B().value
5340+
"""
5341+
node = extract_node(code)
5342+
inferred = next(node.infer())
5343+
assert inferred.value is True
5344+
5345+
5346+
@pytest.mark.xfail(reason="unimplemented")
5347+
def test_compare_dynamic():
5348+
code = """
5349+
class A:
5350+
def __le__(self, other):
5351+
return True
5352+
A() <= None
5353+
"""
5354+
node = extract_node(code)
5355+
inferred = next(node.infer())
5356+
assert inferred.value is True
5357+
5358+
5359+
def test_compare_uninferable_member():
5360+
code = """
5361+
from unknown import UNKNOWN
5362+
0 <= UNKNOWN
5363+
"""
5364+
node = extract_node(code)
5365+
inferred = next(node.infer())
5366+
assert inferred is util.Uninferable
5367+
5368+
5369+
def test_compare_chained_comparisons_shortcircuit_on_false():
5370+
code = """
5371+
from unknown import UNKNOWN
5372+
2 < 1 < UNKNOWN
5373+
"""
5374+
node = extract_node(code)
5375+
inferred = next(node.infer())
5376+
assert inferred.value is False
5377+
5378+
5379+
def test_compare_chained_comparisons_continue_on_true():
5380+
code = """
5381+
from unknown import UNKNOWN
5382+
1 < 2 < UNKNOWN
5383+
"""
5384+
node = extract_node(code)
5385+
inferred = next(node.infer())
5386+
assert inferred is util.Uninferable
5387+
5388+
5389+
@pytest.mark.xfail(reason="unimplemented")
5390+
def test_compare_known_false_branch():
5391+
code = """
5392+
a = 'hello'
5393+
if 1 < 2:
5394+
a = 'goodbye'
5395+
a
5396+
"""
5397+
node = extract_node(code)
5398+
inferred = list(node.infer())
5399+
assert len(inferred) == 1
5400+
assert isinstance(inferred[0], nodes.Const)
5401+
assert inferred[0].value == "hello"
5402+
5403+
5404+
def test_compare_ifexp_constant():
5405+
code = """
5406+
a = 'hello' if 1 < 2 else 'goodbye'
5407+
a
5408+
"""
5409+
node = extract_node(code)
5410+
inferred = list(node.infer())
5411+
assert len(inferred) == 1
5412+
assert isinstance(inferred[0], nodes.Const)
5413+
assert inferred[0].value == "hello"
5414+
5415+
5416+
def test_compare_typeerror():
5417+
code = """
5418+
123 <= "abc"
5419+
"""
5420+
node = extract_node(code)
5421+
inferred = list(node.infer())
5422+
assert len(inferred) == 1
5423+
assert inferred[0] is util.Uninferable
5424+
5425+
5426+
def test_compare_multiple_possibilites():
5427+
code = """
5428+
from unknown import UNKNOWN
5429+
a = 1
5430+
if UNKNOWN:
5431+
a = 2
5432+
b = 3
5433+
if UNKNOWN:
5434+
b = 4
5435+
a < b
5436+
"""
5437+
node = extract_node(code)
5438+
inferred = list(node.infer())
5439+
assert len(inferred) == 1
5440+
# All possible combinations are true: (1 < 3), (1 < 4), (2 < 3), (2 < 4)
5441+
assert inferred[0].value is True
5442+
5443+
5444+
def test_compare_ambiguous_multiple_possibilites():
5445+
code = """
5446+
from unknown import UNKNOWN
5447+
a = 1
5448+
if UNKNOWN:
5449+
a = 3
5450+
b = 2
5451+
if UNKNOWN:
5452+
b = 4
5453+
a < b
5454+
"""
5455+
node = extract_node(code)
5456+
inferred = list(node.infer())
5457+
assert len(inferred) == 1
5458+
# Not all possible combinations are true: (1 < 2), (1 < 4), (3 !< 2), (3 < 4)
5459+
assert inferred[0] is util.Uninferable
5460+
5461+
5462+
def test_compare_nonliteral():
5463+
code = """
5464+
def func(a, b):
5465+
return (a, b) <= (1, 2) #@
5466+
"""
5467+
return_node = extract_node(code)
5468+
node = return_node.value
5469+
inferred = list(node.infer()) # should not raise ValueError
5470+
assert len(inferred) == 1
5471+
assert inferred[0] is util.Uninferable
5472+
5473+
52195474
def test_limit_inference_result_amount():
52205475
"""Test setting limit inference result amount"""
52215476
code = """

0 commit comments

Comments
 (0)