Skip to content

Tweaks to --strict-equality based on user feedback #6674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 27, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,7 +1938,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
self.msg.unsupported_operand_types('in', left_type, right_type, e)
# Only show dangerous overlap if there are no other errors.
elif (not local_errors.is_errors() and cont_type and
self.dangerous_comparison(left_type, cont_type)):
self.dangerous_comparison(left_type, cont_type,
original_cont_type=right_type)):
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
else:
self.msg.add_errors(local_errors)
Expand All @@ -1951,8 +1952,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
# testCustomEqCheckStrictEquality for an example.
if self.msg.errors.total_errors() == err_count and operator in ('==', '!='):
right_type = self.accept(right)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)
if (not self.custom_equality_method(left_type) and
not self.custom_equality_method(right_type)):
# We suppress the error if there is a custom __eq__() method on either
# side. User defined (or even standard library) classes can define this
# to return True for comparisons between non-overlapping types.
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)

elif operator == 'is' or operator == 'is not':
right_type = self.accept(right) # validate the right operand
Expand All @@ -1974,9 +1980,33 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
assert result is not None
return result

def dangerous_comparison(self, left: Type, right: Type) -> bool:
def custom_equality_method(self, typ: Type) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: This should be a module-level function, since it doesn't depend on self.

"""Does this type have a custom __eq__() method?"""
if isinstance(typ, UnionType):
return any(self.custom_equality_method(t) for t in typ.items)
if isinstance(typ, Instance):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A really minor nit: maybe handle Instance first, since it's probably the most common case.

method = typ.type.get_method('__eq__')
if method and method.info:
return not method.info.fullname().startswith('builtins.')
return False
# TODO: support other types (see has_member())?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it would be good to support some additional types -- at least TupleType (use fallback). Maybe Any should return true as well?

return False

def has_bytes_component(self, typ: Type) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: This should be a module-level function.

"""Is this the builtin bytes type, or a union that contains it?"""
if isinstance(typ, UnionType):
return any(self.has_bytes_component(t) for t in typ.items)
if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes':
return True
return False

def dangerous_comparison(self, left: Type, right: Type,
original_cont_type: Optional[Type] = None) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: original_cont_type was not immediately clear. Maybe rename to original_container or container_type?

"""Check for dangerous non-overlapping comparisons like 42 == 'no'.

The original_cont_type is the original container type for 'in' checks
(and None for equality checks).

Rules:
* X and None are overlapping even in strict-optional mode. This is to allow
'assert x is not None' for x defined as 'x = None # type: str' in class body
Expand All @@ -1985,9 +2015,7 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool:
non-overlapping, although technically None is overlap, it is most
likely an error.
* Any overlaps with everything, i.e. always safe.
* Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0
are errors. This is mostly needed for bytes vs unicode, and
int vs float are added just for consistency.
* Special case: b'abc' in b'cde' is safe.
"""
if not self.chk.options.strict_equality:
return False
Expand All @@ -1996,7 +2024,12 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool:
if isinstance(left, UnionType) and isinstance(right, UnionType):
left = remove_optional(left)
right = remove_optional(right)
return not is_overlapping_types(left, right, ignore_promotions=True)
if (original_cont_type and self.has_bytes_component(original_cont_type) and
self.has_bytes_component(left)):
# We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc'
# return True (and we want to show the error only if the check can _never_ be True).
return False
return not is_overlapping_types(left, right, ignore_promotions=False)

def get_operator_method(self, op: str) -> str:
if op == '/' and self.chk.options.python_version[0] == 2:
Expand Down
52 changes: 51 additions & 1 deletion test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2024,7 +2024,23 @@ cb: Union[Container[A], Container[B]]
[builtins fixtures/bool.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityNoPromote]
[case testStrictEqualityBytesSpecial]
# flags: --strict-equality
b'abc' in b'abcde'
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityBytesSpecialUnion]
# flags: --strict-equality
from typing import Union
x: Union[bytes, str]

b'abc' in x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b'a' in 'b' fails at runtime. Should this generate an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this is independent of this flag. The error message says "Non-overlapping container check ..." while in this example the check may return True. I think this can be tightened in typeshed, we can just define str.__contains__ as accepting str, because 42 in 'b' etc. all fail as well at runtime.

x in b'abc'
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityNoPromotePy3]
# flags: --strict-equality
'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes")
b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str")
Expand All @@ -2035,6 +2051,16 @@ x != y # E: Non-overlapping equality check (left operand type: "str", right ope
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityOkPromote]
# flags: --strict-equality
from typing import Container
c: Container[int]

1 == 1.0 # OK
1.0 in c # OK
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityAny]
# flags: --strict-equality
from typing import Any, Container
Expand Down Expand Up @@ -2086,6 +2112,30 @@ class B:
A() == B() # E: Unsupported operand types for == ("A" and "B")
[builtins fixtures/bool.pyi]

[case testCustomEqCheckStrictEqualityOKInstance]
# flags: --strict-equality
class A:
def __eq__(self, other: object) -> bool:
...
class B:
def __eq__(self, other: object) -> bool:
...

A() == int() # OK
int() != B() # OK
[builtins fixtures/bool.pyi]

[case testCustomEqCheckStrictEqualityOKUnion]
# flags: --strict-equality
from typing import Union
class A:
def __eq__(self, other: object) -> bool:
...

x: Union[A, str]
x == int()
[builtins fixtures/bool.pyi]

[case testCustomContainsCheckStrictEquality]
# flags: --strict-equality
class A:
Expand Down
5 changes: 4 additions & 1 deletion test-data/unit/fixtures/primitives.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class str(Sequence[str]):
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> str: pass
def format(self, *args) -> str: pass
class bytes: pass
class bytes(Sequence[int]):
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> int: pass
class bytearray: pass
class tuple(Generic[T]): pass
class function: pass
Expand Down