diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index de858ccef7a9..9d30fc455693 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -248,6 +248,24 @@ The type Tuple[t, ...] represents a tuple with the item types t, ...: t = 1, 'foo' # OK t = 'foo', 1 # Type check error +Starred expressions +****************************** + +In most cases, mypy can infer the type of starred expressions from the right-hand side of an assignment, but not always: + +.. code-block:: python + + a, *bs = 1, 2, 3 # OK + p, q, *rs = 1, 2 # Error: Type of cs cannot be inferred + +On first line, the type of ``bs`` is inferred to be ``List[int]``. However, on the second line, mypy cannot infer the type of ``rs``, because there is no right-hand side value for ``rs`` to infer the type from. In cases like these, the starred expression needs to be annotated with a starred type: + +.. code-block:: python + + p, q, *rs = 1, 2 # type: int, int, *List[int] + +Here, the type of ``rs`` is set to ``List[int]``. + Class name forward references ***************************** diff --git a/mypy/checker.py b/mypy/checker.py index 2cecb7a9189b..be20e0bd10c2 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -16,7 +16,7 @@ TypeApplication, DictExpr, SliceExpr, FuncExpr, TempNode, SymbolTableNode, Context, ListComprehension, ConditionalExpr, GeneratorExpr, Decorator, SetExpr, PassStmt, TypeVarExpr, UndefinedExpr, PrintStmt, - LITERAL_TYPE, BreakStmt, ContinueStmt, ComparisonExpr + LITERAL_TYPE, BreakStmt, ContinueStmt, ComparisonExpr, StarExpr ) from mypy.nodes import function_type, method_type from mypy import nodes @@ -956,17 +956,44 @@ def check_assignment_to_multiple_lvalues(self, lvalues: List[Node], rvalue: Node # control in cases like: a, b = [int, str] where rhs would get # type List[object] - rtuple = cast(Union[TupleExpr, ListExpr], rvalue) + rvalues = cast(Union[TupleExpr, ListExpr], rvalue).items - if len(rtuple.items) != len(lvalues): - self.msg.incompatible_value_count_in_assignment( - len(lvalues), len(rtuple.items), context) - else: - for lv, rv in zip(lvalues, rtuple.items): + if self.check_rvalue_count_in_assignment(lvalues, len(rvalues), context): + star_index = next((i for i, lv in enumerate(lvalues) if + isinstance(lv, StarExpr)), len(lvalues)) + + left_lvs = lvalues[:star_index] + star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None + right_lvs = lvalues[star_index+1:] + + left_rvs, star_rvs, right_rvs = self.split_around_star( + rvalues, star_index, len(lvalues)) + + lr_pairs = list(zip(left_lvs, left_rvs)) + if star_lv: + rv_list = ListExpr(star_rvs) + rv_list.set_line(rvalue.get_line()) + lr_pairs.append( (star_lv.expr, rv_list) ) + lr_pairs.extend(zip(right_lvs, right_rvs)) + + for lv, rv in lr_pairs: self.check_assignment(lv, rv, infer_lvalue_type) else: self.check_multi_assignment(lvalues, rvalue, context, infer_lvalue_type) + def check_rvalue_count_in_assignment(self, lvalues: List[Node], rvalue_count: int, + context: Context) -> bool: + if any(isinstance(lvalue, StarExpr) for lvalue in lvalues): + if len(lvalues)-1 > rvalue_count: + self.msg.wrong_number_values_to_unpack(rvalue_count, + len(lvalues)-1, context) + return False + elif rvalue_count != len(lvalues): + self.msg.wrong_number_values_to_unpack(rvalue_count, + len(lvalues), context) + return False + return True + def remove_parens(self, node: Node) -> Node: if isinstance(node, ParenExpr): return self.remove_parens(node.expr) @@ -978,70 +1005,125 @@ def check_multi_assignment(self, lvalues: List[Node], context: Context, infer_lvalue_type: bool = True, msg: str = None) -> None: - """Check the assignment of one rvalue to a number of lvalues - for example from a ListExpr or TupleExpr. - """ - - if not msg: - msg = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT - - # First handle case where rvalue is of form Undefined, ... - rvalue_type = get_undefined_tuple(rvalue, self.named_type('builtins.tuple')) - undefined_rvalue = True - if not rvalue_type: - # Infer the type of an ordinary rvalue expression. - rvalue_type = self.accept(rvalue) # TODO maybe elsewhere; redundant - undefined_rvalue = False - - if isinstance(rvalue_type, AnyType): - for lv in lvalues: - self.check_assignment(lv, self.temp_node(AnyType(), context), infer_lvalue_type) - elif isinstance(rvalue_type, TupleType): - self.check_multi_assignment_from_tuple(lvalues, rvalue, cast(TupleType, rvalue_type), + """Check the assignment of one rvalue to a number of lvalues + for example from a ListExpr or TupleExpr. + """ + + if not msg: + msg = messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT + + # First handle case where rvalue is of form Undefined, ... + rvalue_type = get_undefined_tuple(rvalue, self.named_type('builtins.tuple')) + undefined_rvalue = True + if not rvalue_type: + # Infer the type of an ordinary rvalue expression. + rvalue_type = self.accept(rvalue) # TODO maybe elsewhere; redundant + undefined_rvalue = False + + if isinstance(rvalue_type, AnyType): + for lv in lvalues: + if isinstance(lv, StarExpr): + lv = lv.expr + self.check_assignment(lv, self.temp_node(AnyType(), context), infer_lvalue_type) + elif isinstance(rvalue_type, TupleType): + self.check_multi_assignment_from_tuple(lvalues, rvalue, cast(TupleType, rvalue_type), context, undefined_rvalue, infer_lvalue_type) - else: - self.check_multi_assignment_from_iterable(lvalues, rvalue_type, + else: + self.check_multi_assignment_from_iterable(lvalues, rvalue_type, context, infer_lvalue_type) def check_multi_assignment_from_tuple(self, lvalues: List[Node], rvalue: Node, rvalue_type: TupleType, context: Context, - undefined_rvalue: bool, infer_lvalue_type: bool = True) -> None: - if len(rvalue_type.items) != len(lvalues): - self.msg.wrong_number_values_to_unpack(len(rvalue_type.items), len(lvalues), context) - else: - if not undefined_rvalue: - # Create lvalue_type for type inference + undefined_rvalue: bool, infer_lvalue_type: bool=True) -> None: + if self.check_rvalue_count_in_assignment(lvalues, len(rvalue_type.items), context): + star_index = next((i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues)) - type_parameters = [] # type: List[Type] - for i in range(len(lvalues)): - sub_lvalue_type, index_expr, inferred = self.check_lvalue(lvalues[i]) - - if sub_lvalue_type: - type_parameters.append(sub_lvalue_type) - else: # index lvalue - # TODO Figure out more precise type context, probably - # based on the type signature of the _set method. - type_parameters.append(rvalue_type.items[i]) - - lvalue_type = TupleType(type_parameters, self.named_type('builtins.tuple')) + left_lvs = lvalues[:star_index] + star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None + right_lvs = lvalues[star_index+1:] + if not undefined_rvalue: # Infer rvalue again, now in the correct type context. + lvalue_type = self.lvalue_type_for_inference(lvalues, rvalue_type) rvalue_type = cast(TupleType, self.accept(rvalue, lvalue_type)) - for lv, rv_type in zip(lvalues, rvalue_type.items): + left_rv_types, star_rv_types, right_rv_types = self.split_around_star( + rvalue_type.items, star_index, len(lvalues)) + + for lv, rv_type in zip(left_lvs, left_rv_types): + self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type) + if star_lv: + nodes = [self.temp_node(rv_type, context) for rv_type in star_rv_types] + list_expr = ListExpr(nodes) + list_expr.set_line(context.get_line()) + self.check_assignment(star_lv.expr, list_expr, infer_lvalue_type) + for lv, rv_type in zip(right_lvs, right_rv_types): self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type) + def lvalue_type_for_inference(self, lvalues: List[Node], rvalue_type: TupleType) -> Type: + star_index = next((i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues)) + left_lvs = lvalues[:star_index] + star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None + right_lvs = lvalues[star_index+1:] + left_rv_types, star_rv_types, right_rv_types = self.split_around_star( + rvalue_type.items, star_index, len(lvalues)) + + type_parameters = [] # type: List[Type] + + def append_types_for_inference(lvs: List[Node], rv_types: List[Type]) -> None: + for lv, rv_type in zip(lvs, rv_types): + sub_lvalue_type, index_expr, inferred = self.check_lvalue(lv) + if sub_lvalue_type: + type_parameters.append(sub_lvalue_type) + else: # index lvalue + # TODO Figure out more precise type context, probably + # based on the type signature of the _set method. + type_parameters.append(rv_type) + + append_types_for_inference(left_lvs, left_rv_types) + + if star_lv: + sub_lvalue_type, index_expr, inferred = self.check_lvalue(star_lv.expr) + if sub_lvalue_type: + type_parameters.extend([sub_lvalue_type] * len(star_rv_types)) + else: # index lvalue + # TODO Figure out more precise type context, probably + # based on the type signature of the _set method. + type_parameters.extend(star_rv_types) + + append_types_for_inference(right_lvs, right_rv_types) + + return TupleType(type_parameters, self.named_type('builtins.tuple')) + + def split_around_star(self, items: List[T], star_index: int, + length: int) -> Tuple[List[T], List[T], List[T]]: + """Splits a list of items in three to match another list of length 'length' + that contains a starred expression at 'star_index' in the following way: + + star_index = 2, length = 5 (i.e., [a,b,*,c,d]), items = [1,2,3,4,5,6,7] + returns in: ([1,2], [3,4,5], [6,7]) + """ + nr_right_of_star = length - star_index - 1 + right_index = nr_right_of_star if -nr_right_of_star != 0 else len(items) + left = items[:star_index] + star = items[star_index:right_index] + right = items[right_index:] + return (left, star, right) + def type_is_iterable(self, type: Type) -> bool: return (is_subtype(type, self.named_generic_type('typing.Iterable', [AnyType()])) and isinstance(type, Instance)) def check_multi_assignment_from_iterable(self, lvalues: List[Node], rvalue_type: Type, - context: Context, infer_lvalue_type: bool = True) -> None: + context: Context, infer_lvalue_type: bool=True) -> None: if self.type_is_iterable(rvalue_type): item_type = self.iterable_item_type(cast(Instance,rvalue_type)) for lv in lvalues: - self.check_assignment(lv, self.temp_node(item_type, context), infer_lvalue_type) + if isinstance(lv, StarExpr): + self.check_assignment(lv.expr, self.temp_node(rvalue_type, context), infer_lvalue_type) + else: + self.check_assignment(lv, self.temp_node(item_type, context), infer_lvalue_type) else: self.msg.type_not_iterable(rvalue_type, context) @@ -1452,21 +1534,10 @@ def analyse_iterable_item_type(self, expr: Node) -> Type: expr) return echk.check_call(method, [], [], expr)[0] - def analyse_index_variables(self, index: List[Node], - item_type: Type, context: Context) -> None: + def analyse_index_variables(self, index: Node, item_type: Type, + context: Context) -> None: """Type check or infer for loop or list comprehension index vars.""" - # Create a temporary copy of variables with Node item type. - # TODO this is ugly - node_index = [] # type: List[Node] - for i in index: - node_index.append(i) - - if len(node_index) == 1: - self.check_assignment(node_index[0], self.temp_node(item_type, context)) - else: - self.check_multi_assignment(node_index, - self.temp_node(item_type, context), - context) + self.check_assignment(index, self.temp_node(item_type, context)) def visit_del_stmt(self, s: DelStmt) -> Type: if isinstance(s.expr, IndexExpr): diff --git a/mypy/noderepr.py b/mypy/noderepr.py index 102263cc839c..2fd73db99657 100644 --- a/mypy/noderepr.py +++ b/mypy/noderepr.py @@ -144,10 +144,9 @@ def __init__(self, while_tok: Any, else_tok: Any) -> None: class ForStmtRepr: - def __init__(self, for_tok: Any, commas: Any, in_tok: Any, + def __init__(self, for_tok: Any, in_tok: Any, else_tok: Any) -> None: self.for_tok = for_tok - self.commas = commas self.in_tok = in_tok self.else_tok = else_tok @@ -213,6 +212,11 @@ def __init__(self, lparen: Any, rparen: Any) -> None: self.rparen = rparen +class StarExprRepr: + def __init__(self, star: Any) -> None: + self.star = star + + class NameExprRepr: def __init__(self, id: Any) -> None: self.id = id @@ -326,10 +330,9 @@ def __init__(self, langle: Any, commas: Any, rangle: Any) -> None: class GeneratorExprRepr: - def __init__(self, for_toks: List[Token], commas: List[Token], in_toks: List[Token], + def __init__(self, for_toks: List[Token], in_toks: List[Token], if_toklists: List[List[Token]]) -> None: self.for_toks = for_toks - self.commas = commas self.in_toks = in_toks self.if_toklists = if_toklists diff --git a/mypy/nodes.py b/mypy/nodes.py index b6380fbd52e4..d5f3f18a08f0 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -568,13 +568,13 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class ForStmt(Node): # Index variables - index = Undefined(List['Node']) + index = Undefined(Node) # Expression to iterate expr = Undefined(Node) body = Undefined(Block) else_body = Undefined(Block) - def __init__(self, index: List['Node'], expr: Node, body: Block, + def __init__(self, index: Node, expr: Node, body: Block, else_body: Block) -> None: self.index = index self.expr = expr @@ -805,6 +805,23 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_paren_expr(self) +class StarExpr(Node): + """Star expression""" + + expr = Undefined(Node) + + def __init__(self, expr: Node) -> None: + self.expr = expr + self.literal = self.expr.literal + self.literal_hash = ('Star', expr.literal_hash,) + + # Whether this starred expression is used in a tuple/list and as lvalue + self.valid = False + + def accept(self, visitor: NodeVisitor[T]) -> T: + return visitor.visit_star_expr(self) + + class RefExpr(Node): """Abstract base class for name-like constructs""" @@ -1175,9 +1192,9 @@ class GeneratorExpr(Node): left_expr = Undefined(Node) sequences_expr = Undefined(List[Node]) condlists = Undefined(List[List[Node]]) - indices = Undefined(List[List[Node]]) + indices = Undefined(List[Node]) - def __init__(self, left_expr: Node, indices: List[List[Node]], + def __init__(self, left_expr: Node, indices: List[Node], sequences: List[Node], condlists: List[List[Node]]) -> None: self.left_expr = left_expr self.sequences = sequences diff --git a/mypy/output.py b/mypy/output.py index 3668c92804c3..3673a65a38b0 100644 --- a/mypy/output.py +++ b/mypy/output.py @@ -262,9 +262,7 @@ def visit_while_stmt(self, o): def visit_for_stmt(self, o): r = o.repr self.token(r.for_tok) - for i in range(len(o.index)): - self.node(o.index[i]) - self.token(r.commas[i]) + self.node(o.index) self.token(r.in_tok) self.node(o.expr) @@ -332,6 +330,10 @@ def visit_paren_expr(self, o): self.node(o.expr) self.token(o.repr.rparen) + def visit_star_expr(self, o): + self.token(o.repr.star) + self.node(o.expr) + def visit_name_expr(self, o): # Supertype references may not have a representation. if o.repr: @@ -446,10 +448,7 @@ def visit_generator_expr(self, o): self.node(o.left_expr) for i in range(len(o.indices)): self.token(r.for_toks[i]) - for j in range(len(o.indices[i])): - self.node(o.indices[i][j]) - if j < len(o.indices[i]) - 1: - self.token(r.commas[0]) + self.node(o.indices[i]) self.token(r.in_toks[i]) self.node(o.sequences[i]) for cond, if_tok in zip(o.condlists[i], r.if_toklists[i]): diff --git a/mypy/parse.py b/mypy/parse.py index 443b642554dd..0e8d7308b020 100755 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -23,7 +23,8 @@ TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, - UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, ComparisonExpr + UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, ComparisonExpr, + StarExpr ) from mypy import nodes from mypy import noderepr @@ -44,7 +45,7 @@ '&': 10, '^': 9, '|': 8, - '==': 7, '!=': 7, '<': 7, '>': 7, '<=': 7, '>=': 7, 'is': 7, 'in': 7, + '==': 7, '!=': 7, '<': 7, '>': 7, '<=': 7, '>=': 7, 'is': 7, 'in': 7, '*u': 7, # unary * for star expressions 'not': 6, 'and': 5, 'or': 4, @@ -697,7 +698,7 @@ def parse_statement(self) -> Node: return stmt def parse_expression_or_assignment(self) -> Node: - e = self.parse_expression() + e = self.parse_expression(star_expr_allowed=True) if self.current_str() == '=': return self.parse_assignment(e) elif self.current_str() in op_assign: @@ -726,11 +727,11 @@ def parse_assignment(self, lv: Any) -> Node: assigns = [self.expect('=')] lvalues = [lv] - e = self.parse_expression() + e = self.parse_expression(star_expr_allowed=True) while self.current_str() == '=': lvalues.append(e) assigns.append(self.skip()) - e = self.parse_expression() + e = self.parse_expression(star_expr_allowed=True) br = self.expect_break() type = self.parse_type_comment(br, signature=False) @@ -851,7 +852,7 @@ def parse_while_stmt(self) -> WhileStmt: def parse_for_stmt(self) -> ForStmt: for_tok = self.expect('for') - index, commas = self.parse_for_index_variables() + index = self.parse_for_index_variables() in_tok = self.expect('in') expr = self.parse_expression() @@ -865,31 +866,30 @@ def parse_for_stmt(self) -> ForStmt: else_tok = none node = ForStmt(index, expr, body, else_body) - self.set_repr(node, noderepr.ForStmtRepr(for_tok, commas, in_tok, - else_tok)) + self.set_repr(node, noderepr.ForStmtRepr(for_tok, in_tok, else_tok)) return node - def parse_for_index_variables(self) -> Tuple[List[Node], List[Token]]: + def parse_for_index_variables(self) -> Node: # Parse index variables of a 'for' statement. - index = List[Node]() + index_items = List[Node]() commas = List[Token]() - is_paren = self.current_str() == '(' - if is_paren: - self.skip() - while True: - v = self.parse_expression(precedence['in']) # prevent parsing of for's 'in' - index.append(v) + v = self.parse_expression(precedence['in'], star_expr_allowed=True) # prevent parsing of for-stmt's 'in' + index_items.append(v) if self.current_str() != ',': commas.append(none) break commas.append(self.skip()) - if is_paren: - self.expect(')') + if len(index_items) == 1: + index = index_items[0] + else: + index = TupleExpr(index_items) + index.set_line(index_items[0].get_line()) + self.set_repr(index, noderepr.TupleExprRepr(none, commas, none)) - return index, commas + return index def parse_if_stmt(self) -> IfStmt: is_error = False @@ -1020,7 +1020,7 @@ def parse_print_stmt(self) -> PrintStmt: # Parsing expressions - def parse_expression(self, prec: int = 0) -> Node: + def parse_expression(self, prec: int = 0, star_expr_allowed: bool = False) -> Node: """Parse a subexpression within a specific precedence context.""" expr = Undefined # type: Node t = self.current() # Remember token for setting the line number. @@ -1040,6 +1040,8 @@ def parse_expression(self, prec: int = 0) -> Node: expr = self.parse_lambda_expr() elif s == '{': expr = self.parse_dict_or_set_expr() + elif s == '*' and star_expr_allowed: + expr = self.parse_star_expr() else: if isinstance(self.current(), Name): # Name expression. @@ -1137,12 +1139,21 @@ def parse_parentheses(self) -> Node: expr = self.parse_empty_tuple_expr(lparen) # type: Node else: # Parenthesised expression. - expr = self.parse_expression(0) + expr = self.parse_expression(0, star_expr_allowed=True) rparen = self.expect(')') expr = ParenExpr(expr) self.set_repr(expr, noderepr.ParenExprRepr(lparen, rparen)) return expr + def parse_star_expr(self) -> Node: + star = self.expect('*') + expr = self.parse_expression(precedence['*u']) + expr = StarExpr(expr) + if expr.line < 0: + expr.set_line(star) + self.set_repr(expr, noderepr.StarExprRepr(star)) + return expr + def parse_empty_tuple_expr(self, lparen: Any) -> TupleExpr: rparen = self.expect(')') node = TupleExpr([]) @@ -1155,7 +1166,7 @@ def parse_list_expr(self) -> Node: lbracket = self.expect('[') commas = List[Token]() while self.current_str() != ']' and not self.eol(): - items.append(self.parse_expression(precedence[''])) + items.append(self.parse_expression(precedence[''], star_expr_allowed=True)) if self.current_str() != ',': break commas.append(self.expect(',')) @@ -1174,7 +1185,7 @@ def parse_list_expr(self) -> Node: return expr def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: - indices = List[List[Node]]() + indices = List[Node]() sequences = List[Node]() for_toks = List[Token]() in_toks = List[Token]() @@ -1184,7 +1195,7 @@ def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: if_toks = List[Token]() conds = List[Node]() for_toks.append(self.expect('for')) - index, commas = self.parse_for_index_variables() + index = self.parse_for_index_variables() indices.append(index) in_toks.append(self.expect('in')) sequence = self.parse_expression_list() @@ -1197,8 +1208,7 @@ def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: gen = GeneratorExpr(left_expr, indices, sequences, condlists) gen.set_line(for_toks[0]) - self.set_repr(gen, noderepr.GeneratorExprRepr(for_toks, commas, in_toks, - if_toklists)) + self.set_repr(gen, noderepr.GeneratorExprRepr(for_toks, in_toks, if_toklists)) return gen def parse_expression_list(self) -> Node: @@ -1263,7 +1273,7 @@ def parse_tuple_expr(self, expr: Node, if (self.current_str() in [')', ']', '='] or isinstance(self.current(), Break)): break - items.append(self.parse_expression(prec)) + items.append(self.parse_expression(prec, star_expr_allowed=True)) if self.current_str() != ',': break node = TupleExpr(items) self.set_repr(node, noderepr.TupleExprRepr(none, commas, none)) diff --git a/mypy/parsetype.py b/mypy/parsetype.py index 2143b108a4f7..46c5c60d1287 100644 --- a/mypy/parsetype.py +++ b/mypy/parsetype.py @@ -3,7 +3,7 @@ from typing import List, Tuple, Union, cast from mypy.types import ( - Type, UnboundType, TupleType, UnionType, TypeList, AnyType, Callable + Type, UnboundType, TupleType, UnionType, TypeList, AnyType, Callable, StarType ) from mypy.typerepr import CommonTypeRepr, ListTypeRepr from mypy.lex import Token, Name, StrLit, Break, lex @@ -57,6 +57,8 @@ def parse_type(self) -> Type: return self.parse_named_type() elif t.string == '[': return self.parse_type_list() + elif t.string == '*': + return self.parse_star_type() elif isinstance(t, StrLit): # Type escaped as string literal. typestr = t.parsed() @@ -94,7 +96,7 @@ def parse_types(self) -> Type: if self.current_token_str() == ')': break items.append(self.parse_type()) - type = TupleType(items, None) + type = TupleType(items, None, type.line) return type def parse_type_list(self) -> TypeList: @@ -145,6 +147,11 @@ def parse_named_type(self) -> Type: commas, rangle)) return typ + def parse_star_type(self) -> Type: + star = self.expect('*') + type = self.parse_type() + return StarType(type, star.line) + # Helpers def skip(self) -> Token: diff --git a/mypy/semanal.py b/mypy/semanal.py index e8d810308eaa..ade9a57a57e1 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -41,7 +41,7 @@ """ from typing import ( - Undefined, List, Dict, Set, Tuple, cast, Any, overload, typevar + Undefined, List, Dict, Set, Tuple, cast, Any, overload, typevar, Union ) from mypy.nodes import ( @@ -56,7 +56,7 @@ SymbolTableNode, TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr, StrExpr, PrintStmt, ConditionalExpr, DucktypeExpr, DisjointclassExpr, - ComparisonExpr, ARG_POS, ARG_NAMED, MroError, type_aliases + ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases ) from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor @@ -64,7 +64,7 @@ from mypy.types import ( NoneTyp, Callable, Overloaded, Instance, Type, TypeVar, AnyType, FunctionLike, UnboundType, TypeList, ErrorType, TypeVarDef, - replace_leading_arg_type, TupleType, UnionType + replace_leading_arg_type, TupleType, UnionType, StarType ) from mypy.nodes import function_type, implicit_module_attrs from mypy.typeanal import TypeAnalyser, TypeAnalyserPass3, analyse_node @@ -800,11 +800,39 @@ def analyse_lvalue(self, lval: Node, nested: bool = False, items = (Any(lval)).items if len(items) == 0 and isinstance(lval, TupleExpr): self.fail("Can't assign to ()", lval) + self.analyse_tuple_or_list_lvalue(cast(Union[ListExpr, TupleExpr], lval), + add_global, explicit_type) + elif isinstance(lval, StarExpr): + if nested: + self.analyse_lvalue(lval.expr, nested, add_global, explicit_type) + else: + self.fail('Starred assignment target must be in a list or tuple', lval) + else: + self.fail('Invalid assignment target', lval) + + def analyse_tuple_or_list_lvalue(self, lval: Union[ListExpr, TupleExpr], + add_global: bool = False, + explicit_type: bool = False) -> None: + """Analyze an lvalue or assignment target that is a list or tuple.""" + items = lval.items + + def strip_parens(node: Node) -> Node: + if isinstance(node, ParenExpr): + return strip_parens(node.expr) + else: + return node + + star_exprs = [cast(StarExpr, strip_parens(item)) for item in items + if isinstance(strip_parens(item), StarExpr)] + + if len(star_exprs) > 1: + self.fail('Two starred expressions in assignment', lval) + else: + if len(star_exprs) == 1: + star_exprs[0].valid = True for i in items: self.analyse_lvalue(i, nested=True, add_global=add_global, explicit_type = explicit_type) - else: - self.fail('Invalid assignment target', lval) def analyse_member_lvalue(self, lval: MemberExpr) -> None: lval.accept(self) @@ -839,6 +867,8 @@ def infer_type_from_undefined(self, rvalue: Node) -> Type: return None def store_declared_types(self, lvalue: Node, typ: Type) -> None: + if isinstance(typ, StarType) and not isinstance(lvalue, StarExpr): + self.fail('Star type only allowed for starred expressions', lvalue) if isinstance(lvalue, RefExpr): lvalue.is_def = False if isinstance(lvalue.node, Var): @@ -856,6 +886,11 @@ def store_declared_types(self, lvalue: Node, typ: Type) -> None: else: self.fail('Tuple type expected for multiple variables', lvalue) + elif isinstance(lvalue, StarExpr): + if isinstance(typ, StarType): + self.store_declared_types(lvalue.expr, typ.type) + else: + self.fail('Star type expected for starred expression', lvalue) elif isinstance(lvalue, ParenExpr): self.store_declared_types(lvalue.expr, typ) else: @@ -1027,8 +1062,7 @@ def visit_for_stmt(self, s: ForStmt) -> None: s.expr.accept(self) # Bind index variables and check if they define new names. - for n in s.index: - self.analyse_lvalue(n) + self.analyse_lvalue(s.index) self.loop_depth += 1 self.visit_block(s.body) @@ -1130,6 +1164,12 @@ def visit_dict_expr(self, expr: DictExpr) -> None: def visit_paren_expr(self, expr: ParenExpr) -> None: expr.expr.accept(self) + def visit_star_expr(self, expr: StarExpr) -> None: + if not expr.valid: + self.fail('Can use starred expression only as assignment target', expr) + else: + expr.expr.accept(self) + def visit_call_expr(self, expr: CallExpr) -> None: """Analyze a call expression. @@ -1299,8 +1339,7 @@ def visit_generator_expr(self, expr: GeneratorExpr) -> None: expr.condlists): sequence.accept(self) # Bind index variables. - for n in index: - self.analyse_lvalue(n) + self.analyse_lvalue(index) for cond in conditions: cond.accept(self) @@ -1561,8 +1600,7 @@ def visit_var_def(self, d: VarDef) -> None: self.sem.cur_mod_id) def visit_for_stmt(self, s: ForStmt) -> None: - for n in s.index: - self.sem.analyse_lvalue(n, add_global=True) + self.sem.analyse_lvalue(s.index, add_global=True) def visit_with_stmt(self, s: WithStmt) -> None: for n in s.name: diff --git a/mypy/strconv.py b/mypy/strconv.py index c0ba5c17979f..12333ed11f2d 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -294,6 +294,9 @@ def visit_float_expr(self, o): def visit_paren_expr(self, o): return self.dump([o.expr], o) + def visit_star_expr(self, o): + return self.dump([o.expr], o) + def visit_name_expr(self, o): return (short_type(o) + '(' + self.pretty_name(o.name, o.kind, o.fullname, o.is_def) diff --git a/mypy/test/data/check-dynamic-typing.test b/mypy/test/data/check-dynamic-typing.test index c474756b3b62..550dff4ec854 100644 --- a/mypy/test/data/check-dynamic-typing.test +++ b/mypy/test/data/check-dynamic-typing.test @@ -21,7 +21,7 @@ d = Undefined # type: Any a, b = Undefined, Undefined # type: (A, B) d, a = b, b # E: Incompatible types in assignment (expression has type "B", variable has type "A") -d, d = d, d, d # E: Too many values to assign +d, d = d, d, d # E: Too many values to unpack (2 expected, 3 provided) a, b = d, d d, d = a, b diff --git a/mypy/test/data/check-lists.test b/mypy/test/data/check-lists.test index 19e20ae987c0..61fc351a82fa 100644 --- a/mypy/test/data/check-lists.test +++ b/mypy/test/data/check-lists.test @@ -22,8 +22,8 @@ from typing import Undefined, List a, b, c = Undefined, Undefined, Undefined # type: (A, B, C) a, b = [a, b] -a, b = [a] # E: Need 2 values to assign -a, b = [a, b, c] # E: Too many values to assign +a, b = [a] # E: Need more than 1 value to unpack (2 expected) +a, b = [a, b, c] # E: Too many values to unpack (2 expected, 3 provided) class A: pass class B: pass @@ -51,12 +51,14 @@ class C: pass from typing import Undefined, List a, b, c = Undefined, Undefined, Undefined # type: (A, B, C) -a, b = [a, b] -a, b = [a] # E: Need 2 values to assign -a, b = [a, b, c] # E: Too many values to assign +def f() -> None: # needed because test parser tries to parse [a, b] as section header + [a, b] = [a, b] + [a, b] = [a] # E: Need more than 1 value to unpack (2 expected) + [a, b] = [a, b, c] # E: Too many values to unpack (2 expected, 3 provided) class A: pass class B: pass class C: pass [builtins fixtures/list.py] -[out] \ No newline at end of file +[out] +main: In function "f": \ No newline at end of file diff --git a/mypy/test/data/check-statements.test b/mypy/test/data/check-statements.test index f99786c4057b..936961be842f 100644 --- a/mypy/test/data/check-statements.test +++ b/mypy/test/data/check-statements.test @@ -757,6 +757,23 @@ x = 1 y = 1 +-- Star assignment +-- --------------- + + +[case testAssignListToStarExpr] +from typing import Undefined, List +bs, cs = Undefined, Undefined # type: List[A], List[B] +*bs, b = bs +*bs, c = cs # E: Incompatible types in assignment (expression has type List[B], variable has type List[A]) +*ns, c = cs +nc = cs + +class A: pass +class B: pass +[builtins fixtures/list.py] + + -- Type aliases -- ------------ diff --git a/mypy/test/data/check-tuples.test b/mypy/test/data/check-tuples.test index 031992992fec..b197da2b7736 100644 --- a/mypy/test/data/check-tuples.test +++ b/mypy/test/data/check-tuples.test @@ -279,7 +279,7 @@ a1, b1 = a, a # type: (A, B) # E: Incompatible types in assignment (expression a2, b2 = b, b # type: (A, B) # E: Incompatible types in assignment (expression has type "B", variable has type "A") a3, b3 = a # type: (A, B) # E: '__main__.A' object is not iterable a4, b4 = None # type: (A, B) # E: ''None'' object is not iterable -a5, b5 = a, b, a # type: (A, B) # E: Too many values to assign +a5, b5 = a, b, a # type: (A, B) # E: Too many values to unpack (2 expected, 3 provided) ax, bx = a, b # type: (A, B) @@ -372,9 +372,139 @@ a = '' # E: Incompatible types in assignment (expression has type "str", variabl b = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "str") +-- Assignment to starred expressions +-- --------------------------------- + + +[case testAssignmentToStarMissingAnnotation] +from typing import Undefined, List +t = 1, 2 +a, b, *c = 1, 2 # E: Need type annotation for variable +aa, bb, *cc = t # E: Need type annotation for variable +[builtins fixtures/list.py] + +[case testAssignmentToStarAnnotation] +from typing import Undefined, List +li, lo = Undefined, Undefined # type: List[int], List[object] +a, b, *c = 1, 2 # type: int, int, *List[int] +c = lo # E: Incompatible types in assignment (expression has type List[object], variable has type List[int]) +c = li +[builtins fixtures/list.py] + +[case testAssignmentToStarCount1] +from typing import Undefined, List +ca = Undefined # type: List[int] +c = [1] +a, b, *c = 1, # E: Need more than 1 value to unpack (2 expected) +a, b, *c = 1, 2 +a, b, *c = 1, 2, 3 +a, b, *c = 1, 2, 3, 4 +[builtins fixtures/list.py] + +[case testAssignmentToStarCount2] +from typing import Undefined, List +ca = Undefined # type: List[int] +t1 = 1, +t2 = 1, 2 +t3 = 1, 2, 3 +t4 = 1, 2, 3, 4 +c = [1] +a, b, *c = t1 # E: Need more than 1 value to unpack (2 expected) +a, b, *c = t2 +a, b, *c = t3 +a, b, *c = t4 +[builtins fixtures/list.py] + +[case testAssignmentToStarFromAny] +from typing import Undefined, Any +a, c = Any(1), C() +p, *q = a +c = a +c = q + +class C: pass + +[case testAssignmentToComplexStar] +from typing import Undefined, List +li = Undefined # type: List[int] +a, *(li) = 1, +a, *(b, c) = 1, 2 # E: Need more than 1 value to unpack (2 expected) +a, *(b, c) = 1, 2, 3 +a, *(b, c) = 1, 2, 3, 4 # E: Too many values to unpack (2 expected, 3 provided) +[builtins fixtures/list.py] + +[case testAssignmentToStarFromTupleType] +from typing import Undefined, List, Tuple +li = Undefined # type: List[int] +la = Undefined # type: List[A] +ta = Undefined # type: Tuple[A, A, A] +a, *la = ta +a, *li = ta # E +a, *na = ta +na = la +na = a # E + +class A: pass +[builtins fixtures/list.py] +[out] +main, line 6: List item 1 has incompatible type "A" +main, line 6: List item 2 has incompatible type "A" +main, line 9: Incompatible types in assignment (expression has type "A", variable has type List[A]) + +[case testAssignmentToStarFromTupleInference] +from typing import Undefined, List +li = Undefined # type: List[int] +la = Undefined # type: List[A] +a, *l = A(), A() +l = li # E: Incompatible types in assignment (expression has type List[int], variable has type List[A]) +l = la + +class A: pass +[builtins fixtures/list.py] +[out] + +[case testAssignmentToStarFromListInference] +from typing import Undefined, List +li = Undefined # type: List[int] +la = Undefined # type: List[A] +a, *l = [A(), A()] +l = li # E: Incompatible types in assignment (expression has type List[int], variable has type List[A]) +l = la + +class A: pass +[builtins fixtures/list.py] +[out] + +[case testAssignmentToStarFromTupleTypeInference] +from typing import Undefined, List, Tuple +li = Undefined # type: List[int] +la = Undefined # type: List[A] +ta = Undefined # type: Tuple[A, A, A] +a, *l = ta +l = li # E: Incompatible types in assignment (expression has type List[int], variable has type List[A]) +l = la + +class A: pass +[builtins fixtures/list.py] +[out] + +[case testAssignmentToStarFromListTypeInference] +from typing import Undefined, List +li = Undefined # type: List[int] +la = Undefined # type: List[A] +a, *l = la +l = li # E: Incompatible types in assignment (expression has type List[int], variable has type List[A]) +l = la + +class A: pass +[builtins fixtures/list.py] +[out] + + -- Nested tuple assignment -- ---------------------------- + [case testNestedTupleAssignment1] from typing import Undefined a1, b1, c1 = Undefined, Undefined, Undefined # type: (A, B, C) diff --git a/mypy/test/data/parse-errors.test b/mypy/test/data/parse-errors.test index 6ce444272a8f..e4ab097927a1 100644 --- a/mypy/test/data/parse-errors.test +++ b/mypy/test/data/parse-errors.test @@ -52,6 +52,11 @@ file, line 1: Parse error before end of line file, line 2: Parse error before end of line file, line 3: Parse error before end of line +[case testDoubleStar] +**a +[out] +file, line 1: Parse error before ** + [case testInvalidSuperClass] class A(C[): pass @@ -191,6 +196,12 @@ x = 0 # type: [out] file, line 2: Empty type annotation +[case testInvalidTypeComment4] +0 +x = 0 # type: * +[out] +file, line 2: Parse error before end of line + [case testInvalidMultilineLiteralType] def f() -> "A\nB": pass [out] diff --git a/mypy/test/data/parse.test b/mypy/test/data/parse.test index f56a2084115c..17049d9c6310 100644 --- a/mypy/test/data/parse.test +++ b/mypy/test/data/parse.test @@ -843,11 +843,12 @@ MypyFile:1( Block:1( PassStmt:2())) ForStmt:3( - NameExpr(x) - ParenExpr:3( - TupleExpr:3( - NameExpr(y) - NameExpr(w))) + TupleExpr:3( + NameExpr(x) + ParenExpr:3( + TupleExpr:3( + NameExpr(y) + NameExpr(w)))) NameExpr(z) Block:3( ExpressionStmt:4( @@ -1196,11 +1197,12 @@ MypyFile:1( ExpressionStmt:1( GeneratorExpr:1( NameExpr(x) - NameExpr(y) - ParenExpr:1( - TupleExpr:1( - NameExpr(p) - NameExpr(q))) + TupleExpr:1( + NameExpr(y) + ParenExpr:1( + TupleExpr:1( + NameExpr(p) + NameExpr(q)))) NameExpr(z)))) [case testListComprehension] @@ -1227,8 +1229,9 @@ MypyFile:1( TupleExpr:1( NameExpr(x) NameExpr(y))) - NameExpr(y) - NameExpr(z) + TupleExpr:1( + NameExpr(y) + NameExpr(z)) TupleExpr:1( IntExpr(1) IntExpr(2)))))) @@ -1458,8 +1461,10 @@ for (i, j) in x: [out] MypyFile:1( ForStmt:1( - NameExpr(i) - NameExpr(j) + ParenExpr:1( + TupleExpr:1( + NameExpr(i) + NameExpr(j))) NameExpr(x) Block:1( PassStmt:2()))) @@ -2714,3 +2719,123 @@ MypyFile:1( ClassDef:1( A PassStmt:2())) + +[case testStarExpression] +*a +*a, b +a, *b +a, (*x, y) +a, (x, *y) +[out] +MypyFile:1( + ExpressionStmt:1( + StarExpr:1( + NameExpr(a))) + ExpressionStmt:2( + TupleExpr:2( + StarExpr:2( + NameExpr(a)) + NameExpr(b))) + ExpressionStmt:3( + TupleExpr:3( + NameExpr(a) + StarExpr:3( + NameExpr(b)))) + ExpressionStmt:4( + TupleExpr:4( + NameExpr(a) + ParenExpr:4( + TupleExpr:4( + StarExpr:4( + NameExpr(x)) + NameExpr(y))))) + ExpressionStmt:5( + TupleExpr:5( + NameExpr(a) + ParenExpr:5( + TupleExpr:5( + NameExpr(x) + StarExpr:5( + NameExpr(y))))))) + +[case testStarExpressionParenthesis] +*(a) +*(a,b) +[out] +MypyFile:1( + ExpressionStmt:1( + StarExpr:1( + ParenExpr:1( + NameExpr(a)))) + ExpressionStmt:2( + StarExpr:2( + ParenExpr:2( + TupleExpr:2( + NameExpr(a) + NameExpr(b)))))) + +[case testStarExpressionInFor] +for *a in b: + pass + +for a, *b in c: + pass + +for *a, b in c: + pass +[out] +MypyFile:1( + ForStmt:1( + StarExpr:1( + NameExpr(a)) + NameExpr(b) + Block:1( + PassStmt:2())) + ForStmt:4( + TupleExpr:4( + NameExpr(a) + StarExpr:4( + NameExpr(b))) + NameExpr(c) + Block:4( + PassStmt:5())) + ForStmt:7( + TupleExpr:7( + StarExpr:7( + NameExpr(a)) + NameExpr(b)) + NameExpr(c) + Block:7( + PassStmt:8()))) + +[case testStarExprInGeneratorExpr] +x for y, *p in z +x for *p, y in z +x for y, *p, q in z +[out] +MypyFile:1( + ExpressionStmt:1( + GeneratorExpr:1( + NameExpr(x) + TupleExpr:1( + NameExpr(y) + StarExpr:1( + NameExpr(p))) + NameExpr(z))) + ExpressionStmt:2( + GeneratorExpr:2( + NameExpr(x) + TupleExpr:2( + StarExpr:2( + NameExpr(p)) + NameExpr(y)) + NameExpr(z))) + ExpressionStmt:3( + GeneratorExpr:3( + NameExpr(x) + TupleExpr:3( + NameExpr(y) + StarExpr:3( + NameExpr(p)) + NameExpr(q)) + NameExpr(z)))) \ No newline at end of file diff --git a/mypy/test/data/semanal-errors.test b/mypy/test/data/semanal-errors.test index f7d2401f52f4..6b9e950835b1 100644 --- a/mypy/test/data/semanal-errors.test +++ b/mypy/test/data/semanal-errors.test @@ -159,6 +159,15 @@ z = 0 # type: x main, line 4: Invalid type "__main__.f" main, line 5: Invalid type "__main__.x" +[case testTwoStarsInType] +import typing +x = 1 # type: *object, *object +y = 1 # type: object, (*object, *object) +z = 1 # type: *object, (object, *object) +[out] +main, line 2: At most one star type allowed in a tuple +main, line 3: At most one star type allowed in a tuple + [case testGlobalVarRedefinition] import typing class A: pass @@ -389,6 +398,65 @@ main, line 4: Invalid assignment target main, line 5: Invalid assignment target main, line 6: Invalid assignment target +[case testInvalidStarType] +a = 1 # type: *int +[out] +main, line 1: Star type only allowed for starred expressions + +[case testInvalidStarType] +*a, b = 1 # type: int, int +[out] +main, line 1: Star type expected for starred expression + +[case testTwoStarExpressions] +a, *b, *c = 1 +*a, (*b, c) = 1 +a, (*b, *c) = 1 +[*a, *b] = 1 +[out] +main, line 1: Two starred expressions in assignment +main, line 3: Two starred expressions in assignment +main, line 4: Two starred expressions in assignment + +[case testTwoStarExpressionsInForStmt] +for a, *b, *c in z: + pass +for *a, (*b, c) in z: + pass +for a, (*b, *c) in z: + pass +for [*a, *b] in z: + pass +[out] +main, line 1: Two starred expressions in assignment +main, line 5: Two starred expressions in assignment +main, line 7: Two starred expressions in assignment + +[case testTwoStarExpressionsInGeneratorExpr] +(a for a, *b, *c in []) +(a for *a, (*b, c) in []) +(a for a, (*b, *c) in []) +[out] +main, line 1: Two starred expressions in assignment +main, line 1: Name 'a' is not defined +main, line 3: Two starred expressions in assignment + +[case testStarExpressionRhs] +b = 1 +c = 1 +d = 1 +a = *b +a = b, (c, *d) +[out] +main, line 4: Can use starred expression only as assignment target +main, line 5: Can use starred expression only as assignment target + +[case testStarExpressionInExp] +a = 1 +*a + 1 +[out] +main, line 2: Can use starred expression only as assignment target + [case testInvalidDel] import typing x = 1 diff --git a/mypy/test/data/semanal-expressions.test b/mypy/test/data/semanal-expressions.test index f2dc78eee68e..f79676f81924 100644 --- a/mypy/test/data/semanal-expressions.test +++ b/mypy/test/data/semanal-expressions.test @@ -261,11 +261,12 @@ MypyFile:1( ExpressionStmt:2( GeneratorExpr:2( NameExpr(x [l]) - NameExpr(x* [l]) - ParenExpr:2( - TupleExpr:2( - NameExpr(y* [l]) - NameExpr(z* [l]))) + TupleExpr:2( + NameExpr(x* [l]) + ParenExpr:2( + TupleExpr:2( + NameExpr(y* [l]) + NameExpr(z* [l])))) NameExpr(a [__main__.a])))) [case testLambda] diff --git a/mypy/test/data/semanal-statements.test b/mypy/test/data/semanal-statements.test index 502b37a9a485..26555adf3596 100644 --- a/mypy/test/data/semanal-statements.test +++ b/mypy/test/data/semanal-statements.test @@ -113,8 +113,9 @@ for x, y in []: [out] MypyFile:1( ForStmt:1( - NameExpr(x* [__main__.x]) - NameExpr(y* [__main__.y]) + TupleExpr:1( + NameExpr(x* [__main__.x]) + NameExpr(y* [__main__.y])) ListExpr:1() Block:1( ExpressionStmt:2( @@ -347,6 +348,38 @@ MypyFile:1( NameExpr(y [__main__.y]))) IntExpr(1))) +[case testStarLvalues] +*x, y = 1 +*x, (y, *z) = 1 +*(x, q), r = 1 +[out] +MypyFile:1( + AssignmentStmt:1( + TupleExpr:1( + StarExpr:1( + NameExpr(x* [__main__.x])) + NameExpr(y* [__main__.y])) + IntExpr(1)) + AssignmentStmt:2( + TupleExpr:2( + StarExpr:2( + NameExpr(x [__main__.x])) + ParenExpr:2( + TupleExpr:2( + NameExpr(y [__main__.y]) + StarExpr:2( + NameExpr(z* [__main__.z]))))) + IntExpr(1)) + AssignmentStmt:3( + TupleExpr:3( + StarExpr:3( + ParenExpr:3( + TupleExpr:3( + NameExpr(x [__main__.x]) + NameExpr(q* [__main__.q])))) + NameExpr(r* [__main__.r])) + IntExpr(1))) + [case testMultipleDefinition] x, y = 1 x, y = 2 diff --git a/mypy/traverser.py b/mypy/traverser.py index da3e4e31c10e..8792edbbb880 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -85,8 +85,7 @@ def visit_while_stmt(self, o: WhileStmt) -> T: o.else_body.accept(self) def visit_for_stmt(self, o: ForStmt) -> T: - for ind in o.index: - ind.accept(self) + o.index.accept(self) o.expr.accept(self) o.body.accept(self) if o.else_body: @@ -202,8 +201,7 @@ def visit_generator_expr(self, o: GeneratorExpr) -> T: for index, sequence, conditions in zip(o.indices, o.sequences, o.condlists): sequence.accept(self) - for ind in index: - ind.accept(self) + index.accept(self) for cond in conditions: cond.accept(self) o.left_expr.accept(self) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 8d3113e68b1b..ca73797ec154 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -16,7 +16,7 @@ UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, UnaryExpr, FuncExpr, TypeApplication, PrintStmt, SymbolTable, RefExpr, UndefinedExpr, TypeVarExpr, DucktypeExpr, - DisjointclassExpr, ComparisonExpr, TempNode + DisjointclassExpr, ComparisonExpr, TempNode, StarExpr ) from mypy.types import Type, FunctionLike from mypy.visitor import NodeVisitor @@ -203,7 +203,7 @@ def visit_while_stmt(self, node: WhileStmt) -> Node: self.optional_block(node.else_body)) def visit_for_stmt(self, node: ForStmt) -> Node: - return ForStmt(self.nodes(node.index), + return ForStmt(self.node(node.index), self.node(node.expr), self.block(node.body), self.optional_block(node.else_body)) @@ -255,6 +255,9 @@ def visit_print_stmt(self, node: PrintStmt) -> Node: return PrintStmt(self.nodes(node.args), node.newline) + def visit_star_expr(self, node: StarExpr) -> Node: + return StarExpr(node.expr) + def visit_int_expr(self, node: IntExpr) -> Node: return IntExpr(node.value) @@ -371,7 +374,7 @@ def visit_generator_expr(self, node: GeneratorExpr) -> Node: def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr: return GeneratorExpr(self.node(node.left_expr), - [self.nodes(index) for index in node.indices], + [self.node(index) for index in node.indices], [self.node(s) for s in node.sequences], [[self.node(cond) for cond in conditions] for conditions in node.condlists]) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d01d1a10bc89..7387aafd8418 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -4,7 +4,7 @@ from mypy.types import ( Type, UnboundType, TypeVar, TupleType, UnionType, Instance, AnyType, Callable, - Void, NoneTyp, TypeList, TypeVarDef, TypeVisitor + Void, NoneTyp, TypeList, TypeVarDef, TypeVisitor, StarType ) from mypy.typerepr import TypeVarRepr from mypy.nodes import ( @@ -153,10 +153,17 @@ def visit_callable(self, t: Callable) -> Type: return res def visit_tuple_type(self, t: TupleType) -> Type: + star_count = sum(1 for item in t.items if isinstance(item, StarType)) + if star_count > 1: + self.fail('At most one star type allowed in a tuple', t) + return AnyType() return TupleType(self.anal_array(t.items), self.builtin_type('builtins.tuple'), t.line, t.repr) + def visit_star_type(self, t: StarType) -> Type: + return StarType(t.type.accept(self), t.line, t.repr) + def visit_union_type(self, t: UnionType) -> Type: return UnionType(self.anal_array(t.items), t.line, t.repr) diff --git a/mypy/types.py b/mypy/types.py index ea3fc740420e..2b4891277d0a 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -390,6 +390,19 @@ def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_tuple_type(self) +class StarType(Type): + """The star type *type_parameter""" + + type = Undefined(Type) + + def __init__(self, type: Type, line: int = -1, repr: Any = None) -> None: + self.type = type + super().__init__(line, repr) + + def accept(self, visitor: 'TypeVisitor[T]') -> T: + return visitor.visit_star_type(self) + + class UnionType(Type): """The union type Union[T1, ..., Tn] (at least one type argument).""" @@ -512,6 +525,9 @@ def visit_overloaded(self, t: Overloaded) -> T: def visit_tuple_type(self, t: TupleType) -> T: pass + def visit_star_type(self, t: StarType) -> T: + pass + def visit_union_type(self, t: UnionType) -> T: assert(0) # XXX catch visitors that don't have Union cases yet @@ -569,6 +585,9 @@ def visit_tuple_type(self, t: TupleType) -> Type: Any(t.fallback.accept(self)), t.line, t.repr) + def visit_star_type(self, t: StarType) -> Type: + return StarType(t.type.accept(self), t.line, t.repr) + def visit_union_type(self, t: UnionType) -> Type: return UnionType(self.translate_types(t.items), t.line, t.repr) @@ -682,6 +701,10 @@ def visit_tuple_type(self, t): s = self.list_str(t.items) return 'Tuple[{}]'.format(s) + def visit_star_type(self, t): + s = t.type.accept(self) + return '*{}'.format(s) + def visit_union_type(self, t): s = self.list_str(t.items) return 'Union[{}]'.format(s) @@ -763,6 +786,9 @@ def visit_callable(self, t: Callable) -> bool: def visit_tuple_type(self, t: TupleType) -> bool: return self.query_types(t.items) + def visit_star_type(self, t: StarType) -> bool: + return t.type.accept(self) + def visit_union_type(self, t: UnionType) -> bool: return self.query_types(t.items) diff --git a/mypy/visitor.py b/mypy/visitor.py index 52f47167d37e..ae0ee1de55d9 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -133,6 +133,9 @@ def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> T: def visit_paren_expr(self, o: 'mypy.nodes.ParenExpr') -> T: pass + def visit_star_expr(self, o: 'mypy.nodes.StarExpr') -> T: + pass + def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T: pass