Skip to content

Add basic support for enum literals #6668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from typing_extensions import Final


# Invalid types

INVALID_TYPE_RAW_ENUM_VALUE = "Invalid type: try using Literal[{}.{}] instead?" # type: Final


# Type checker error message constants --

NO_RETURN_VALUE_EXPECTED = 'No return value expected' # type: Final
Expand Down
6 changes: 5 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,11 @@ def format_bare(self, typ: Type, verbosity: int = 0) -> str:
s = 'TypedDict({{{}}})'.format(', '.join(items))
return s
elif isinstance(typ, LiteralType):
return str(typ)
if typ.is_enum_literal():
underlying_type = self.format_bare(typ.fallback, verbosity=verbosity)
return 'Literal[{}.{}]'.format(underlying_type, typ.value)
else:
return str(typ)
elif isinstance(typ, UnionType):
# Only print Unions as Optionals if the Optional wouldn't have to contain another Union
print_as_optional = (len(typ.items) -
Expand Down
41 changes: 35 additions & 6 deletions mypy/newsemanal/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ def __init__(self,
# Names of type aliases encountered while analysing a type will be collected here.
self.aliases_used = set() # type: Set[str]

def visit_unbound_type(self, t: UnboundType) -> Type:
typ = self.visit_unbound_type_nonoptional(t)
def visit_unbound_type(self, t: UnboundType, defining_literal: bool = False) -> Type:
typ = self.visit_unbound_type_nonoptional(t, defining_literal)
if t.optional:
# We don't need to worry about double-wrapping Optionals or
# wrapping Anys: Union simplification will take care of that.
return make_optional_type(typ)
return typ

def visit_unbound_type_nonoptional(self, t: UnboundType) -> Type:
def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) -> Type:
sym = self.lookup_qualified(t.name, t, suppress_errors=self.third_pass)
if sym is not None:
node = sym.node
Expand Down Expand Up @@ -217,7 +217,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType) -> Type:
elif isinstance(node, TypeInfo):
return self.analyze_type_with_type_info(node, t.args, t)
else:
return self.analyze_unbound_type_without_type_info(t, sym)
return self.analyze_unbound_type_without_type_info(t, sym, defining_literal)
else: # sym is None
if self.third_pass:
self.fail('Invalid type "{}"'.format(t.name), t)
Expand Down Expand Up @@ -348,7 +348,8 @@ def analyze_type_with_type_info(self, info: TypeInfo, args: List[Type], ctx: Con
fallback=instance)
return instance

def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode) -> Type:
def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode,
defining_literal: bool) -> Type:
"""Figure out what an unbound type that doesn't refer to a TypeInfo node means.

This is something unusual. We try our best to find out what it is.
Expand All @@ -373,6 +374,30 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl
if self.allow_unbound_tvars and unbound_tvar and not self.third_pass:
return t

# Option 3:
# Enum value. Note: we only want to return a LiteralType when
# we're using this enum value specifically within context of
# a "Literal[...]" type. So, if `defining_literal` is not set,
# we bail out early with an error.
#
# If, in the distant future, we decide to permit things like
# `def foo(x: Color.RED) -> None: ...`, we can remove that
# check entirely.
if isinstance(sym.node, Var) and sym.node.info and sym.node.info.is_enum:
value = sym.node.name()
base_enum_short_name = sym.node.info.name()
if not defining_literal:
msg = message_registry.INVALID_TYPE_RAW_ENUM_VALUE.format(
base_enum_short_name, value)
self.fail(msg, t)
return AnyType(TypeOfAny.from_error)
return LiteralType(
value=value,
fallback=Instance(sym.node.info, [], line=t.line, column=t.column),
line=t.line,
column=t.column,
)

# None of the above options worked, we give up.
self.fail('Invalid type "{}"'.format(name), t)

Expand Down Expand Up @@ -631,7 +656,11 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
# If arg is an UnboundType that was *not* originally defined as
# a string, try expanding it in case it's a type alias or something.
if isinstance(arg, UnboundType):
arg = self.anal_type(arg)
self.nesting_level += 1
try:
arg = self.visit_unbound_type(arg, defining_literal=True)
finally:
self.nesting_level -= 1

# Literal[...] cannot contain Any. Give up and add an error message
# (if we haven't already).
Expand Down
45 changes: 39 additions & 6 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ def __init__(self,
# Names of type aliases encountered while analysing a type will be collected here.
self.aliases_used = set() # type: Set[str]

def visit_unbound_type(self, t: UnboundType) -> Type:
typ = self.visit_unbound_type_nonoptional(t)
def visit_unbound_type(self, t: UnboundType, defining_literal: bool = False) -> Type:
typ = self.visit_unbound_type_nonoptional(t, defining_literal)
if t.optional:
# We don't need to worry about double-wrapping Optionals or
# wrapping Anys: Union simplification will take care of that.
return make_optional_type(typ)
return typ

def visit_unbound_type_nonoptional(self, t: UnboundType) -> Type:
def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) -> Type:
sym = self.lookup(t.name, t, suppress_errors=self.third_pass)
if '.' in t.name:
# Handle indirect references to imported names.
Expand Down Expand Up @@ -249,7 +249,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType) -> Type:
elif isinstance(node, TypeInfo):
return self.analyze_unbound_type_with_type_info(t, node)
else:
return self.analyze_unbound_type_without_type_info(t, sym)
return self.analyze_unbound_type_without_type_info(t, sym, defining_literal)
else: # sym is None
if self.third_pass:
self.fail('Invalid type "{}"'.format(t.name), t)
Expand Down Expand Up @@ -368,7 +368,8 @@ def analyze_unbound_type_with_type_info(self, t: UnboundType, info: TypeInfo) ->
fallback=instance)
return instance

def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode) -> Type:
def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode,
defining_literal: bool) -> Type:
"""Figure out what an unbound type that doesn't refer to a TypeInfo node means.

This is something unusual. We try our best to find out what it is.
Expand All @@ -377,6 +378,7 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl
if name is None:
assert sym.node is not None
name = sym.node.name()

# Option 1:
# Something with an Any type -- make it an alias for Any in a type
# context. This is slightly problematic as it allows using the type 'Any'
Expand All @@ -385,14 +387,40 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl
if isinstance(sym.node, Var) and isinstance(sym.node.type, AnyType):
return AnyType(TypeOfAny.from_unimported_type,
missing_import_name=sym.node.type.missing_import_name)

# Option 2:
# Unbound type variable. Currently these may be still valid,
# for example when defining a generic type alias.
unbound_tvar = (isinstance(sym.node, TypeVarExpr) and
(not self.tvar_scope or self.tvar_scope.get_binding(sym) is None))
if self.allow_unbound_tvars and unbound_tvar and not self.third_pass:
return t

# Option 3:
# Enum value. Note: we only want to return a LiteralType when
# we're using this enum value specifically within context of
# a "Literal[...]" type. So, if `defining_literal` is not set,
# we bail out early with an error.
#
# If, in the distant future, we decide to permit things like
# `def foo(x: Color.RED) -> None: ...`, we can remove that
# check entirely.
if isinstance(sym.node, Var) and not t.args and sym.node.info and sym.node.info.is_enum:
value = sym.node.name()
base_enum_short_name = sym.node.info.name()
if not defining_literal:
msg = message_registry.INVALID_TYPE_RAW_ENUM_VALUE.format(
base_enum_short_name, value)
self.fail(msg, t)
return AnyType(TypeOfAny.from_error)
return LiteralType(
value=value,
fallback=Instance(sym.node.info, [], line=t.line, column=t.column),
line=t.line,
column=t.column,
)

# Option 4:
# If it is not something clearly bad (like a known function, variable,
# type variable, or module), and it is still not too late, we try deferring
# this type using a forward reference wrapper. It will be revisited in
Expand All @@ -410,6 +438,7 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl
self.fail('Unsupported forward reference to "{}"'.format(t.name), t)
return AnyType(TypeOfAny.from_error)
return ForwardRef(t)

# None of the above options worked, we give up.
self.fail('Invalid type "{}"'.format(name), t)
if self.third_pass and isinstance(sym.node, TypeVarExpr):
Expand Down Expand Up @@ -657,7 +686,11 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
# If arg is an UnboundType that was *not* originally defined as
# a string, try expanding it in case it's a type alias or something.
if isinstance(arg, UnboundType):
arg = self.anal_type(arg)
self.nesting_level += 1
try:
arg = self.visit_unbound_type(arg, defining_literal=True)
finally:
self.nesting_level -= 1

# Literal[...] cannot contain Any. Give up and add an error message
# (if we haven't already).
Expand Down
13 changes: 12 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,9 @@ class LiteralType(Type):

For example, 'Literal[42]' is represented as
'LiteralType(value=42, fallback=instance_of_int)'

As another example, `Literal[Color.RED]` (where Color is an enum) is
represented as `LiteralType(value="RED", fallback=instance_of_color)'.
"""
__slots__ = ('value', 'fallback')

Expand All @@ -1464,15 +1467,23 @@ def __eq__(self, other: object) -> bool:
else:
return NotImplemented

def is_enum_literal(self) -> bool:
return self.fallback.type.is_enum

def value_repr(self) -> str:
"""Returns the string representation of the underlying type.

This function is almost equivalent to running `repr(self.value)`,
except it includes some additional logic to correctly handle cases
where the value is a string, byte string, or a unicode string.
where the value is a string, byte string, a unicode string, or an enum.
"""
raw = repr(self.value)
fallback_name = self.fallback.type.fullname()

# If this is backed by an enum,
if self.is_enum_literal():
return '{}.{}'.format(fallback_name, self.value)

if fallback_name == 'builtins.bytes':
# Note: 'builtins.bytes' only appears in Python 3, so we want to
# explicitly prefix with a "b"
Expand Down
Loading