Skip to content

Commit 47b22c5

Browse files
[mypyc] Use optimized implementation for builtins.sum (#10268)
Partially fixes mypyc/mypyc#796 This PR forces the computation of sum() to be been shown through benchmarking to speed up the execution of sum() over integers by about 2-5x. There is some support for the start argument, but only for if 'start' is a literal expression (has a .value attribute). The current implementation doesn't work with arbitrary values for start, because I couldn't figure out how to get any Expression that could be given to be evaluated fully into something that you can initialize the retval Register to. So for example, these cases will not get optimized: a = 1 sum((x == 0 for x in [0]), a) # won't get evaluated because a is a NameExpr sum((x == 0 for x in [0], 0 + j) # won't get evaluated because 0 + j is an OpExpr Co-authored-by: 97littleleaf11 <[email protected]>
1 parent 6b9cc47 commit 47b22c5

File tree

4 files changed

+123
-8
lines changed

4 files changed

+123
-8
lines changed

mypyc/irbuild/specialize.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from mypy.nodes import (
1818
CallExpr, RefExpr, MemberExpr, NameExpr, TupleExpr, GeneratorExpr,
19-
ListExpr, DictExpr, StrExpr, ARG_POS, Expression
19+
ListExpr, DictExpr, StrExpr, IntExpr, ARG_POS, ARG_NAMED, Expression
2020
)
2121
from mypy.types import AnyType, TypeOfAny
2222

@@ -188,7 +188,6 @@ def translate_set_from_generator_call(
188188
@specialize_function('builtins.tuple')
189189
@specialize_function('builtins.frozenset')
190190
@specialize_function('builtins.dict')
191-
@specialize_function('builtins.sum')
192191
@specialize_function('builtins.min')
193192
@specialize_function('builtins.max')
194193
@specialize_function('builtins.sorted')
@@ -266,6 +265,41 @@ def gen_inner_stmts() -> None:
266265
return retval
267266

268267

268+
@specialize_function('builtins.sum')
269+
def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]:
270+
# specialized implementation is used if:
271+
# - only one or two arguments given (if not, sum() has been given invalid arguments)
272+
# - first argument is a Generator (there is no benefit to optimizing the performance of eg.
273+
# sum([1, 2, 3]), so non-Generator Iterables are not handled)
274+
if not (len(expr.args) in (1, 2)
275+
and expr.arg_kinds[0] == ARG_POS
276+
and isinstance(expr.args[0], GeneratorExpr)):
277+
return None
278+
279+
# handle 'start' argument, if given
280+
if len(expr.args) == 2:
281+
# ensure call to sum() was properly constructed
282+
if not expr.arg_kinds[1] in (ARG_POS, ARG_NAMED):
283+
return None
284+
start_expr = expr.args[1]
285+
else:
286+
start_expr = IntExpr(0)
287+
288+
gen_expr = expr.args[0]
289+
target_type = builder.node_type(expr)
290+
retval = Register(target_type)
291+
builder.assign(retval, builder.coerce(builder.accept(start_expr), target_type, -1), -1)
292+
293+
def gen_inner_stmts() -> None:
294+
call_expr = builder.accept(gen_expr.left_expr)
295+
builder.assign(retval, builder.binary_op(retval, call_expr, '+', -1), -1)
296+
297+
loop_params = list(zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists))
298+
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)
299+
300+
return retval
301+
302+
269303
@specialize_function('dataclasses.field')
270304
@specialize_function('attr.Factory')
271305
def translate_dataclasses_field_call(
@@ -295,12 +329,8 @@ def translate_next_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
295329
return None
296330

297331
gen = expr.args[0]
298-
299332
retval = Register(builder.node_type(expr))
300-
default_val = None
301-
if len(expr.args) > 1:
302-
default_val = builder.accept(expr.args[1])
303-
333+
default_val = builder.accept(expr.args[1]) if len(expr.args) > 1 else None
304334
exit_block = BasicBlock()
305335

306336
def gen_inner_stmts() -> None:

mypyc/test-data/fixtures/ir.py

+1
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ class ZeroDivisionError(Exception): pass
278278

279279
def any(i: Iterable[T]) -> bool: pass
280280
def all(i: Iterable[T]) -> bool: pass
281+
def sum(i: Iterable[T]) -> int: pass
281282
def reversed(object: Sequence[T]) -> Iterator[T]: ...
282283
def id(o: object) -> int: pass
283284
# This type is obviously wrong but the test stubs don't have Sized anymore

mypyc/test-data/irbuild-basic.test

+40
Original file line numberDiff line numberDiff line change
@@ -3287,6 +3287,46 @@ L10:
32873287
L11:
32883288
return r0
32893289

3290+
[case testSum]
3291+
from typing import Callable, Iterable
3292+
3293+
def call_sum(l: Iterable[int], comparison: Callable[[int], bool]) -> int:
3294+
return sum(comparison(x) for x in l)
3295+
3296+
[out]
3297+
def call_sum(l, comparison):
3298+
l, comparison :: object
3299+
r0 :: int
3300+
r1, r2 :: object
3301+
r3, x :: int
3302+
r4, r5 :: object
3303+
r6 :: bool
3304+
r7 :: object
3305+
r8, r9 :: int
3306+
r10 :: bit
3307+
L0:
3308+
r0 = 0
3309+
r1 = PyObject_GetIter(l)
3310+
L1:
3311+
r2 = PyIter_Next(r1)
3312+
if is_error(r2) goto L4 else goto L2
3313+
L2:
3314+
r3 = unbox(int, r2)
3315+
x = r3
3316+
r4 = box(int, x)
3317+
r5 = PyObject_CallFunctionObjArgs(comparison, r4, 0)
3318+
r6 = unbox(bool, r5)
3319+
r7 = box(bool, r6)
3320+
r8 = unbox(int, r7)
3321+
r9 = CPyTagged_Add(r0, r8)
3322+
r0 = r9
3323+
L3:
3324+
goto L1
3325+
L4:
3326+
r10 = CPy_NoErrOccured()
3327+
L5:
3328+
return r0
3329+
32903330
[case testSetAttr1]
32913331
from typing import Any, Dict, List
32923332
def lol(x: Any):

mypyc/test-data/run-misc.test

+45-1
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,51 @@ assert call_all(mixed_110) == 1
831831
assert call_any_nested([[1, 1, 1], [1, 1], []]) == 1
832832
assert call_any_nested([[1, 1, 1], [0, 1], []]) == 0
833833

834+
[case testSum]
835+
[typing fixtures/typing-full.pyi]
836+
from typing import Any, List
837+
838+
def test_sum_of_numbers() -> None:
839+
assert sum(x for x in [1, 2, 3]) == 6
840+
assert sum(x for x in [0.0, 1.2, 2]) == 6.2
841+
assert sum(x for x in [1, 1j]) == 1 + 1j
842+
843+
def test_sum_callables() -> None:
844+
assert sum((lambda x: x == 0)(x) for x in []) == 0
845+
assert sum((lambda x: x == 0)(x) for x in [0]) == 1
846+
assert sum((lambda x: x == 0)(x) for x in [0, 0, 0]) == 3
847+
assert sum((lambda x: x == 0)(x) for x in [0, 1, 0]) == 2
848+
assert sum((lambda x: x % 2 == 0)(x) for x in range(2**10)) == 2**9
849+
850+
def test_sum_comparisons() -> None:
851+
assert sum(x == 0 for x in []) == 0
852+
assert sum(x == 0 for x in [0]) == 1
853+
assert sum(x == 0 for x in [0, 0, 0]) == 3
854+
assert sum(x == 0 for x in [0, 1, 0]) == 2
855+
assert sum(x % 2 == 0 for x in range(2**10)) == 2**9
856+
857+
def test_sum_multi() -> None:
858+
assert sum(i + j == 0 for i, j in zip([0, 0, 0], [0, 1, 0])) == 2
859+
860+
def test_sum_misc() -> None:
861+
# misc cases we do optimize (note, according to sum's helptext, we don't need to support
862+
# non-numeric cases, but CPython and mypyc both do anyway)
863+
assert sum(c == 'd' for c in 'abcdd') == 2
864+
# misc cases we do not optimize
865+
assert sum([0, 1]) == 1
866+
assert sum([0, 1], 1) == 2
867+
868+
def test_sum_start_given() -> None:
869+
a = 1
870+
assert sum((x == 0 for x in [0, 1]), a) == 2
871+
assert sum(((lambda x: x == 0)(x) for x in []), 1) == 1
872+
assert sum(((lambda x: x == 0)(x) for x in [0]), 1) == 2
873+
assert sum(((lambda x: x == 0)(x) for x in [0, 0, 0]), 1) == 4
874+
assert sum(((lambda x: x == 0)(x) for x in [0, 1, 0]), 1) == 3
875+
assert sum(((lambda x: x % 2 == 0)(x) for x in range(2**10)), 1) == 2**9 + 1
876+
assert sum((x for x in [1, 1j]), 2j) == 1 + 3j
877+
assert sum((c == 'd' for c in 'abcdd'), 1) == 3
878+
834879
[case testNoneStuff]
835880
from typing import Optional
836881
class A:
@@ -845,7 +890,6 @@ def none() -> None:
845890
def arg(x: Optional[A]) -> bool:
846891
return x is None
847892

848-
849893
[file driver.py]
850894
import native
851895
native.lol(native.A())

0 commit comments

Comments
 (0)