Skip to content

Commit b4492b7

Browse files
gvanrossumJukkaL
authored andcommitted
Allow E[<str>] where E is an Enum type (#2812)
* Allow E[<str>] where E is an Enum type. Fixes #1381. * Respond to review by @ilevkivskyi
1 parent 9406886 commit b4492b7

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

mypy/checkexpr.py

+13
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,9 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type:
14421442
return AnyType()
14431443
elif isinstance(left_type, TypedDictType):
14441444
return self.visit_typeddict_index_expr(left_type, e.index)
1445+
elif (isinstance(left_type, CallableType)
1446+
and left_type.is_type_obj() and left_type.type_object().is_enum):
1447+
return self.visit_enum_index_expr(left_type.type_object(), e.index, e)
14451448
else:
14461449
result, method_type = self.check_op('__getitem__', left_type, e.index, e)
14471450
e.method_type = method_type
@@ -1500,6 +1503,16 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression)
15001503
return AnyType()
15011504
return item_type
15021505

1506+
def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression,
1507+
context: Context) -> Type:
1508+
string_type = self.named_type('builtins.str') # type: Type
1509+
if self.chk.options.python_version[0] < 3:
1510+
string_type = UnionType.make_union([string_type,
1511+
self.named_type('builtins.unicode')])
1512+
self.chk.check_subtype(self.accept(index), string_type, context,
1513+
"Enum index should be a string", "actual index type")
1514+
return Instance(enum_type, [])
1515+
15031516
def visit_cast_expr(self, expr: CastExpr) -> Type:
15041517
"""Type check a cast expression."""
15051518
source_type = self.accept(expr.expr, context=AnyType())

mypy/semanal.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2617,7 +2617,11 @@ def visit_unary_expr(self, expr: UnaryExpr) -> None:
26172617

26182618
def visit_index_expr(self, expr: IndexExpr) -> None:
26192619
expr.base.accept(self)
2620-
if isinstance(expr.base, RefExpr) and expr.base.kind == TYPE_ALIAS:
2620+
if (isinstance(expr.base, RefExpr)
2621+
and isinstance(expr.base.node, TypeInfo)
2622+
and expr.base.node.is_enum):
2623+
expr.index.accept(self)
2624+
elif isinstance(expr.base, RefExpr) and expr.base.kind == TYPE_ALIAS:
26212625
# Special form -- subscripting a generic type alias.
26222626
# Perform the type substitution and create a new alias.
26232627
res = analyze_type_alias(expr,

test-data/unit/pythoneval-enum.test

+26-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def takes_some_ext_int_enum(s: SomeExtIntEnum):
119119
pass
120120
takes_some_ext_int_enum(SomeExtIntEnum.x)
121121

122-
123122
[case testNamedTupleEnum]
124123
from typing import NamedTuple
125124
from enum import Enum
@@ -132,3 +131,29 @@ class E(N, Enum):
132131
def f(x: E) -> None: pass
133132

134133
f(E.X)
134+
135+
[case testEnumCall]
136+
from enum import IntEnum
137+
class E(IntEnum):
138+
a = 1
139+
x = None # type: int
140+
reveal_type(E(x))
141+
[out]
142+
_program.py:5: error: Revealed type is '_testEnumCall.E'
143+
144+
[case testEnumIndex]
145+
from enum import IntEnum
146+
class E(IntEnum):
147+
a = 1
148+
s = None # type: str
149+
reveal_type(E[s])
150+
[out]
151+
_program.py:5: error: Revealed type is '_testEnumIndex.E'
152+
153+
[case testEnumIndexError]
154+
from enum import IntEnum
155+
class E(IntEnum):
156+
a = 1
157+
E[1]
158+
[out]
159+
_program.py:4: error: Enum index should be a string (actual index type "int")

0 commit comments

Comments
 (0)