From 11b4bff55f74dec1eedcdd8471e4fe1b40b59676 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Wed, 14 Jun 2023 16:20:54 -0500 Subject: [PATCH 1/7] fixes --- base/sort.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/base/sort.jl b/base/sort.jl index 99f2ed3e1aeb8..6891ad6620f9c 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -524,13 +524,13 @@ struct WithoutMissingVector{T, U} <: AbstractVector{T} new{nonmissingtype(eltype(data)), typeof(data)}(data) end end -Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i) - out = v.data[i] +Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i::Integer) + out = v.data[i::Integer] @assert !(out isa Missing) out::eltype(v) end -Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i) - v.data[i] = x +Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i::Integer) + v.data[i::Integer] = x v end Base.size(v::WithoutMissingVector) = size(v.data) @@ -590,8 +590,9 @@ function _sort!(v::AbstractVector, a::MissingOptimization, o::Ordering, kw) # we can assume v is equal to eachindex(o.data) which allows a copying partition # without allocations. lo_i, hi_i = lo, hi - for i in eachindex(o.data) # equal to copy(v) - x = o.data[i] + cv = eachindex(o.data) # equal to copy(v) + for i in lo:hi + x = o.data[cv[i]] if ismissing(x) == (o.order == Reverse) # should x go at the beginning/end? v[lo_i] = i lo_i += 1 From 9211d88ae2ea8948aeb3afc05cc1e7518eed3f45 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Wed, 14 Jun 2023 16:40:36 -0500 Subject: [PATCH 2/7] add tests --- test/sorting.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/sorting.jl b/test/sorting.jl index cf98182307088..543869e99c689 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -1025,6 +1025,38 @@ Base.similar(A::MyArray49392, ::Type{T}, dims::Dims{N}) where {T, N} = MyArray49 @test all(sort!(y, dims=2) .== sort!(x,dims=2)) end +@testset "MissingOptimization fastpath for Perm ordering when lo:hi ≠ eachindex(v)" begin + v = [rand() < .5 ? missing : rand() for _ in 1:100] + ix = collect(1:100) + sort!(ix, 1, 10, Base.Sort.DEFAULT_STABLE, Base.Order.Perm(Base.Order.Forward, v)) + @test issorted(v[ix[1:10]]) +end + +struct NonScalarIndexingOfWithoutMissingVectorAlgorithm <: Base.Sort.Algorithm end +function Base.Sort._sort!(v::AbstractVector, ::NonScalarIndexingOfWithoutMissingVectorAlgorithm, o::Base.Order.Ordering, kw) + Base.Sort.@getkw lo hi + first_half = v[lo:lo+(hi-lo)÷2] + second_half = v[lo+(hi-lo)÷2+1:hi] + whole = v[lo:hi] + all(vcat(first_half, second_half) .=== whole) || error() + out = Base.Sort._sort!(whole, Base.Sort.DEFAULT_STABLE, o, (;kw..., lo=1, hi=length(whole))) + v[lo:hi] .= whole + out +end + +@testset "Non-scaler indexing of WithoutMissingVector" begin + @testset "Unit test" begin + wmv = Base.Sort.WithoutMissingVector(Union{Missing, Int}[1, 7, 2, 9]) + @test wmv[[1, 3]] == [1, 2] + @test wmv[1:3] == [1, 7, 2] + end + @testset "End to end" begin + alg = Base.Sort.InitialOptimizations(NonScalarIndexingOfWithoutMissingVectorAlgorithm()) + @test issorted(sort(rand(100); alg)) + @test issorted(sort([rand() < .5 ? missing : randstring() for _ in 1:100]; alg)) + end +end + # This testset is at the end of the file because it is slow. @testset "searchsorted" begin numTypes = [ Int8, Int16, Int32, Int64, Int128, From d91a81624b7907aef78bbd055200e26806c6facb Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Wed, 14 Jun 2023 16:57:24 -0500 Subject: [PATCH 3/7] better error message --- base/sort.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/sort.jl b/base/sort.jl index 6891ad6620f9c..e1424e1f4d663 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -2168,7 +2168,7 @@ function _sort!(v::AbstractVector, a::Algorithm, o::Ordering, kw) scratch else # This error prevents infinite recursion for unknown algorithms - throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o))) is not defined")) + throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o)), ::Any) is not defined")) end end From 90a4ab362060ce6753e5ddabb1675725e2b3bdc0 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Wed, 14 Jun 2023 17:14:49 -0500 Subject: [PATCH 4/7] remove unnecessary type annotations --- base/sort.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/base/sort.jl b/base/sort.jl index e1424e1f4d663..df9b082bf05db 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -525,12 +525,12 @@ struct WithoutMissingVector{T, U} <: AbstractVector{T} end end Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i::Integer) - out = v.data[i::Integer] + out = v.data[i] @assert !(out isa Missing) out::eltype(v) end Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i::Integer) - v.data[i::Integer] = x + v.data[i] = x v end Base.size(v::WithoutMissingVector) = size(v.data) From b4aeb27682d8e519f4ea0839a87a3a53f46de3f3 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Wed, 14 Jun 2023 19:35:31 -0500 Subject: [PATCH 5/7] make dispatch loop detection more permissive --- base/sort.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/base/sort.jl b/base/sort.jl index df9b082bf05db..8ef4eef65e8b1 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -436,7 +436,7 @@ for (sym, exp, type) in [ (:mn, :(throw(ArgumentError("mn is needed but has not been computed"))), :(eltype(v))), (:mx, :(throw(ArgumentError("mx is needed but has not been computed"))), :(eltype(v))), (:scratch, nothing, :(Union{Nothing, Vector})), # could have different eltype - (:allow_legacy_dispatch, true, Bool)] + (:legacy_dispatch_entry, nothing, Union{Nothing, Algorithm})] usym = Symbol(:_, sym) @eval function $usym(v, o, kw) # using missing instead of nothing because scratch could === nothing. @@ -2150,25 +2150,25 @@ end # Support 3-, 5-, and 6-argument versions of sort! for calling into the internals in the old way sort!(v::AbstractVector, a::Algorithm, o::Ordering) = sort!(v, firstindex(v), lastindex(v), a, o) function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering) - _sort!(v, a, o, (; lo, hi, allow_legacy_dispatch=false)) + _sort!(v, a, o, (; lo, hi, legacy_dispatch_entry=a)) v end sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering, _) = sort!(v, lo, hi, a, o) function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering, scratch::Vector) - _sort!(v, a, o, (; lo, hi, scratch, allow_legacy_dispatch=false)) + _sort!(v, a, o, (; lo, hi, scratch, legacy_dispatch_entry=a)) v end # Support dispatch on custom algorithms in the old way # sort!(::AbstractVector, ::Integer, ::Integer, ::MyCustomAlgorithm, ::Ordering) = ... function _sort!(v::AbstractVector, a::Algorithm, o::Ordering, kw) - @getkw lo hi scratch allow_legacy_dispatch - if allow_legacy_dispatch - sort!(v, lo, hi, a, o) - scratch - else + @getkw lo hi scratch legacy_dispatch_entry + if legacy_dispatch_entry === a # This error prevents infinite recursion for unknown algorithms throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o)), ::Any) is not defined")) + else + sort!(v, lo, hi, a, o) + scratch end end From e879d8e841b5b07cbfd25dbbda3db92a9303410d Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Wed, 14 Jun 2023 19:51:22 -0500 Subject: [PATCH 6/7] add test --- test/sorting.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/sorting.jl b/test/sorting.jl index 543869e99c689..147a70a5db7d9 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -1032,8 +1032,8 @@ end @test issorted(v[ix[1:10]]) end -struct NonScalarIndexingOfWithoutMissingVectorAlgorithm <: Base.Sort.Algorithm end -function Base.Sort._sort!(v::AbstractVector, ::NonScalarIndexingOfWithoutMissingVectorAlgorithm, o::Base.Order.Ordering, kw) +struct NonScalarIndexingOfWithoutMissingVectorAlg <: Base.Sort.Algorithm end +function Base.Sort._sort!(v::AbstractVector, ::NonScalarIndexingOfWithoutMissingVectorAlg, o::Base.Order.Ordering, kw) Base.Sort.@getkw lo hi first_half = v[lo:lo+(hi-lo)÷2] second_half = v[lo+(hi-lo)÷2+1:hi] @@ -1051,12 +1051,20 @@ end @test wmv[1:3] == [1, 7, 2] end @testset "End to end" begin - alg = Base.Sort.InitialOptimizations(NonScalarIndexingOfWithoutMissingVectorAlgorithm()) + alg = Base.Sort.InitialOptimizations(NonScalarIndexingOfWithoutMissingVectorAlg()) @test issorted(sort(rand(100); alg)) @test issorted(sort([rand() < .5 ? missing : randstring() for _ in 1:100]; alg)) end end +struct DispatchLoopTestAlg <: Base.Sort.Algorithm end +function Base.sort!(v::AbstractVector, lo::Integer, hi::Integer, ::DispatchLoopTestAlg, order::Base.Order.Ordering) + sort!(view(v, lo:hi); order) +end +@testset "Support dispatch from the old style to the new style and back" begin + @test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward)) +end + # This testset is at the end of the file because it is slow. @testset "searchsorted" begin numTypes = [ Int8, Int16, Int32, Int64, Int128, From 3c862b3c826a9dcc12fc3e940511811bf9f8091a Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Wed, 14 Jun 2023 20:22:53 -0500 Subject: [PATCH 7/7] move Algorithm declaration to top of file to avoid user before definition error --- base/sort.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/base/sort.jl b/base/sort.jl index 8ef4eef65e8b1..90f8755d3b1a4 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -44,6 +44,7 @@ export # not exported by Base SMALL_ALGORITHM, SMALL_THRESHOLD +abstract type Algorithm end ## functions requiring only ordering ## @@ -499,8 +500,6 @@ internal or recursive calls. """ function _sort! end -abstract type Algorithm end - """ MissingOptimization(next) <: Algorithm