Skip to content

Commit 50f6215

Browse files
gvanrossumJukkaL
authored andcommitted
Support functional API for Enum. (#2805)
Fixes #2306. Also move Enum tests from runtests to pytest.
1 parent 591a98c commit 50f6215

12 files changed

+606
-187
lines changed

mypy/checker.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717
TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt,
1818
WhileStmt, OperatorAssignmentStmt, WithStmt, AssertStmt,
1919
RaiseStmt, TryStmt, ForStmt, DelStmt, CallExpr, IntExpr, StrExpr,
20-
UnicodeExpr, OpExpr, UnaryExpr, LambdaExpr, TempNode, SymbolTableNode,
21-
Context, Decorator, PrintStmt, LITERAL_TYPE, BreakStmt, PassStmt, ContinueStmt,
22-
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, ImportFrom, ImportAll, ImportBase,
23-
ARG_POS, CONTRAVARIANT, COVARIANT, ExecStmt, GlobalDecl, Import, NonlocalDecl,
24-
MDEF, Node
25-
)
20+
BytesExpr, UnicodeExpr, FloatExpr, OpExpr, UnaryExpr, CastExpr, RevealTypeExpr, SuperExpr,
21+
TypeApplication, DictExpr, SliceExpr, LambdaExpr, TempNode, SymbolTableNode,
22+
Context, ListComprehension, ConditionalExpr, GeneratorExpr,
23+
Decorator, SetExpr, TypeVarExpr, NewTypeExpr, PrintStmt,
24+
LITERAL_TYPE, BreakStmt, PassStmt, ContinueStmt, ComparisonExpr, StarExpr,
25+
YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension,
26+
DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr,
27+
RefExpr, YieldExpr, BackquoteExpr, Import, ImportFrom, ImportAll, ImportBase,
28+
AwaitExpr, PromoteExpr, Node, EnumCallExpr,
29+
ARG_POS, MDEF,
30+
CONTRAVARIANT, COVARIANT)
2631
from mypy import nodes
2732
from mypy.types import (
2833
Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType,
@@ -45,7 +50,7 @@
4550
from mypy.semanal import set_callable_name, refers_to_fullname
4651
from mypy.erasetype import erase_typevars
4752
from mypy.expandtype import expand_type, expand_type_by_instance
48-
from mypy.visitor import StatementVisitor
53+
from mypy.visitor import NodeVisitor
4954
from mypy.join import join_types
5055
from mypy.treetransform import TransformVisitor
5156
from mypy.binder import ConditionalTypeBinder, get_declaration
@@ -70,7 +75,7 @@
7075
])
7176

7277

73-
class TypeChecker(StatementVisitor[None]):
78+
class TypeChecker(NodeVisitor[None]):
7479
"""Mypy type checker.
7580
7681
Type check mypy source files that have been semantically analyzed.
@@ -2259,21 +2264,7 @@ def visit_break_stmt(self, s: BreakStmt) -> None:
22592264

22602265
def visit_continue_stmt(self, s: ContinueStmt) -> None:
22612266
self.binder.handle_continue()
2262-
2263-
def visit_exec_stmt(self, s: ExecStmt) -> None:
2264-
pass
2265-
2266-
def visit_global_decl(self, s: GlobalDecl) -> None:
2267-
pass
2268-
2269-
def visit_nonlocal_decl(self, s: NonlocalDecl) -> None:
2270-
pass
2271-
2272-
def visit_var(self, s: Var) -> None:
2273-
pass
2274-
2275-
def visit_pass_stmt(self, s: PassStmt) -> None:
2276-
pass
2267+
return None
22772268

22782269
#
22792270
# Helpers

mypy/checkexpr.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
2121
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
2222
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
23-
TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF,
23+
TypeAliasExpr, BackquoteExpr, EnumCallExpr,
24+
ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF,
2425
UNBOUND_TVAR, BOUND_TVAR, LITERAL_TYPE
2526
)
2627
from mypy import nodes
@@ -349,6 +350,12 @@ def check_call(self, callee: Type, args: List[Expression],
349350
"""
350351
arg_messages = arg_messages or self.msg
351352
if isinstance(callee, CallableType):
353+
if (isinstance(callable_node, RefExpr)
354+
and callable_node.fullname in ('enum.Enum', 'enum.IntEnum',
355+
'enum.Flag', 'enum.IntFlag')):
356+
# An Enum() call that failed SemanticAnalyzer.check_enum_call().
357+
return callee.ret_type, callee
358+
352359
if (callee.is_type_obj() and callee.type_object().is_abstract
353360
# Exceptions for Type[...] and classmethod first argument
354361
and not callee.from_type_type and not callee.is_classmethod_class):
@@ -2199,6 +2206,22 @@ def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type:
21992206
# TODO: Perhaps return a type object type?
22002207
return AnyType()
22012208

2209+
def visit_enum_call_expr(self, e: EnumCallExpr) -> Type:
2210+
for name, value in zip(e.items, e.values):
2211+
if value is not None:
2212+
typ = self.accept(value)
2213+
if not isinstance(typ, AnyType):
2214+
var = e.info.names[name].node
2215+
if isinstance(var, Var):
2216+
# Inline TypeCheker.set_inferred_type(),
2217+
# without the lvalue. (This doesn't really do
2218+
# much, since the value attribute is defined
2219+
# to have type Any in the typeshed stub.)
2220+
var.type = typ
2221+
var.is_inferred = True
2222+
# TODO: Perhaps return a type object type?
2223+
return AnyType()
2224+
22022225
def visit_typeddict_expr(self, e: TypedDictExpr) -> Type:
22032226
# TODO: Perhaps return a type object type?
22042227
return AnyType()

mypy/nodes.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,6 +1830,25 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
18301830
return visitor.visit_typeddict_expr(self)
18311831

18321832

1833+
class EnumCallExpr(Expression):
1834+
"""Named tuple expression Enum('name', 'val1 val2 ...')."""
1835+
1836+
# The class representation of this enumerated type
1837+
info = None # type: TypeInfo
1838+
# The item names (for debugging)
1839+
items = None # type: List[str]
1840+
values = None # type: List[Optional[Expression]]
1841+
1842+
def __init__(self, info: 'TypeInfo', items: List[str],
1843+
values: List[Optional[Expression]]) -> None:
1844+
self.info = info
1845+
self.items = items
1846+
self.values = values
1847+
1848+
def accept(self, visitor: ExpressionVisitor[T]) -> T:
1849+
return visitor.visit_enum_call_expr(self)
1850+
1851+
18331852
class PromoteExpr(Expression):
18341853
"""Ducktype class decorator expression _promote(...)."""
18351854

mypy/semanal.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SymbolNode,
6565
SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr,
6666
YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr,
67-
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode,
67+
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr,
6868
COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES, ARG_OPT, nongen_builtins,
6969
collections_type_aliases, get_member_expr_fullname,
7070
)
@@ -1498,6 +1498,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
14981498
self.process_typevar_declaration(s)
14991499
self.process_namedtuple_definition(s)
15001500
self.process_typeddict_definition(s)
1501+
self.process_enum_call(s)
15011502

15021503
if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and
15031504
s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and
@@ -2331,6 +2332,139 @@ def is_classvar(self, typ: Type) -> bool:
23312332
def fail_invalid_classvar(self, context: Context) -> None:
23322333
self.fail('ClassVar can only be used for assignments in class body', context)
23332334

2335+
def process_enum_call(self, s: AssignmentStmt) -> None:
2336+
"""Check if s defines an Enum; if yes, store the definition in symbol table."""
2337+
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr):
2338+
return
2339+
lvalue = s.lvalues[0]
2340+
name = lvalue.name
2341+
enum_call = self.check_enum_call(s.rvalue, name)
2342+
if enum_call is None:
2343+
return
2344+
# Yes, it's a valid Enum definition. Add it to the symbol table.
2345+
node = self.lookup(name, s)
2346+
if node:
2347+
node.kind = GDEF # TODO locally defined Enum
2348+
node.node = enum_call
2349+
2350+
def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]:
2351+
"""Check if a call defines an Enum.
2352+
2353+
Example:
2354+
2355+
A = enum.Enum('A', 'foo bar')
2356+
2357+
is equivalent to:
2358+
2359+
class A(enum.Enum):
2360+
foo = 1
2361+
bar = 2
2362+
"""
2363+
if not isinstance(node, CallExpr):
2364+
return None
2365+
call = node
2366+
callee = call.callee
2367+
if not isinstance(callee, RefExpr):
2368+
return None
2369+
fullname = callee.fullname
2370+
if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'):
2371+
return None
2372+
items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1])
2373+
if not ok:
2374+
# Error. Construct dummy return value.
2375+
return self.build_enum_call_typeinfo('Enum', [], fullname)
2376+
name = cast(StrExpr, call.args[0]).value
2377+
if name != var_name or self.is_func_scope():
2378+
# Give it a unique name derived from the line number.
2379+
name += '@' + str(call.line)
2380+
info = self.build_enum_call_typeinfo(name, items, fullname)
2381+
# Store it as a global just in case it would remain anonymous.
2382+
# (Or in the nearest class if there is one.)
2383+
stnode = SymbolTableNode(GDEF, info, self.cur_mod_id)
2384+
if self.type:
2385+
self.type.names[name] = stnode
2386+
else:
2387+
self.globals[name] = stnode
2388+
call.analyzed = EnumCallExpr(info, items, values)
2389+
call.analyzed.set_line(call.line, call.column)
2390+
return info
2391+
2392+
def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo:
2393+
base = self.named_type_or_none(fullname)
2394+
assert base is not None
2395+
info = self.basic_new_typeinfo(name, base)
2396+
info.is_enum = True
2397+
for item in items:
2398+
var = Var(item)
2399+
var.info = info
2400+
var.is_property = True
2401+
info.names[item] = SymbolTableNode(MDEF, var)
2402+
return info
2403+
2404+
def parse_enum_call_args(self, call: CallExpr,
2405+
class_name: str) -> Tuple[List[str],
2406+
List[Optional[Expression]], bool]:
2407+
args = call.args
2408+
if len(args) < 2:
2409+
return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call)
2410+
if len(args) > 2:
2411+
return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call)
2412+
if call.arg_kinds != [ARG_POS, ARG_POS]:
2413+
return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call)
2414+
if not isinstance(args[0], (StrExpr, UnicodeExpr)):
2415+
return self.fail_enum_call_arg(
2416+
"%s() expects a string literal as the first argument" % class_name, call)
2417+
items = []
2418+
values = [] # type: List[Optional[Expression]]
2419+
if isinstance(args[1], (StrExpr, UnicodeExpr)):
2420+
fields = args[1].value
2421+
for field in fields.replace(',', ' ').split():
2422+
items.append(field)
2423+
elif isinstance(args[1], (TupleExpr, ListExpr)):
2424+
seq_items = args[1].items
2425+
if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items):
2426+
items = [cast(StrExpr, seq_item).value for seq_item in seq_items]
2427+
elif all(isinstance(seq_item, (TupleExpr, ListExpr))
2428+
and len(seq_item.items) == 2
2429+
and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr))
2430+
for seq_item in seq_items):
2431+
for seq_item in seq_items:
2432+
assert isinstance(seq_item, (TupleExpr, ListExpr))
2433+
name, value = seq_item.items
2434+
assert isinstance(name, (StrExpr, UnicodeExpr))
2435+
items.append(name.value)
2436+
values.append(value)
2437+
else:
2438+
return self.fail_enum_call_arg(
2439+
"%s() with tuple or list expects strings or (name, value) pairs" %
2440+
class_name,
2441+
call)
2442+
elif isinstance(args[1], DictExpr):
2443+
for key, value in args[1].items:
2444+
if not isinstance(key, (StrExpr, UnicodeExpr)):
2445+
return self.fail_enum_call_arg(
2446+
"%s() with dict literal requires string literals" % class_name, call)
2447+
items.append(key.value)
2448+
values.append(value)
2449+
else:
2450+
# TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}?
2451+
return self.fail_enum_call_arg(
2452+
"%s() expects a string, tuple, list or dict literal as the second argument" %
2453+
class_name,
2454+
call)
2455+
if len(items) == 0:
2456+
return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call)
2457+
if not values:
2458+
values = [None] * len(items)
2459+
assert len(items) == len(values)
2460+
return items, values, True
2461+
2462+
def fail_enum_call_arg(self, message: str,
2463+
context: Context) -> Tuple[List[str],
2464+
List[Optional[Expression]], bool]:
2465+
self.fail(message, context)
2466+
return [], [], False
2467+
23342468
def visit_decorator(self, dec: Decorator) -> None:
23352469
for d in dec.decorators:
23362470
d.accept(self)

mypy/strconv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,9 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> str:
431431
o.info.name(),
432432
o.info.tuple_type)
433433

434+
def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> str:
435+
return 'EnumCallExpr:{}({}, {})'.format(o.line, o.info.name(), o.items)
436+
434437
def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> str:
435438
return 'TypedDictExpr:{}({})'.format(o.line,
436439
o.info.name())

mypy/test/testcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
'check-newsyntax.test',
7474
'check-underscores.test',
7575
'check-classvar.test',
76+
'check-enum.test',
7677
]
7778

7879

mypy/test/testpythoneval.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
python_eval_files = ['pythoneval.test',
3333
'python2eval.test']
3434

35-
python_34_eval_files = ['pythoneval-asyncio.test',
36-
'pythoneval-enum.test']
35+
python_34_eval_files = ['pythoneval-asyncio.test']
3736

3837
# Path to Python 3 interpreter
3938
python3_path = sys.executable

mypy/treetransform.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
ComparisonExpr, TempNode, StarExpr, Statement, Expression,
2020
YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension,
2121
DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr,
22-
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, OverloadPart
22+
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr,
23+
OverloadPart, EnumCallExpr,
2324
)
2425
from mypy.types import Type, FunctionLike
2526
from mypy.traverser import TraverserVisitor
@@ -486,6 +487,9 @@ def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr:
486487
def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr:
487488
return NamedTupleExpr(node.info)
488489

490+
def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr:
491+
return EnumCallExpr(node.info, node.items, node.values)
492+
489493
def visit_typeddict_expr(self, node: TypedDictExpr) -> Node:
490494
return TypedDictExpr(node.info)
491495

mypy/visitor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
156156
def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
157157
pass
158158

159+
@abstractmethod
160+
def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
161+
pass
162+
159163
@abstractmethod
160164
def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
161165
pass
@@ -514,6 +518,9 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
514518
def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
515519
pass
516520

521+
def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
522+
pass
523+
517524
def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
518525
pass
519526

0 commit comments

Comments
 (0)