Skip to content

Commit be8fd1e

Browse files
author
Guido van Rossum
committed
Support functional API for Enum.
Fixes #2306.
1 parent ed7d0c0 commit be8fd1e

File tree

8 files changed

+282
-5
lines changed

8 files changed

+282
-5
lines changed

mypy/checker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension,
2525
DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr,
2626
RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase,
27-
AwaitExpr, PromoteExpr, Node,
27+
AwaitExpr, PromoteExpr, Node, EnumCallExpr,
2828
ARG_POS, MDEF,
2929
CONTRAVARIANT, COVARIANT)
3030
from mypy import nodes
@@ -2244,6 +2244,9 @@ def visit_newtype_expr(self, e: NewTypeExpr) -> Type:
22442244
def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type:
22452245
return self.expr_checker.visit_namedtuple_expr(e)
22462246

2247+
def visit_enum_call_expr(self, e: EnumCallExpr) -> Type:
2248+
return self.expr_checker.visit_enum_call_expr(e)
2249+
22472250
def visit_typeddict_expr(self, e: TypedDictExpr) -> Type:
22482251
return self.expr_checker.visit_typeddict_expr(e)
22492252

mypy/checkexpr.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
2020
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
2121
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
22-
TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF,
23-
UNBOUND_TVAR, BOUND_TVAR,
22+
TypeAliasExpr, BackquoteExpr, EnumCallExpr,
23+
ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, UNBOUND_TVAR, BOUND_TVAR,
2424
)
2525
from mypy import nodes
2626
import mypy.checker
@@ -343,6 +343,12 @@ def check_call(self, callee: Type, args: List[Expression],
343343
"""
344344
arg_messages = arg_messages or self.msg
345345
if isinstance(callee, CallableType):
346+
if (isinstance(callable_node, RefExpr)
347+
and callable_node.fullname in ('enum.Enum', 'enum.IntEnum',
348+
'enum.Flag', 'enum.IntFlag')):
349+
# An Enum() call that failed SemanticAnalyzer.check_enum_call().
350+
return callee.ret_type, callee
351+
346352
if callee.is_concrete_type_obj() and callee.type_object().is_abstract:
347353
type = callee.type_object()
348354
self.msg.cannot_instantiate_abstract_class(
@@ -2156,6 +2162,22 @@ def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type:
21562162
# TODO: Perhaps return a type object type?
21572163
return AnyType()
21582164

2165+
def visit_enum_call_expr(self, e: EnumCallExpr) -> Type:
2166+
for name, value in zip(e.items, e.values):
2167+
if value is not None:
2168+
typ = self.accept(value)
2169+
if not isinstance(typ, AnyType):
2170+
var = e.info.names[name].node
2171+
if isinstance(var, Var):
2172+
# Inline TypeCheker.set_inferred_type(),
2173+
# without the lvalue. (This doesn't really do
2174+
# much, since the value attribute is defined
2175+
# to have type Any in the typeshed stub.)
2176+
var.type = typ
2177+
var.is_inferred = True
2178+
# TODO: Perhaps return a type object type?
2179+
return AnyType()
2180+
21592181
def visit_typeddict_expr(self, e: TypedDictExpr) -> Type:
21602182
# TODO: Perhaps return a type object type?
21612183
return AnyType()

mypy/nodes.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,6 +1792,25 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
17921792
return visitor.visit_typeddict_expr(self)
17931793

17941794

1795+
class EnumCallExpr(Expression):
1796+
"""Named tuple expression Enum('name', 'val1 val2 ...')."""
1797+
1798+
# The class representation of this enumerated type
1799+
info = None # type: TypeInfo
1800+
# The item names (for debugging)
1801+
items = None # type: List[str]
1802+
values = None # type: List[Optional[Expression]]
1803+
1804+
def __init__(self, info: 'TypeInfo', items: List[str],
1805+
values: List[Optional[Expression]]) -> None:
1806+
self.info = info
1807+
self.items = items
1808+
self.values = values
1809+
1810+
def accept(self, visitor: ExpressionVisitor[T]) -> T:
1811+
return visitor.visit_enum_call_expr(self)
1812+
1813+
17951814
class PromoteExpr(Expression):
17961815
"""Ducktype class decorator expression _promote(...)."""
17971816

mypy/semanal.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SymbolNode,
6565
SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr,
6666
YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr,
67-
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode,
67+
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr,
6868
COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES,
6969
)
7070
from mypy.typevars import has_no_typevars, fill_typevars
@@ -1195,6 +1195,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
11951195
self.process_typevar_declaration(s)
11961196
self.process_namedtuple_definition(s)
11971197
self.process_typeddict_definition(s)
1198+
self.process_enum_call(s)
11981199

11991200
if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and
12001201
s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and
@@ -1991,6 +1992,139 @@ def build_typeddict_typeinfo(self, name: str, items: List[str],
19911992

19921993
return info
19931994

1995+
def process_enum_call(self, s: AssignmentStmt) -> None:
1996+
"""Check if s defines an Enum; if yes, store the definition in symbol table."""
1997+
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr):
1998+
return
1999+
lvalue = s.lvalues[0]
2000+
name = lvalue.name
2001+
enum_call = self.check_enum_call(s.rvalue, name)
2002+
if enum_call is None:
2003+
return
2004+
# Yes, it's a valid Enum definition. Add it to the symbol table.
2005+
node = self.lookup(name, s)
2006+
if node:
2007+
node.kind = GDEF # TODO locally defined Enum
2008+
node.node = enum_call
2009+
2010+
def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]:
2011+
"""Check if a call defines an Enum.
2012+
2013+
Example:
2014+
2015+
A = enum.Enum('A', 'foo bar')
2016+
2017+
is equivalent to:
2018+
2019+
class A(enum.Enum):
2020+
foo = 1
2021+
bar = 2
2022+
"""
2023+
if not isinstance(node, CallExpr):
2024+
return None
2025+
call = node
2026+
callee = call.callee
2027+
if not isinstance(callee, RefExpr):
2028+
return None
2029+
fullname = callee.fullname
2030+
if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'):
2031+
return None
2032+
items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1])
2033+
if not ok:
2034+
# Error. Construct dummy return value.
2035+
return self.build_enum_call_typeinfo('Enum', [], fullname)
2036+
name = cast(StrExpr, call.args[0]).value
2037+
if name != var_name or self.is_func_scope():
2038+
# Give it a unique name derived from the line number.
2039+
name += '@' + str(call.line)
2040+
info = self.build_enum_call_typeinfo(name, items, fullname)
2041+
# Store it as a global just in case it would remain anonymous.
2042+
# (Or in the nearest class if there is one.)
2043+
stnode = SymbolTableNode(GDEF, info, self.cur_mod_id)
2044+
if self.type:
2045+
self.type.names[name] = stnode
2046+
else:
2047+
self.globals[name] = stnode
2048+
call.analyzed = EnumCallExpr(info, items, values)
2049+
call.analyzed.set_line(call.line, call.column)
2050+
return info
2051+
2052+
def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo:
2053+
base = self.named_type_or_none(fullname)
2054+
assert base is not None
2055+
info = self.basic_new_typeinfo(name, base)
2056+
info.is_enum = True
2057+
for item in items:
2058+
var = Var(item)
2059+
var.info = info
2060+
var.is_property = True
2061+
info.names[item] = SymbolTableNode(MDEF, var)
2062+
return info
2063+
2064+
def parse_enum_call_args(self, call: CallExpr,
2065+
class_name: str) -> Tuple[List[str],
2066+
List[Optional[Expression]], bool]:
2067+
args = call.args
2068+
if len(args) < 2:
2069+
return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call)
2070+
if len(args) > 2:
2071+
return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call)
2072+
if call.arg_kinds != [ARG_POS, ARG_POS]:
2073+
return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call)
2074+
if not isinstance(args[0], (StrExpr, UnicodeExpr)):
2075+
return self.fail_enum_call_arg(
2076+
"%s() expects a string literal as the first argument" % class_name, call)
2077+
items = []
2078+
values = [] # type: List[Optional[Expression]]
2079+
if isinstance(args[1], (StrExpr, UnicodeExpr)):
2080+
fields = args[1].value
2081+
for field in fields.replace(',', ' ').split():
2082+
items.append(field)
2083+
elif isinstance(args[1], (TupleExpr, ListExpr)):
2084+
seq_items = args[1].items
2085+
if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items):
2086+
items = [cast(StrExpr, seq_item).value for seq_item in seq_items]
2087+
elif all(isinstance(seq_item, (TupleExpr, ListExpr))
2088+
and len(seq_item.items) == 2
2089+
and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr))
2090+
for seq_item in seq_items):
2091+
for seq_item in seq_items:
2092+
assert isinstance(seq_item, (TupleExpr, ListExpr))
2093+
name, value = seq_item.items
2094+
assert isinstance(name, (StrExpr, UnicodeExpr))
2095+
items.append(name.value)
2096+
values.append(value)
2097+
else:
2098+
return self.fail_enum_call_arg(
2099+
"%s() with tuple or list of (name, value) pairs not yet supported" %
2100+
class_name,
2101+
call)
2102+
elif isinstance(args[1], DictExpr):
2103+
for key, value in args[1].items:
2104+
if not isinstance(key, (StrExpr, UnicodeExpr)):
2105+
return self.fail_enum_call_arg(
2106+
"%s() with dict literal requires string literals" % class_name, call)
2107+
items.append(key.value)
2108+
values.append(value)
2109+
else:
2110+
# TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}?
2111+
return self.fail_enum_call_arg(
2112+
"%s() expects a string, tuple, list or dict literal as the second argument" %
2113+
class_name,
2114+
call)
2115+
if len(items) == 0:
2116+
return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call)
2117+
if not values:
2118+
values = [None] * len(items)
2119+
assert len(items) == len(values)
2120+
return items, values, True
2121+
2122+
def fail_enum_call_arg(self, message: str,
2123+
context: Context) -> Tuple[List[str],
2124+
List[Optional[Expression]], bool]:
2125+
self.fail(message, context)
2126+
return [], [], False
2127+
19942128
def visit_decorator(self, dec: Decorator) -> None:
19952129
for d in dec.decorators:
19962130
d.accept(self)

mypy/strconv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,9 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> str:
429429
o.info.name(),
430430
o.info.tuple_type)
431431

432+
def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> str:
433+
return 'EnumCallExpr:{}({}, {})'.format(o.line, o.info.name(), o.items)
434+
432435
def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> str:
433436
return 'TypedDictExpr:{}({})'.format(o.line,
434437
o.info.name())

mypy/treetransform.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ComparisonExpr, TempNode, StarExpr, Statement, Expression,
2020
YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension,
2121
DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr,
22-
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr,
22+
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, EnumCallExpr,
2323
)
2424
from mypy.types import Type, FunctionLike
2525
from mypy.traverser import TraverserVisitor
@@ -483,6 +483,9 @@ def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr:
483483
def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr:
484484
return NamedTupleExpr(node.info)
485485

486+
def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr:
487+
return EnumCallExpr(node.info, node.items, node.values)
488+
486489
def visit_typeddict_expr(self, node: TypedDictExpr) -> Node:
487490
return TypedDictExpr(node.info)
488491

mypy/visitor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
156156
def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
157157
pass
158158

159+
@abstractmethod
160+
def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
161+
pass
162+
159163
@abstractmethod
160164
def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
161165
pass
@@ -392,6 +396,9 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
392396
def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
393397
pass
394398

399+
def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
400+
pass
401+
395402
def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
396403
pass
397404

test-data/unit/pythoneval-enum.test

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,89 @@ class E(N, Enum):
132132
def f(x: E) -> None: pass
133133

134134
f(E.X)
135+
136+
[case testFunctionalEnumString]
137+
from enum import Enum, IntEnum
138+
E = Enum('E', 'foo bar')
139+
I = IntEnum('I', ' bar, baz ')
140+
reveal_type(E.foo)
141+
reveal_type(E.bar.value)
142+
reveal_type(I.bar)
143+
reveal_type(I.baz.value)
144+
[out]
145+
_program.py:4: error: Revealed type is '_testFunctionalEnumString.E'
146+
_program.py:5: error: Revealed type is 'Any'
147+
_program.py:6: error: Revealed type is '_testFunctionalEnumString.I'
148+
_program.py:7: error: Revealed type is 'builtins.int'
149+
150+
[case testFunctionalEnumListOfStrings]
151+
from enum import Enum, IntEnum
152+
E = Enum('E', ('foo', 'bar'))
153+
F = Enum('F', ['bar', 'baz'])
154+
reveal_type(E.foo)
155+
reveal_type(F.baz)
156+
[out]
157+
_program.py:4: error: Revealed type is '_testFunctionalEnumListOfStrings.E'
158+
_program.py:5: error: Revealed type is '_testFunctionalEnumListOfStrings.F'
159+
160+
[case testFunctionalEnumListOfPairs]
161+
from enum import Enum, IntEnum
162+
E = Enum('E', [('foo', 1), ['bar', 2]])
163+
F = Enum('F', (['bar', 1], ('baz', 2)))
164+
reveal_type(E.foo)
165+
reveal_type(F.baz)
166+
reveal_type(E.foo.value)
167+
reveal_type(F.bar.name)
168+
[out]
169+
_program.py:4: error: Revealed type is '_testFunctionalEnumListOfPairs.E'
170+
_program.py:5: error: Revealed type is '_testFunctionalEnumListOfPairs.F'
171+
_program.py:6: error: Revealed type is 'Any'
172+
_program.py:7: error: Revealed type is 'builtins.str'
173+
174+
[case testFunctionalEnumDict]
175+
from enum import Enum, IntEnum
176+
E = Enum('E', {'foo': 1, 'bar': 2})
177+
F = IntEnum('F', {'bar': 1, 'baz': 2})
178+
reveal_type(E.foo)
179+
reveal_type(F.baz)
180+
reveal_type(E.foo.value)
181+
reveal_type(F.bar.name)
182+
[out]
183+
_program.py:4: error: Revealed type is '_testFunctionalEnumDict.E'
184+
_program.py:5: error: Revealed type is '_testFunctionalEnumDict.F'
185+
_program.py:6: error: Revealed type is 'Any'
186+
_program.py:7: error: Revealed type is 'builtins.str'
187+
188+
[case testFunctionalEnumErrors]
189+
from enum import Enum, IntEnum
190+
A = Enum('A')
191+
B = Enum('B', 42)
192+
C = Enum('C', 'a b', 'x')
193+
D = Enum('D', foo)
194+
bar = 'x y z'
195+
E = Enum('E', bar)
196+
I = IntEnum('I')
197+
J = IntEnum('I', 42)
198+
K = IntEnum('I', 'p q', 'z')
199+
L = Enum('L', ' ')
200+
M = Enum('M', ())
201+
N = IntEnum('M', [])
202+
P = Enum('P', [42])
203+
Q = Enum('Q', [('a', 42, 0)])
204+
R = IntEnum('R', [[0, 42]])
205+
[out]
206+
_program.py:2: error: Too few arguments for Enum()
207+
_program.py:3: error: Enum() expects a string, tuple, list or dict literal as the second argument
208+
_program.py:4: error: Too many arguments for Enum()
209+
_program.py:5: error: Enum() expects a string, tuple, list or dict literal as the second argument
210+
_program.py:5: error: Name 'foo' is not defined
211+
_program.py:7: error: Enum() expects a string, tuple, list or dict literal as the second argument
212+
_program.py:8: error: Too few arguments for IntEnum()
213+
_program.py:9: error: IntEnum() expects a string, tuple, list or dict literal as the second argument
214+
_program.py:10: error: Too many arguments for IntEnum()
215+
_program.py:11: error: Enum() needs at least one item
216+
_program.py:12: error: Enum() needs at least one item
217+
_program.py:13: error: IntEnum() needs at least one item
218+
_program.py:14: error: Enum() with tuple or list of (name, value) pairs not yet supported
219+
_program.py:15: error: Enum() with tuple or list of (name, value) pairs not yet supported
220+
_program.py:16: error: IntEnum() with tuple or list of (name, value) pairs not yet supported

0 commit comments

Comments
 (0)