Skip to content

Checking of starred expressions #483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 21, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,24 @@ The type Tuple[t, ...] represents a tuple with the item types t, ...:
t = 1, 'foo' # OK
t = 'foo', 1 # Type check error

Starred expressions
******************************

In most cases, mypy can infer the type of starred expressions from the right-hand side of an assignment, but not always:

.. code-block:: python

a, *bs = 1, 2, 3 # OK
p, q, *rs = 1, 2 # Error: Type of cs cannot be inferred

On first line, the type of ``bs`` is inferred to be ``List[int]``. However, on the second line, mypy cannot infer the type of ``rs``, because there is no right-hand side value for ``rs`` to infer the type from. In cases like these, the starred expression needs to be annotated with a starred type:

.. code-block:: python

p, q, *rs = 1, 2 # type: int, int, *List[int]

Here, the type of ``rs`` is set to ``List[int]``.

Class name forward references
*****************************

Expand Down
199 changes: 135 additions & 64 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TypeApplication, DictExpr, SliceExpr, FuncExpr, TempNode, SymbolTableNode,
Context, ListComprehension, ConditionalExpr, GeneratorExpr,
Decorator, SetExpr, PassStmt, TypeVarExpr, UndefinedExpr, PrintStmt,
LITERAL_TYPE, BreakStmt, ContinueStmt, ComparisonExpr
LITERAL_TYPE, BreakStmt, ContinueStmt, ComparisonExpr, StarExpr
)
from mypy.nodes import function_type, method_type
from mypy import nodes
Expand Down Expand Up @@ -956,17 +956,44 @@ def check_assignment_to_multiple_lvalues(self, lvalues: List[Node], rvalue: Node
# control in cases like: a, b = [int, str] where rhs would get
# type List[object]

rtuple = cast(Union[TupleExpr, ListExpr], rvalue)
rvalues = cast(Union[TupleExpr, ListExpr], rvalue).items

if len(rtuple.items) != len(lvalues):
self.msg.incompatible_value_count_in_assignment(
len(lvalues), len(rtuple.items), context)
else:
for lv, rv in zip(lvalues, rtuple.items):
if self.check_rvalue_count_in_assignment(lvalues, len(rvalues), context):
star_index = next((i for i, lv in enumerate(lvalues) if
isinstance(lv, StarExpr)), len(lvalues))

left_lvs = lvalues[:star_index]
star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
right_lvs = lvalues[star_index+1:]

left_rvs, star_rvs, right_rvs = self.split_around_star(
rvalues, star_index, len(lvalues))

lr_pairs = list(zip(left_lvs, left_rvs))
if star_lv:
rv_list = ListExpr(star_rvs)
rv_list.set_line(rvalue.get_line())
lr_pairs.append( (star_lv.expr, rv_list) )
lr_pairs.extend(zip(right_lvs, right_rvs))

for lv, rv in lr_pairs:
self.check_assignment(lv, rv, infer_lvalue_type)
else:
self.check_multi_assignment(lvalues, rvalue, context, infer_lvalue_type)

def check_rvalue_count_in_assignment(self, lvalues: List[Node], rvalue_count: int,
context: Context) -> bool:
if any(isinstance(lvalue, StarExpr) for lvalue in lvalues):
if len(lvalues)-1 > rvalue_count:
self.msg.wrong_number_values_to_unpack(rvalue_count,
len(lvalues)-1, context)
return False
elif rvalue_count != len(lvalues):
self.msg.wrong_number_values_to_unpack(rvalue_count,
len(lvalues), context)
return False
return True

def remove_parens(self, node: Node) -> Node:
if isinstance(node, ParenExpr):
return self.remove_parens(node.expr)
Expand All @@ -978,70 +1005,125 @@ def check_multi_assignment(self, lvalues: List[Node],
context: Context,
infer_lvalue_type: bool = True,
msg: str = None) -> None:
"""Check the assignment of one rvalue to a number of lvalues
for example from a ListExpr or TupleExpr.
"""

if not msg:
msg = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT

# First handle case where rvalue is of form Undefined, ...
rvalue_type = get_undefined_tuple(rvalue, self.named_type('builtins.tuple'))
undefined_rvalue = True
if not rvalue_type:
# Infer the type of an ordinary rvalue expression.
rvalue_type = self.accept(rvalue) # TODO maybe elsewhere; redundant
undefined_rvalue = False

if isinstance(rvalue_type, AnyType):
for lv in lvalues:
self.check_assignment(lv, self.temp_node(AnyType(), context), infer_lvalue_type)
elif isinstance(rvalue_type, TupleType):
self.check_multi_assignment_from_tuple(lvalues, rvalue, cast(TupleType, rvalue_type),
"""Check the assignment of one rvalue to a number of lvalues
for example from a ListExpr or TupleExpr.
"""

if not msg:
msg = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT

# First handle case where rvalue is of form Undefined, ...
rvalue_type = get_undefined_tuple(rvalue, self.named_type('builtins.tuple'))
undefined_rvalue = True
if not rvalue_type:
# Infer the type of an ordinary rvalue expression.
rvalue_type = self.accept(rvalue) # TODO maybe elsewhere; redundant
undefined_rvalue = False

if isinstance(rvalue_type, AnyType):
for lv in lvalues:
if isinstance(lv, StarExpr):
lv = lv.expr
self.check_assignment(lv, self.temp_node(AnyType(), context), infer_lvalue_type)
elif isinstance(rvalue_type, TupleType):
self.check_multi_assignment_from_tuple(lvalues, rvalue, cast(TupleType, rvalue_type),
context, undefined_rvalue, infer_lvalue_type)
else:
self.check_multi_assignment_from_iterable(lvalues, rvalue_type,
else:
self.check_multi_assignment_from_iterable(lvalues, rvalue_type,
context, infer_lvalue_type)

def check_multi_assignment_from_tuple(self, lvalues: List[Node], rvalue: Node,
rvalue_type: TupleType, context: Context,
undefined_rvalue: bool, infer_lvalue_type: bool = True) -> None:
if len(rvalue_type.items) != len(lvalues):
self.msg.wrong_number_values_to_unpack(len(rvalue_type.items), len(lvalues), context)
else:
if not undefined_rvalue:
# Create lvalue_type for type inference
undefined_rvalue: bool, infer_lvalue_type: bool=True) -> None:
if self.check_rvalue_count_in_assignment(lvalues, len(rvalue_type.items), context):
star_index = next((i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues))

type_parameters = [] # type: List[Type]
for i in range(len(lvalues)):
sub_lvalue_type, index_expr, inferred = self.check_lvalue(lvalues[i])

if sub_lvalue_type:
type_parameters.append(sub_lvalue_type)
else: # index lvalue
# TODO Figure out more precise type context, probably
# based on the type signature of the _set method.
type_parameters.append(rvalue_type.items[i])

lvalue_type = TupleType(type_parameters, self.named_type('builtins.tuple'))
left_lvs = lvalues[:star_index]
star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
right_lvs = lvalues[star_index+1:]

if not undefined_rvalue:
# Infer rvalue again, now in the correct type context.
lvalue_type = self.lvalue_type_for_inference(lvalues, rvalue_type)
rvalue_type = cast(TupleType, self.accept(rvalue, lvalue_type))

for lv, rv_type in zip(lvalues, rvalue_type.items):
left_rv_types, star_rv_types, right_rv_types = self.split_around_star(
rvalue_type.items, star_index, len(lvalues))

for lv, rv_type in zip(left_lvs, left_rv_types):
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)
if star_lv:
nodes = [self.temp_node(rv_type, context) for rv_type in star_rv_types]
list_expr = ListExpr(nodes)
list_expr.set_line(context.get_line())
self.check_assignment(star_lv.expr, list_expr, infer_lvalue_type)
for lv, rv_type in zip(right_lvs, right_rv_types):
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)

def lvalue_type_for_inference(self, lvalues: List[Node], rvalue_type: TupleType) -> Type:
star_index = next((i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues))
left_lvs = lvalues[:star_index]
star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
right_lvs = lvalues[star_index+1:]
left_rv_types, star_rv_types, right_rv_types = self.split_around_star(
rvalue_type.items, star_index, len(lvalues))

type_parameters = [] # type: List[Type]

def append_types_for_inference(lvs: List[Node], rv_types: List[Type]) -> None:
for lv, rv_type in zip(lvs, rv_types):
sub_lvalue_type, index_expr, inferred = self.check_lvalue(lv)
if sub_lvalue_type:
type_parameters.append(sub_lvalue_type)
else: # index lvalue
# TODO Figure out more precise type context, probably
# based on the type signature of the _set method.
type_parameters.append(rv_type)

append_types_for_inference(left_lvs, left_rv_types)

if star_lv:
sub_lvalue_type, index_expr, inferred = self.check_lvalue(star_lv.expr)
if sub_lvalue_type:
type_parameters.extend([sub_lvalue_type] * len(star_rv_types))
else: # index lvalue
# TODO Figure out more precise type context, probably
# based on the type signature of the _set method.
type_parameters.extend(star_rv_types)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails in case of nested star expressions, like a, *(b, *c) = x. I'd like to fix it, but I have trouble coming up with a test case for type inference from lvalues.

I guess it is not very common, though, and not very useful.


append_types_for_inference(right_lvs, right_rv_types)

return TupleType(type_parameters, self.named_type('builtins.tuple'))

def split_around_star(self, items: List[T], star_index: int,
length: int) -> Tuple[List[T], List[T], List[T]]:
"""Splits a list of items in three to match another list of length 'length'
that contains a starred expression at 'star_index' in the following way:

star_index = 2, length = 5 (i.e., [a,b,*,c,d]), items = [1,2,3,4,5,6,7]
returns in: ([1,2], [3,4,5], [6,7])
"""
nr_right_of_star = length - star_index - 1
right_index = nr_right_of_star if -nr_right_of_star != 0 else len(items)
left = items[:star_index]
star = items[star_index:right_index]
right = items[right_index:]
return (left, star, right)

def type_is_iterable(self, type: Type) -> bool:
return (is_subtype(type, self.named_generic_type('typing.Iterable',
[AnyType()])) and
isinstance(type, Instance))

def check_multi_assignment_from_iterable(self, lvalues: List[Node], rvalue_type: Type,
context: Context, infer_lvalue_type: bool = True) -> None:
context: Context, infer_lvalue_type: bool=True) -> None:
if self.type_is_iterable(rvalue_type):
item_type = self.iterable_item_type(cast(Instance,rvalue_type))
for lv in lvalues:
self.check_assignment(lv, self.temp_node(item_type, context), infer_lvalue_type)
if isinstance(lv, StarExpr):
self.check_assignment(lv.expr, self.temp_node(rvalue_type, context), infer_lvalue_type)
else:
self.check_assignment(lv, self.temp_node(item_type, context), infer_lvalue_type)
else:
self.msg.type_not_iterable(rvalue_type, context)

Expand Down Expand Up @@ -1452,21 +1534,10 @@ def analyse_iterable_item_type(self, expr: Node) -> Type:
expr)
return echk.check_call(method, [], [], expr)[0]

def analyse_index_variables(self, index: List[Node],
item_type: Type, context: Context) -> None:
def analyse_index_variables(self, index: Node, item_type: Type,
context: Context) -> None:
"""Type check or infer for loop or list comprehension index vars."""
# Create a temporary copy of variables with Node item type.
# TODO this is ugly
node_index = [] # type: List[Node]
for i in index:
node_index.append(i)

if len(node_index) == 1:
self.check_assignment(node_index[0], self.temp_node(item_type, context))
else:
self.check_multi_assignment(node_index,
self.temp_node(item_type, context),
context)
self.check_assignment(index, self.temp_node(item_type, context))

def visit_del_stmt(self, s: DelStmt) -> Type:
if isinstance(s.expr, IndexExpr):
Expand Down
11 changes: 7 additions & 4 deletions mypy/noderepr.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,9 @@ def __init__(self, while_tok: Any, else_tok: Any) -> None:


class ForStmtRepr:
def __init__(self, for_tok: Any, commas: Any, in_tok: Any,
def __init__(self, for_tok: Any, in_tok: Any,
else_tok: Any) -> None:
self.for_tok = for_tok
self.commas = commas
self.in_tok = in_tok
self.else_tok = else_tok

Expand Down Expand Up @@ -213,6 +212,11 @@ def __init__(self, lparen: Any, rparen: Any) -> None:
self.rparen = rparen


class StarExprRepr:
def __init__(self, star: Any) -> None:
self.star = star


class NameExprRepr:
def __init__(self, id: Any) -> None:
self.id = id
Expand Down Expand Up @@ -326,10 +330,9 @@ def __init__(self, langle: Any, commas: Any, rangle: Any) -> None:


class GeneratorExprRepr:
def __init__(self, for_toks: List[Token], commas: List[Token], in_toks: List[Token],
def __init__(self, for_toks: List[Token], in_toks: List[Token],
if_toklists: List[List[Token]]) -> None:
self.for_toks = for_toks
self.commas = commas
self.in_toks = in_toks
self.if_toklists = if_toklists

Expand Down
25 changes: 21 additions & 4 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,13 +568,13 @@ def accept(self, visitor: NodeVisitor[T]) -> T:

class ForStmt(Node):
# Index variables
index = Undefined(List['Node'])
index = Undefined(Node)
# Expression to iterate
expr = Undefined(Node)
body = Undefined(Block)
else_body = Undefined(Block)

def __init__(self, index: List['Node'], expr: Node, body: Block,
def __init__(self, index: Node, expr: Node, body: Block,
else_body: Block) -> None:
self.index = index
self.expr = expr
Expand Down Expand Up @@ -805,6 +805,23 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_paren_expr(self)


class StarExpr(Node):
"""Star expression"""

expr = Undefined(Node)

def __init__(self, expr: Node) -> None:
self.expr = expr
self.literal = self.expr.literal
self.literal_hash = ('Star', expr.literal_hash,)

# Whether this starred expression is used in a tuple/list and as lvalue
self.valid = False

def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_star_expr(self)


class RefExpr(Node):
"""Abstract base class for name-like constructs"""

Expand Down Expand Up @@ -1175,9 +1192,9 @@ class GeneratorExpr(Node):
left_expr = Undefined(Node)
sequences_expr = Undefined(List[Node])
condlists = Undefined(List[List[Node]])
indices = Undefined(List[List[Node]])
indices = Undefined(List[Node])

def __init__(self, left_expr: Node, indices: List[List[Node]],
def __init__(self, left_expr: Node, indices: List[Node],
sequences: List[Node], condlists: List[List[Node]]) -> None:
self.left_expr = left_expr
self.sequences = sequences
Expand Down
Loading