Skip to content

Improve foldl's stability on nested Iterators. #45789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ julia> [1:5;] .|> (x -> x^2) |> sum |> inv
"""
|>(x, f) = f(x)

_stable_typeof(x) = typeof(x)
_stable_typeof(::Type{T}) where {T} = @isdefined(T) ? Type{T} : DataType

"""
f = Returns(value)

Expand All @@ -928,7 +931,7 @@ julia> f.value
struct Returns{V} <: Function
value::V
Returns{V}(value) where {V} = new{V}(value)
Returns(value) = new{Core.Typeof(value)}(value)
Returns(value) = new{_stable_typeof(value)}(value)
end

(obj::Returns)(@nospecialize(args...); @nospecialize(kw...)) = obj.value
Expand Down Expand Up @@ -1014,7 +1017,19 @@ struct ComposedFunction{O,I} <: Function
ComposedFunction(outer, inner) = new{Core.Typeof(outer),Core.Typeof(inner)}(outer, inner)
end

(c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...))
function (c::ComposedFunction)(x...; kw...)
fs = unwrap_composed(c)
call_composed(fs[1](x...; kw...), tail(fs)...)
end
unwrap_composed(c::ComposedFunction) = (unwrap_composed(c.inner)..., unwrap_composed(c.outer)...)
unwrap_composed(c) = (maybeconstructor(c),)
call_composed(x, f, fs...) = (@inline; call_composed(f(x), fs...))
call_composed(x, f) = f(x)

struct Constructor{F} <: Function end
(::Constructor{F})(args...; kw...) where {F} = (@inline; F(args...; kw...))
maybeconstructor(::Type{F}) where {F} = Constructor{F}()
maybeconstructor(f) = f

∘(f) = f
∘(f, g) = ComposedFunction(f, g)
Expand Down Expand Up @@ -1078,8 +1093,8 @@ struct Fix1{F,T} <: Function
f::F
x::T

Fix1(f::F, x::T) where {F,T} = new{F,T}(f, x)
Fix1(f::Type{F}, x::T) where {F,T} = new{Type{F},T}(f, x)
Fix1(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x)
Fix1(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x)
end

(f::Fix1)(y) = f.f(f.x, y)
Expand All @@ -1095,8 +1110,8 @@ struct Fix2{F,T} <: Function
f::F
x::T

Fix2(f::F, x::T) where {F,T} = new{F,T}(f, x)
Fix2(f::Type{F}, x::T) where {F,T} = new{Type{F},T}(f, x)
Fix2(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x)
Fix2(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x)
end

(f::Fix2)(y) = f.f(y, f.x)
Expand Down
30 changes: 19 additions & 11 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,25 @@ what is returned is `itr′` and

op′ = (xfₙ ∘ ... ∘ xf₂ ∘ xf₁)(op)
"""
_xfadjoint(op, itr) = (op, itr)
_xfadjoint(op, itr::Generator) =
if itr.f === identity
_xfadjoint(op, itr.iter)
else
_xfadjoint(MappingRF(itr.f, op), itr.iter)
end
_xfadjoint(op, itr::Filter) =
_xfadjoint(FilteringRF(itr.flt, op), itr.itr)
_xfadjoint(op, itr::Flatten) =
_xfadjoint(FlatteningRF(op), itr.it)
function _xfadjoint(op, itr)
itr′, wrap = _xfadjoint_unwrap(itr)
wrap(op), itr′
end

_xfadjoint_unwrap(itr) = itr, identity
function _xfadjoint_unwrap(itr::Generator)
itr′, wrap = _xfadjoint_unwrap(itr.iter)
itr.f === identity && return itr′, wrap
return itr′, wrap ∘ Fix1(MappingRF, itr.f)
end
function _xfadjoint_unwrap(itr::Filter)
itr′, wrap = _xfadjoint_unwrap(itr.itr)
return itr′, wrap ∘ Fix1(FilteringRF, itr.flt)
end
function _xfadjoint_unwrap(itr::Flatten)
itr′, wrap = _xfadjoint_unwrap(itr.it)
return itr′, wrap ∘ FlatteningRF
end

"""
mapfoldl(f, op, itr; [init])
Expand Down
12 changes: 12 additions & 0 deletions test/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ Base.promote_rule(::Type{T19714}, ::Type{Int}) = T19714

end

@testset "Nested ComposedFunction's stability" begin
f(x) = (1, 1, x...)
g = (f ∘ (f ∘ f)) ∘ (f ∘ f ∘ f)
@test (@inferred (g∘g)(1)) == ntuple(Returns(1), 25)
@test (@inferred g(1)) == ntuple(Returns(1), 13)
h = (-) ∘ (-) ∘ (-) ∘ (-) ∘ (-) ∘ (-) ∘ sum
@test (@inferred h((1, 2, 3); init = 0.0)) == 6.0
end

@testset "function negation" begin
str = randstring(20)
@test filter(!isuppercase, str) == replace(str, r"[A-Z]" => "")
Expand Down Expand Up @@ -308,4 +317,7 @@ end
val = [1,2,3]
@test Returns(val)(1) === val
@test sprint(show, Returns(1.0)) == "Returns{Float64}(1.0)"

illtype = Vector{Core._typevar(:T, Union{}, Any)}
@test Returns(illtype) == Returns{DataType}(illtype)
end
13 changes: 13 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,16 @@ end
@test mapreduce(+, +, oa, oa) == 2len
end
end

# issue #45748
@testset "foldl's stability for nested Iterators" begin
a = Iterators.flatten((1:3, 1:3))
b = (2i for i in a if i > 0)
c = Base.Generator(Float64, b)
d = (sin(i) for i in c if i > 0)
@test @inferred(sum(d)) == sum(collect(d))
@test @inferred(extrema(d)) == extrema(collect(d))
@test @inferred(maximum(c)) == maximum(collect(c))
@test @inferred(prod(b)) == prod(collect(b))
@test @inferred(minimum(a)) == minimum(collect(a))
end