Skip to content

Commit ba80201

Browse files
petvanaKristofferC
authored andcommitted
Fix sum() and prod() for tuples (#41510)
This PR aims to fix #39182 and #39183 by using the universal implementation of `prod` and `sum` from https://github.com/JuliaLang/julia/blob/97f817a379b0c3c5f9bb803427fe88a018ebfe18/base/reduce.jl#L588 However, the file `abstractarray.jl` is included way sooner, and it is crucial to have already a simplified version of `prod` function. We can specify a simplified version or `prod` only for a system-wide `Int` type that is sufficient to compile `Base`. ``` julia prod(x::Tuple{}) = 1 # This is consistent with the regular prod because there is no need for size promotion # if all elements in the tuple are of system size. prod(x::Tuple{Int, Vararg{Int}}) = *(x...) ``` Although the implementations are different, they lead to the same binary code for tuples containing ~~`UInt` and~~ `Int`. ``` julia julia> a = (1,2,3) (1, 2, 3) # Simplified version for tuples containing Int only julia> prod_simplified(x::Tuple{Int, Vararg{Int}}) = *(x...) julia> @code_native prod_simplified(a) .text ; ┌ @ REPL[1]:1 within `prod_simplified' ; │┌ @ operators.jl:560 within `*' @ int.jl:88 movq 8(%rdi), %rax imulq (%rdi), %rax imulq 16(%rdi), %rax ; │└ retq nop ; └ ``` ``` julia # Regular prod without the simplification julia> @code_native prod(a) .text ; ┌ @ reduce.jl:588 within `prod` ; │┌ @ reduce.jl:588 within `#prod#247` ; ││┌ @ reduce.jl:289 within `mapreduce` ; │││┌ @ reduce.jl:289 within `#mapreduce#240` ; ││││┌ @ reduce.jl:162 within `mapfoldl` ; │││││┌ @ reduce.jl:162 within `#mapfoldl#236` ; ││││││┌ @ reduce.jl:44 within `mapfoldl_impl` ; │││││││┌ @ reduce.jl:48 within `foldl_impl` ; ││││││││┌ @ tuple.jl:276 within `_foldl_impl` ; │││││││││┌ @ operators.jl:613 within `afoldl` ; ││││││││││┌ @ reduce.jl:81 within `BottomRF` ; │││││││││││┌ @ reduce.jl:38 within `mul_prod` ; ││││││││││││┌ @ int.jl:88 within `*` movq 8(%rdi), %rax imulq (%rdi), %rax ; │││││││││└└└└ ; │││││││││┌ @ operators.jl:614 within `afoldl` ; ││││││││││┌ @ reduce.jl:81 within `BottomRF` ; │││││││││││┌ @ reduce.jl:38 within `mul_prod` ; ││││││││││││┌ @ int.jl:88 within `*` imulq 16(%rdi), %rax ; │└└└└└└└└└└└└ retq nop ; └ ``` (cherry picked from commit bada80c)
1 parent 35f675d commit ba80201

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

base/tuple.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -489,17 +489,12 @@ reverse(t::Tuple) = revargs(t...)
489489

490490
## specialized reduction ##
491491

492-
# TODO: these definitions cannot yet be combined, since +(x...)
493-
# where x might be any tuple matches too many methods.
494-
# TODO: this is inconsistent with the regular sum in cases where the arguments
495-
# require size promotion to system size.
496-
sum(x::Tuple{Any, Vararg{Any}}) = +(x...)
497-
498-
# NOTE: should remove, but often used on array sizes
499-
# TODO: this is inconsistent with the regular prod in cases where the arguments
500-
# require size promotion to system size.
501492
prod(x::Tuple{}) = 1
502-
prod(x::Tuple{Any, Vararg{Any}}) = *(x...)
493+
# This is consistent with the regular prod because there is no need for size promotion
494+
# if all elements in the tuple are of system size.
495+
# It is defined here separately in order to support bootstrap, because it's needed earlier
496+
# than the general prod definition is available.
497+
prod(x::Tuple{Int, Vararg{Int}}) = *(x...)
503498

504499
all(x::Tuple{}) = true
505500
all(x::Tuple{Bool}) = x[1]

test/tuple.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,24 @@ end
361361
@test prod(()) === 1
362362
@test prod((1,2,3)) === 6
363363

364+
# issue 39182
365+
@test sum((0xe1, 0x1f)) === sum([0xe1, 0x1f])
366+
@test sum((Int8(3),)) === Int(3)
367+
@test sum((UInt8(3),)) === UInt(3)
368+
@test sum((3,)) === Int(3)
369+
@test sum((3.0,)) === 3.0
370+
@test sum(("a",)) == sum(["a"])
371+
@test sum((0xe1, 0x1f), init=0x0) == sum([0xe1, 0x1f], init=0x0)
372+
373+
# issue 39183
374+
@test prod((Int8(100), Int8(100))) === 10000
375+
@test prod((Int8(3),)) === Int(3)
376+
@test prod((UInt8(3),)) === UInt(3)
377+
@test prod((3,)) === Int(3)
378+
@test prod((3.0,)) === 3.0
379+
@test prod(("a",)) == prod(["a"])
380+
@test prod((0xe1, 0x1f), init=0x1) == prod([0xe1, 0x1f], init=0x1)
381+
364382
@testset "all" begin
365383
@test all(()) === true
366384
@test all((false,)) === false

0 commit comments

Comments
 (0)