Skip to content

Commit 18ab589

Browse files
authored
[mypyc] Speed up and improve multiple assignment (#9800)
Speed up multiple assignment from variable-length lists and tuples. This speeds up the `multiple_assignment` benchmark by around 80%. Fix multiple lvalues in fixed-length sequence assignments. Optimize some cases of list expressions in assignments. Fixes mypyc/mypyc#729.
1 parent 7c0c1e7 commit 18ab589

File tree

8 files changed

+218
-61
lines changed

8 files changed

+218
-61
lines changed

mypyc/irbuild/builder.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@
3939
from mypyc.ir.rtypes import (
4040
RType, RTuple, RInstance, int_rprimitive, dict_rprimitive,
4141
none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive,
42-
str_rprimitive, is_tagged
42+
str_rprimitive, is_tagged, is_list_rprimitive, is_tuple_rprimitive, c_pyssize_t_rprimitive
4343
)
4444
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
4545
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
4646
from mypyc.primitives.registry import CFunctionDescription, function_ops
47-
from mypyc.primitives.list_ops import to_list, list_pop_last
47+
from mypyc.primitives.list_ops import to_list, list_pop_last, list_get_item_unsafe_op
4848
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
4949
from mypyc.primitives.generic_ops import py_setattr_op, iter_op, next_op
50-
from mypyc.primitives.misc_ops import import_op
50+
from mypyc.primitives.misc_ops import import_op, check_unpack_count_op
5151
from mypyc.crash import catch_errors
5252
from mypyc.options import CompilerOptions
5353
from mypyc.errors import Errors
@@ -460,8 +460,10 @@ def read(self, target: Union[Value, AssignmentTarget], line: int = -1) -> Value:
460460

461461
assert False, 'Unsupported lvalue: %r' % target
462462

463-
def assign(self, target: Union[Register, AssignmentTarget],
464-
rvalue_reg: Value, line: int) -> None:
463+
def assign(self,
464+
target: Union[Register, AssignmentTarget],
465+
rvalue_reg: Value,
466+
line: int) -> None:
465467
if isinstance(target, Register):
466468
self.add(Assign(target, rvalue_reg))
467469
elif isinstance(target, AssignmentTargetRegister):
@@ -486,11 +488,39 @@ def assign(self, target: Union[Register, AssignmentTarget],
486488
for i in range(len(rtypes)):
487489
item_value = self.add(TupleGet(rvalue_reg, i, line))
488490
self.assign(target.items[i], item_value, line)
491+
elif ((is_list_rprimitive(rvalue_reg.type) or is_tuple_rprimitive(rvalue_reg.type))
492+
and target.star_idx is None):
493+
self.process_sequence_assignment(target, rvalue_reg, line)
489494
else:
490495
self.process_iterator_tuple_assignment(target, rvalue_reg, line)
491496
else:
492497
assert False, 'Unsupported assignment target'
493498

499+
def process_sequence_assignment(self,
500+
target: AssignmentTargetTuple,
501+
rvalue: Value,
502+
line: int) -> None:
503+
"""Process assignment like 'x, y = s', where s is a variable-length list or tuple."""
504+
# Check the length of sequence.
505+
expected_len = self.add(LoadInt(len(target.items), rtype=c_pyssize_t_rprimitive))
506+
self.builder.call_c(check_unpack_count_op, [rvalue, expected_len], line)
507+
508+
# Read sequence items.
509+
values = []
510+
for i in range(len(target.items)):
511+
item = target.items[i]
512+
index = self.builder.load_static_int(i)
513+
if is_list_rprimitive(rvalue.type):
514+
item_value = self.call_c(list_get_item_unsafe_op, [rvalue, index], line)
515+
else:
516+
item_value = self.builder.gen_method_call(
517+
rvalue, '__getitem__', [index], item.type, line)
518+
values.append(item_value)
519+
520+
# Assign sequence items to the target lvalues.
521+
for lvalue, value in zip(target.items, values):
522+
self.assign(lvalue, value, line)
523+
494524
def process_iterator_tuple_assignment_helper(self,
495525
litem: AssignmentTarget,
496526
ritem: Value, line: int) -> None:

mypyc/irbuild/statement.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from mypy.nodes import (
1313
Block, ExpressionStmt, ReturnStmt, AssignmentStmt, OperatorAssignmentStmt, IfStmt, WhileStmt,
1414
ForStmt, BreakStmt, ContinueStmt, RaiseStmt, TryStmt, WithStmt, AssertStmt, DelStmt,
15-
Expression, StrExpr, TempNode, Lvalue, Import, ImportFrom, ImportAll, TupleExpr
15+
Expression, StrExpr, TempNode, Lvalue, Import, ImportFrom, ImportAll, TupleExpr, ListExpr,
16+
StarExpr
1617
)
1718

1819
from mypyc.ir.ops import (
@@ -69,39 +70,47 @@ def transform_return_stmt(builder: IRBuilder, stmt: ReturnStmt) -> None:
6970

7071

7172
def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None:
72-
assert len(stmt.lvalues) >= 1
73-
builder.disallow_class_assignments(stmt.lvalues, stmt.line)
74-
lvalue = stmt.lvalues[0]
73+
lvalues = stmt.lvalues
74+
assert len(lvalues) >= 1
75+
builder.disallow_class_assignments(lvalues, stmt.line)
76+
first_lvalue = lvalues[0]
7577
if stmt.type and isinstance(stmt.rvalue, TempNode):
7678
# This is actually a variable annotation without initializer. Don't generate
7779
# an assignment but we need to call get_assignment_target since it adds a
7880
# name binding as a side effect.
79-
builder.get_assignment_target(lvalue, stmt.line)
81+
builder.get_assignment_target(first_lvalue, stmt.line)
8082
return
8183

82-
# multiple assignment
83-
if (isinstance(lvalue, TupleExpr) and isinstance(stmt.rvalue, TupleExpr)
84-
and len(lvalue.items) == len(stmt.rvalue.items)):
84+
# Special case multiple assignments like 'x, y = e1, e2'.
85+
if (isinstance(first_lvalue, (TupleExpr, ListExpr))
86+
and isinstance(stmt.rvalue, (TupleExpr, ListExpr))
87+
and len(first_lvalue.items) == len(stmt.rvalue.items)
88+
and all(is_simple_lvalue(item) for item in first_lvalue.items)
89+
and len(lvalues) == 1):
8590
temps = []
8691
for right in stmt.rvalue.items:
8792
rvalue_reg = builder.accept(right)
8893
temp = Register(rvalue_reg.type)
8994
builder.assign(temp, rvalue_reg, stmt.line)
9095
temps.append(temp)
91-
for (left, temp) in zip(lvalue.items, temps):
96+
for (left, temp) in zip(first_lvalue.items, temps):
9297
assignment_target = builder.get_assignment_target(left)
9398
builder.assign(assignment_target, temp, stmt.line)
9499
return
95100

96101
line = stmt.rvalue.line
97102
rvalue_reg = builder.accept(stmt.rvalue)
98103
if builder.non_function_scope() and stmt.is_final_def:
99-
builder.init_final_static(lvalue, rvalue_reg)
100-
for lvalue in stmt.lvalues:
104+
builder.init_final_static(first_lvalue, rvalue_reg)
105+
for lvalue in lvalues:
101106
target = builder.get_assignment_target(lvalue)
102107
builder.assign(target, rvalue_reg, line)
103108

104109

110+
def is_simple_lvalue(expr: Expression) -> bool:
111+
return not isinstance(expr, (StarExpr, ListExpr, TupleExpr))
112+
113+
105114
def transform_operator_assignment_stmt(builder: IRBuilder, stmt: OperatorAssignmentStmt) -> None:
106115
"""Operator assignment statement such as x += 1"""
107116
builder.disallow_class_assignments([stmt.lvalue], stmt.line)

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ void CPyDebug_Print(const char *msg);
497497
void CPy_Init(void);
498498
int CPyArg_ParseTupleAndKeywords(PyObject *, PyObject *,
499499
const char *, char **, ...);
500+
int CPySequence_CheckUnpackCount(PyObject *sequence, Py_ssize_t expected);
500501

501502

502503
#ifdef __cplusplus

mypyc/lib-rt/misc_ops.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,17 @@ void CPyDebug_Print(const char *msg) {
495495
printf("%s\n", msg);
496496
fflush(stdout);
497497
}
498+
499+
int CPySequence_CheckUnpackCount(PyObject *sequence, Py_ssize_t expected) {
500+
Py_ssize_t actual = Py_SIZE(sequence);
501+
if (unlikely(actual != expected)) {
502+
if (actual < expected) {
503+
PyErr_Format(PyExc_ValueError, "not enough values to unpack (expected %zd, got %zd)",
504+
expected, actual);
505+
} else {
506+
PyErr_Format(PyExc_ValueError, "too many values to unpack (expected %zd)", expected);
507+
}
508+
return -1;
509+
}
510+
return 0;
511+
}

mypyc/primitives/misc_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from mypyc.ir.ops import ERR_NEVER, ERR_MAGIC, ERR_FALSE
44
from mypyc.ir.rtypes import (
55
bool_rprimitive, object_rprimitive, str_rprimitive, object_pointer_rprimitive,
6-
int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive
6+
int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive, c_pyssize_t_rprimitive
77
)
88
from mypyc.primitives.registry import (
99
function_op, custom_op, load_address_op, ERR_NEG_INT
@@ -176,3 +176,11 @@
176176
return_type=bit_rprimitive,
177177
c_function_name='CPyDataclass_SleightOfHand',
178178
error_kind=ERR_FALSE)
179+
180+
# Raise ValueError if length of first argument is not equal to the second argument.
181+
# The first argument must be a list or a variable-length tuple.
182+
check_unpack_count_op = custom_op(
183+
arg_types=[object_rprimitive, c_pyssize_t_rprimitive],
184+
return_type=c_int_rprimitive,
185+
c_function_name='CPySequence_CheckUnpackCount',
186+
error_kind=ERR_NEG_INT)

mypyc/test-data/irbuild-basic.test

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3605,48 +3605,6 @@ L0:
36053605
r2 = truncate r0: int32 to builtins.bool
36063606
return r2
36073607

3608-
[case testMultipleAssignment]
3609-
from typing import Tuple
3610-
3611-
def f(x: int, y: int) -> Tuple[int, int]:
3612-
x, y = y, x
3613-
return (x, y)
3614-
3615-
def f2(x: int, y: str, z: float) -> Tuple[float, str, int]:
3616-
a, b, c = x, y, z
3617-
return (c, b, a)
3618-
[out]
3619-
def f(x, y):
3620-
x, y, r0, r1 :: int
3621-
r2 :: tuple[int, int]
3622-
L0:
3623-
r0 = y
3624-
r1 = x
3625-
x = r0
3626-
y = r1
3627-
r2 = (x, y)
3628-
return r2
3629-
def f2(x, y, z):
3630-
x :: int
3631-
y :: str
3632-
z :: float
3633-
r0 :: int
3634-
r1 :: str
3635-
r2 :: float
3636-
a :: int
3637-
b :: str
3638-
c :: float
3639-
r3 :: tuple[float, str, int]
3640-
L0:
3641-
r0 = x
3642-
r1 = y
3643-
r2 = z
3644-
a = r0
3645-
b = r1
3646-
c = r2
3647-
r3 = (c, b, a)
3648-
return r3
3649-
36503608
[case testLocalImportSubmodule]
36513609
def f() -> int:
36523610
import p.m

mypyc/test-data/irbuild-statements.test

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,63 @@ L9:
456456
L10:
457457
return s
458458

459-
[case testMultipleAssignment]
459+
[case testMultipleAssignmentWithNoUnpacking]
460+
from typing import Tuple
461+
462+
def f(x: int, y: int) -> Tuple[int, int]:
463+
x, y = y, x
464+
return (x, y)
465+
466+
def f2(x: int, y: str, z: float) -> Tuple[float, str, int]:
467+
a, b, c = x, y, z
468+
return (c, b, a)
469+
470+
def f3(x: int, y: int) -> Tuple[int, int]:
471+
[x, y] = [y, x]
472+
return (x, y)
473+
[out]
474+
def f(x, y):
475+
x, y, r0, r1 :: int
476+
r2 :: tuple[int, int]
477+
L0:
478+
r0 = y
479+
r1 = x
480+
x = r0
481+
y = r1
482+
r2 = (x, y)
483+
return r2
484+
def f2(x, y, z):
485+
x :: int
486+
y :: str
487+
z :: float
488+
r0 :: int
489+
r1 :: str
490+
r2 :: float
491+
a :: int
492+
b :: str
493+
c :: float
494+
r3 :: tuple[float, str, int]
495+
L0:
496+
r0 = x
497+
r1 = y
498+
r2 = z
499+
a = r0
500+
b = r1
501+
c = r2
502+
r3 = (c, b, a)
503+
return r3
504+
def f3(x, y):
505+
x, y, r0, r1 :: int
506+
r2 :: tuple[int, int]
507+
L0:
508+
r0 = y
509+
r1 = x
510+
x = r0
511+
y = r1
512+
r2 = (x, y)
513+
return r2
514+
515+
[case testMultipleAssignmentBasicUnpacking]
460516
from typing import Tuple, Any
461517

462518
def from_tuple(t: Tuple[int, str]) -> None:
@@ -596,6 +652,45 @@ L0:
596652
z = r6
597653
return 1
598654

655+
[case testMultipleAssignmentUnpackFromSequence]
656+
from typing import List, Tuple
657+
658+
def f(l: List[int], t: Tuple[int, ...]) -> None:
659+
x: object
660+
y: int
661+
x, y = l
662+
x, y = t
663+
[out]
664+
def f(l, t):
665+
l :: list
666+
t :: tuple
667+
x :: object
668+
y :: int
669+
r0 :: int32
670+
r1 :: bit
671+
r2, r3 :: object
672+
r4 :: int
673+
r5 :: int32
674+
r6 :: bit
675+
r7, r8 :: object
676+
r9 :: int
677+
L0:
678+
r0 = CPySequence_CheckUnpackCount(l, 2)
679+
r1 = r0 >= 0 :: signed
680+
r2 = CPyList_GetItemUnsafe(l, 0)
681+
r3 = CPyList_GetItemUnsafe(l, 2)
682+
x = r2
683+
r4 = unbox(int, r3)
684+
y = r4
685+
r5 = CPySequence_CheckUnpackCount(t, 2)
686+
r6 = r5 >= 0 :: signed
687+
r7 = CPySequenceTuple_GetItem(t, 0)
688+
r8 = CPySequenceTuple_GetItem(t, 2)
689+
r9 = unbox(int, r8)
690+
x = r7
691+
y = r9
692+
return 1
693+
599694
[case testAssert]
600695
from typing import Optional
601696

0 commit comments

Comments
 (0)