Skip to content

Add support for using __bool__ method literal value in union narrowing in if statements #9297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,19 @@ def make_simplified_union(items: Sequence[Type],
return UnionType.make_union(simplified_set, line, column)


def get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]:
t = get_proper_type(t)

if isinstance(t, Instance):
bool_method = t.type.names.get("__bool__", None)
if bool_method:
callee = get_proper_type(bool_method.type)
if isinstance(callee, CallableType):
return callee.ret_type

return None


def true_only(t: Type) -> ProperType:
"""
Restricted version of t with only True-ish values
Expand All @@ -393,8 +406,16 @@ def true_only(t: Type) -> ProperType:
elif isinstance(t, UnionType):
# The true version of a union type is the union of the true versions of its components
new_items = [true_only(item) for item in t.items]
return make_simplified_union(new_items, line=t.line, column=t.column)
can_be_true_items = [item for item in new_items if item.can_be_true]
return make_simplified_union(can_be_true_items, line=t.line, column=t.column)
else:
ret_type = get_type_special_method_bool_ret_type(t)

if ret_type and ret_type.can_be_false and not ret_type.can_be_true:
new_t = copy_type(t)
new_t.can_be_true = False
return new_t

new_t = copy_type(t)
new_t.can_be_false = False
return new_t
Expand All @@ -420,8 +441,16 @@ def false_only(t: Type) -> ProperType:
elif isinstance(t, UnionType):
# The false version of a union type is the union of the false versions of its components
new_items = [false_only(item) for item in t.items]
return make_simplified_union(new_items, line=t.line, column=t.column)
can_be_false_items = [item for item in new_items if item.can_be_false]
return make_simplified_union(can_be_false_items, line=t.line, column=t.column)
else:
ret_type = get_type_special_method_bool_ret_type(t)

if ret_type and ret_type.can_be_true and not ret_type.can_be_false:
new_t = copy_type(t)
new_t.can_be_false = False
return new_t

new_t = copy_type(t)
new_t.can_be_true = False
return new_t
Expand Down
63 changes: 63 additions & 0 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -3243,3 +3243,66 @@ assert c.a is True
c.update()
assert c.a is False
[builtins fixtures/bool.pyi]

[case testConditionalBoolLiteralUnionNarrowing]
# flags: --warn-unreachable

from typing import Union
from typing_extensions import Literal

class Truth:
def __bool__(self) -> Literal[True]: ...

class AlsoTruth:
def __bool__(self) -> Literal[True]: ...

class Lie:
def __bool__(self) -> Literal[False]: ...

class AnyAnswer:
def __bool__(self) -> bool: ...

class NoAnswerSpecified:
pass

x: Union[Truth, Lie]

if x:
reveal_type(x) # N: Revealed type is '__main__.Truth'
else:
reveal_type(x) # N: Revealed type is '__main__.Lie'

if not x:
reveal_type(x) # N: Revealed type is '__main__.Lie'
else:
reveal_type(x) # N: Revealed type is '__main__.Truth'

y: Union[Truth, AlsoTruth, Lie]

if y:
reveal_type(y) # N: Revealed type is 'Union[__main__.Truth, __main__.AlsoTruth]'
else:
reveal_type(y) # N: Revealed type is '__main__.Lie'

z: Union[Truth, AnyAnswer]

if z:
reveal_type(z) # N: Revealed type is 'Union[__main__.Truth, __main__.AnyAnswer]'
else:
reveal_type(z) # N: Revealed type is '__main__.AnyAnswer'

q: Union[Truth, NoAnswerSpecified]

if q:
reveal_type(q) # N: Revealed type is 'Union[__main__.Truth, __main__.NoAnswerSpecified]'
else:
reveal_type(q) # N: Revealed type is '__main__.NoAnswerSpecified'

w: Union[Truth, AlsoTruth]

if w:
reveal_type(w) # N: Revealed type is 'Union[__main__.Truth, __main__.AlsoTruth]'
else:
reveal_type(w) # E: Statement is unreachable

[builtins fixtures/bool.pyi]