diff --git a/base/range.jl b/base/range.jl index cee15db39b911..c4435f2ff3e97 100644 --- a/base/range.jl +++ b/base/range.jl @@ -850,6 +850,11 @@ first(r::OneTo{T}) where {T} = oneunit(T) first(r::StepRangeLen) = unsafe_getindex(r, 1) first(r::LinRange) = r.start +function first(r::OneTo, n::Integer) + n < 0 && throw(ArgumentError("Number of elements must be non-negative")) + OneTo(oftype(r.stop, min(r.stop, n))) +end + last(r::OrdinalRange{T}) where {T} = convert(T, r.stop) # via steprange_last last(r::StepRangeLen) = unsafe_getindex(r, length(r)) last(r::LinRange) = r.stop diff --git a/test/ranges.jl b/test/ranges.jl index 86cd1c3f2345c..629c2966b2fa6 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -1539,6 +1539,9 @@ end @test size(r) == (3,) @test step(r) == 1 @test first(r) == 1 + @test first(r,2) === Base.OneTo(2) + @test first(r,20) === r + @test_throws ArgumentError first(r,-20) @test last(r) == 3 @test minimum(r) == 1 @test maximum(r) == 3 @@ -1570,6 +1573,9 @@ end @test findall(in(2:(length(r) - 1)), r) === 2:(length(r) - 1) @test findall(in(r), 2:(length(r) - 1)) === 1:(length(r) - 2) end + let r = Base.OneTo(Int8(4)) + @test first(r,4) === r + end @test convert(Base.OneTo, 1:2) === Base.OneTo{Int}(2) @test_throws ArgumentError("first element must be 1, got 2") convert(Base.OneTo, 2:3) @test_throws ArgumentError("step must be 1, got 2") convert(Base.OneTo, 1:2:5)