Skip to content

[mypyc] Use optimized implementation for builtins.sum #10268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from mypy.nodes import (
CallExpr, RefExpr, MemberExpr, NameExpr, TupleExpr, GeneratorExpr,
ListExpr, DictExpr, StrExpr, ARG_POS, Expression
ListExpr, DictExpr, StrExpr, IntExpr, ARG_POS, ARG_NAMED, Expression
)
from mypy.types import AnyType, TypeOfAny

Expand Down Expand Up @@ -188,7 +188,6 @@ def translate_set_from_generator_call(
@specialize_function('builtins.tuple')
@specialize_function('builtins.frozenset')
@specialize_function('builtins.dict')
@specialize_function('builtins.sum')
@specialize_function('builtins.min')
@specialize_function('builtins.max')
@specialize_function('builtins.sorted')
Expand Down Expand Up @@ -266,6 +265,41 @@ def gen_inner_stmts() -> None:
return retval


@specialize_function('builtins.sum')
def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]:
# specialized implementation is used if:
# - only one or two arguments given (if not, sum() has been given invalid arguments)
# - first argument is a Generator (there is no benefit to optimizing the performance of eg.
# sum([1, 2, 3]), so non-Generator Iterables are not handled)
if not (len(expr.args) in (1, 2)
and expr.arg_kinds[0] == ARG_POS
and isinstance(expr.args[0], GeneratorExpr)):
return None

# handle 'start' argument, if given
if len(expr.args) == 2:
# ensure call to sum() was properly constructed
if not expr.arg_kinds[1] in (ARG_POS, ARG_NAMED):
return None
start_expr = expr.args[1]
else:
start_expr = IntExpr(0)

gen_expr = expr.args[0]
target_type = builder.node_type(expr)
retval = Register(target_type)
builder.assign(retval, builder.coerce(builder.accept(start_expr), target_type, -1), -1)

def gen_inner_stmts() -> None:
call_expr = builder.accept(gen_expr.left_expr)
builder.assign(retval, builder.binary_op(retval, call_expr, '+', -1), -1)

loop_params = list(zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists))
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)

return retval


@specialize_function('dataclasses.field')
@specialize_function('attr.Factory')
def translate_dataclasses_field_call(
Expand Down Expand Up @@ -295,12 +329,8 @@ def translate_next_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
return None

gen = expr.args[0]

retval = Register(builder.node_type(expr))
default_val = None
if len(expr.args) > 1:
default_val = builder.accept(expr.args[1])

default_val = builder.accept(expr.args[1]) if len(expr.args) > 1 else None
exit_block = BasicBlock()

def gen_inner_stmts() -> None:
Expand Down
1 change: 1 addition & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class ZeroDivisionError(Exception): pass

def any(i: Iterable[T]) -> bool: pass
def all(i: Iterable[T]) -> bool: pass
def sum(i: Iterable[T]) -> int: pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this will need to take a start argument. It looks like that is the cause of a bunch of the test failures.

def reversed(object: Sequence[T]) -> Iterator[T]: ...
def id(o: object) -> int: pass
# This type is obviously wrong but the test stubs don't have Sized anymore
Expand Down
40 changes: 40 additions & 0 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3287,6 +3287,46 @@ L10:
L11:
return r0

[case testSum]
from typing import Callable, Iterable

def call_sum(l: Iterable[int], comparison: Callable[[int], bool]) -> int:
return sum(comparison(x) for x in l)

[out]
def call_sum(l, comparison):
l, comparison :: object
r0 :: int
r1, r2 :: object
r3, x :: int
r4, r5 :: object
r6 :: bool
r7 :: object
r8, r9 :: int
r10 :: bit
L0:
r0 = 0
r1 = PyObject_GetIter(l)
L1:
r2 = PyIter_Next(r1)
if is_error(r2) goto L4 else goto L2
L2:
r3 = unbox(int, r2)
x = r3
r4 = box(int, x)
r5 = PyObject_CallFunctionObjArgs(comparison, r4, 0)
r6 = unbox(bool, r5)
r7 = box(bool, r6)
r8 = unbox(int, r7)
r9 = CPyTagged_Add(r0, r8)
r0 = r9
L3:
goto L1
L4:
r10 = CPy_NoErrOccured()
L5:
return r0

[case testSetAttr1]
from typing import Any, Dict, List
def lol(x: Any):
Expand Down
46 changes: 45 additions & 1 deletion mypyc/test-data/run-misc.test
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,51 @@ assert call_all(mixed_110) == 1
assert call_any_nested([[1, 1, 1], [1, 1], []]) == 1
assert call_any_nested([[1, 1, 1], [0, 1], []]) == 0

[case testSum]
[typing fixtures/typing-full.pyi]
from typing import Any, List
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding [typing fixtures/typing-full.pyi] here is the simplest way to solve test errors.


def test_sum_of_numbers() -> None:
assert sum(x for x in [1, 2, 3]) == 6
assert sum(x for x in [0.0, 1.2, 2]) == 6.2
assert sum(x for x in [1, 1j]) == 1 + 1j

def test_sum_callables() -> None:
assert sum((lambda x: x == 0)(x) for x in []) == 0
assert sum((lambda x: x == 0)(x) for x in [0]) == 1
assert sum((lambda x: x == 0)(x) for x in [0, 0, 0]) == 3
assert sum((lambda x: x == 0)(x) for x in [0, 1, 0]) == 2
assert sum((lambda x: x % 2 == 0)(x) for x in range(2**10)) == 2**9

def test_sum_comparisons() -> None:
assert sum(x == 0 for x in []) == 0
assert sum(x == 0 for x in [0]) == 1
assert sum(x == 0 for x in [0, 0, 0]) == 3
assert sum(x == 0 for x in [0, 1, 0]) == 2
assert sum(x % 2 == 0 for x in range(2**10)) == 2**9

def test_sum_multi() -> None:
assert sum(i + j == 0 for i, j in zip([0, 0, 0], [0, 1, 0])) == 2

def test_sum_misc() -> None:
# misc cases we do optimize (note, according to sum's helptext, we don't need to support
# non-numeric cases, but CPython and mypyc both do anyway)
assert sum(c == 'd' for c in 'abcdd') == 2
# misc cases we do not optimize
assert sum([0, 1]) == 1
assert sum([0, 1], 1) == 2

def test_sum_start_given() -> None:
a = 1
assert sum((x == 0 for x in [0, 1]), a) == 2
assert sum(((lambda x: x == 0)(x) for x in []), 1) == 1
assert sum(((lambda x: x == 0)(x) for x in [0]), 1) == 2
assert sum(((lambda x: x == 0)(x) for x in [0, 0, 0]), 1) == 4
assert sum(((lambda x: x == 0)(x) for x in [0, 1, 0]), 1) == 3
assert sum(((lambda x: x % 2 == 0)(x) for x in range(2**10)), 1) == 2**9 + 1
assert sum((x for x in [1, 1j]), 2j) == 1 + 3j
assert sum((c == 'd' for c in 'abcdd'), 1) == 3

[case testNoneStuff]
from typing import Optional
class A:
Expand All @@ -845,7 +890,6 @@ def none() -> None:
def arg(x: Optional[A]) -> bool:
return x is None


[file driver.py]
import native
native.lol(native.A())
Expand Down