diff --git a/base/iterators.jl b/base/iterators.jl index a03d426e05622..b51920cdddb68 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -973,7 +973,7 @@ cycle(xs) = Cycle(xs) eltype(::Type{Cycle{I}}) where {I} = eltype(I) IteratorEltype(::Type{Cycle{I}}) where {I} = IteratorEltype(I) -IteratorSize(::Type{Cycle{I}}) where {I} = IsInfinite() +IteratorSize(::Type{Cycle{I}}) where {I} = IsInfinite() # XXX: this is false if iterator ever becomes empty iterate(it::Cycle) = iterate(it.xs) isdone(it::Cycle) = isdone(it.xs) @@ -1422,43 +1422,30 @@ julia> sum(a) # Sum the remaining elements 7 ``` """ -mutable struct Stateful{T, VS, N<:Integer} +mutable struct Stateful{T, VS} itr::T # A bit awkward right now, but adapted to the new iteration protocol nextvalstate::Union{VS, Nothing} - - # Number of remaining elements, if itr is HasLength or HasShape. - # if not, store -1 - number_of_consumed_elements. - # This allows us to defer calculating length until asked for. - # See PR #45924 - remaining::N @inline function Stateful{<:Any, Any}(itr::T) where {T} - itl = iterlength(itr) - new{T, Any, typeof(itl)}(itr, iterate(itr), itl) + return new{T, Any}(itr, iterate(itr)) end @inline function Stateful(itr::T) where {T} VS = approx_iter_type(T) - itl = iterlength(itr) - return new{T, VS, typeof(itl)}(itr, iterate(itr)::VS, itl) + return new{T, VS}(itr, iterate(itr)::VS) end end -function iterlength(it)::Signed - if IteratorSize(it) isa Union{HasShape, HasLength} - return length(it) - else - -1 - end +function reset!(s::Stateful) + setfield!(s, :nextvalstate, iterate(s.itr)) # bypass convert call of setproperty! + return s end - -function reset!(s::Stateful{T,VS}, itr::T=s.itr) where {T,VS} +function reset!(s::Stateful{T}, itr::T) where {T} s.itr = itr - itl = iterlength(itr) - setfield!(s, :nextvalstate, iterate(itr)) - s.remaining = itl - s + reset!(s) + return s end + # Try to find an appropriate type for the (value, state tuple), # by doing a recursive unrolling of the iteration protocol up to # fixpoint. @@ -1480,7 +1467,6 @@ end Stateful(x::Stateful) = x convert(::Type{Stateful}, itr) = Stateful(itr) - @inline isdone(s::Stateful, st=nothing) = s.nextvalstate === nothing @inline function popfirst!(s::Stateful) @@ -1490,8 +1476,6 @@ convert(::Type{Stateful}, itr) = Stateful(itr) else val, state = vs Core.setfield!(s, :nextvalstate, iterate(s.itr, state)) - rem = s.remaining - s.remaining = rem - typeof(rem)(1) return val end end @@ -1501,20 +1485,10 @@ end return ns !== nothing ? ns[1] : sentinel end @inline iterate(s::Stateful, state=nothing) = s.nextvalstate === nothing ? nothing : (popfirst!(s), nothing) -IteratorSize(::Type{<:Stateful{T}}) where {T} = IteratorSize(T) isa HasShape ? HasLength() : IteratorSize(T) +IteratorSize(::Type{<:Stateful{T}}) where {T} = IteratorSize(T) isa IsInfinite ? IsInfinite() : SizeUnknown() eltype(::Type{<:Stateful{T}}) where {T} = eltype(T) IteratorEltype(::Type{<:Stateful{T}}) where {T} = IteratorEltype(T) -function length(s::Stateful) - rem = s.remaining - # If rem is actually remaining length, return it. - # else, rem is number of consumed elements. - if rem >= 0 - rem - else - length(s.itr) - (typeof(rem)(1) - rem) - end -end end # if statement several hundred lines above """ diff --git a/test/iterators.jl b/test/iterators.jl index d8184eab7b656..6fd308a31d746 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -853,8 +853,10 @@ end v, s = iterate(z) @test Base.isdone(z, s) end - # Stateful wrapping mutable iterators of known length (#43245) - @test length(Iterators.Stateful(Iterators.Stateful(1:5))) == 5 + # Stateful does not define length + let s = Iterators.Stateful(Iterators.Stateful(1:5)) + @test_throws MethodError length(s) + end end @testset "pair for Svec" begin @@ -866,6 +868,10 @@ end @testset "inference for large zip #26765" begin x = zip(1:2, ["a", "b"], (1.0, 2.0), Base.OneTo(2), Iterators.repeated("a"), 1.0:0.2:2.0, (1 for i in 1:2), Iterators.Stateful(["a", "b", "c"]), (1.0 for i in 1:2, j in 1:3)) + @test Base.IteratorSize(x) isa Base.SizeUnknown + x = zip(1:2, ["a", "b"], (1.0, 2.0), Base.OneTo(2), Iterators.repeated("a"), 1.0:0.2:2.0, + (1 for i in 1:2), Iterators.cycle(Iterators.Stateful(["a", "b", "c"])), (1.0 for i in 1:2, j in 1:3)) + @test Base.IteratorSize(x) isa Base.HasLength @test @inferred(length(x)) == 2 z = Iterators.filter(x -> x[1] >= 1, x) @test @inferred(eltype(z)) <: Tuple{Int,String,Float64,Int,String,Float64,Any,String,Any} @@ -874,20 +880,20 @@ end end @testset "Stateful fix #30643" begin - @test Base.IteratorSize(1:10) isa Base.HasShape + @test Base.IteratorSize(1:10) isa Base.HasShape{1} a = Iterators.Stateful(1:10) - @test Base.IteratorSize(a) isa Base.HasLength - @test length(a) == 10 + @test Base.IteratorSize(a) isa Base.SizeUnknown + @test !Base.isdone(a) @test length(collect(a)) == 10 - @test length(a) == 0 + @test Base.isdone(a) b = Iterators.Stateful(Iterators.take(1:10,3)) - @test Base.IteratorSize(b) isa Base.HasLength - @test length(b) == 3 + @test Base.IteratorSize(b) isa Base.SizeUnknown + @test !Base.isdone(b) @test length(collect(b)) == 3 - @test length(b) == 0 + @test Base.isdone(b) c = Iterators.Stateful(Iterators.countfrom(1)) @test Base.IteratorSize(c) isa Base.IsInfinite - @test length(Iterators.take(c,3)) == 3 + @test !Base.isdone(Iterators.take(c,3)) @test length(collect(Iterators.take(c,3))) == 3 d = Iterators.Stateful(Iterators.filter(isodd,1:10)) @test Base.IteratorSize(d) isa Base.SizeUnknown @@ -1010,6 +1016,11 @@ end @test collect(Iterators.partition(lstrip("01111", '0'), 2)) == ["11", "11"] end +let itr = (i for i in 1:9) # Base.eltype == Any + @test first(Iterators.partition(itr, 3)) isa Vector{Any} + @test collect(zip(repeat([Iterators.Stateful(itr)], 3)...)) == [(1, 2, 3), (4, 5, 6), (7, 8, 9)] +end + @testset "no single-argument map methods" begin maps = (tuple, Returns(nothing), (() -> nothing)) mappers = (Iterators.map, map, foreach)