diff --git a/libcxx/include/__algorithm/nth_element.h b/libcxx/include/__algorithm/nth_element.h index dbacf58f9ecdb..ebd1cbf76143d 100644 --- a/libcxx/include/__algorithm/nth_element.h +++ b/libcxx/include/__algorithm/nth_element.h @@ -13,6 +13,7 @@ #include <__algorithm/comp_ref_type.h> #include <__algorithm/iterator_operations.h> #include <__algorithm/sort.h> +#include <__assert> #include <__config> #include <__debug_utils/randomize_range.h> #include <__iterator/iterator_traits.h> @@ -42,6 +43,7 @@ __nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j, template _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void +// NOLINTNEXTLINE(readability-function-cognitive-complexity) __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp) { using _Ops = _IterOps<_AlgPolicy>; @@ -116,10 +118,18 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando return; } while (true) { - while (!__comp(*__first, *__i)) + while (!__comp(*__first, *__i)) { ++__i; - while (__comp(*__first, *--__j)) - ; + _LIBCPP_ASSERT_UNCATEGORIZED( + __i != __last, + "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?"); + } + do { + _LIBCPP_ASSERT_UNCATEGORIZED( + __j != __first, + "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?"); + --__j; + } while (__comp(*__first, *__j)); if (__i >= __j) break; _Ops::iter_swap(__i, __j); @@ -146,11 +156,19 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando while (true) { // __m still guards upward moving __i - while (__comp(*__i, *__m)) + while (__comp(*__i, *__m)) { ++__i; + _LIBCPP_ASSERT_UNCATEGORIZED( + __i != __last, + "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?"); + } // It is now known that a guard exists for downward moving __j - while (!__comp(*--__j, *__m)) - ; + do { + _LIBCPP_ASSERT_UNCATEGORIZED( + __j != __first, + "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?"); + --__j; + } while (!__comp(*__j, *__m)); if (__i >= __j) break; _Ops::iter_swap(__i, __j); diff --git a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp index e5e417fe7bda2..96c2821c4a654 100644 --- a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp +++ b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp @@ -50,24 +50,34 @@ #include "bad_comparator_values.h" #include "check_assertion.h" -void check_oob_sort_read() { - std::map> comparison_results; // terrible for performance, but really convenient - for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) { - auto values = std::views::split(line, ' '); - auto it = values.begin(); - std::size_t left = std::stol(std::string((*it).data(), (*it).size())); - it = std::next(it); - std::size_t right = std::stol(std::string((*it).data(), (*it).size())); - it = std::next(it); - bool result = static_cast(std::stol(std::string((*it).data(), (*it).size()))); - comparison_results[left][right] = result; - } - auto predicate = [&](std::size_t* left, std::size_t* right) { +class ComparisonResults { +public: + explicit ComparisonResults(std::string_view data) { + for (auto line : std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) { + auto values = std::views::split(line, ' '); + auto it = values.begin(); + std::size_t left = std::stol(std::string((*it).data(), (*it).size())); + it = std::next(it); + std::size_t right = std::stol(std::string((*it).data(), (*it).size())); + it = std::next(it); + bool result = static_cast(std::stol(std::string((*it).data(), (*it).size()))); + comparison_results[left][right] = result; + } + } + + bool compare(size_t* left, size_t* right) const { assert(left != nullptr && right != nullptr && "something is wrong with the test"); - assert(comparison_results.contains(*left) && comparison_results[*left].contains(*right) && "malformed input data?"); - return comparison_results[*left][*right]; - }; + assert(comparison_results.contains(*left) && comparison_results.at(*left).contains(*right) && "malformed input data?"); + return comparison_results.at(*left).at(*right); + } + size_t size() const { return comparison_results.size(); } +private: + std::map> comparison_results; // terrible for performance, but really convenient +}; + +void check_oob_sort_read() { + ComparisonResults comparison_results(SORT_DATA); std::vector> elements; std::set valid_ptrs; for (std::size_t i = 0; i != comparison_results.size(); ++i) { @@ -81,7 +91,7 @@ void check_oob_sort_read() { // because we're reading OOB. assert(valid_ptrs.contains(left)); assert(valid_ptrs.contains(right)); - return predicate(left, right); + return comparison_results.compare(left, right); }; // Check the classic sorting algorithms @@ -117,12 +127,6 @@ void check_oob_sort_read() { std::vector results(copy.size(), nullptr); TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate), "not a valid strict-weak ordering"); } - { - std::vector copy; - for (auto const& e : elements) - copy.push_back(e.get()); - std::nth_element(copy.begin(), copy.end(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator - } // Check the Ranges sorting algorithms { @@ -157,11 +161,38 @@ void check_oob_sort_read() { std::vector results(copy.size(), nullptr); TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort_copy(copy, results, checked_predicate), "not a valid strict-weak ordering"); } +} + +void check_oob_nth_element_read() { + ComparisonResults results(NTH_ELEMENT_DATA); + std::vector> elements; + std::set valid_ptrs; + for (std::size_t i = 0; i != results.size(); ++i) { + elements.push_back(std::make_unique(i)); + valid_ptrs.insert(elements.back().get()); + } + + auto checked_predicate = [&](size_t* left, size_t* right) { + // If the pointers passed to the comparator are not in the set of pointers we + // set up above, then we're being passed garbage values from the algorithm + // because we're reading OOB. + assert(valid_ptrs.contains(left)); + assert(valid_ptrs.contains(right)); + return results.compare(left, right); + }; + { std::vector copy; for (auto const& e : elements) copy.push_back(e.get()); - std::ranges::nth_element(copy, copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator + TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.begin(), copy.end(), checked_predicate), "Would read out of bounds"); + } + + { + std::vector copy; + for (auto const& e : elements) + copy.push_back(e.get()); + TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.begin(), checked_predicate), "Would read out of bounds"); } } @@ -214,6 +245,8 @@ int main(int, char**) { check_oob_sort_read(); + check_oob_nth_element_read(); + check_nan_floats(); check_irreflexive(); diff --git a/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h b/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h index 19ea023419ea9..c0ffd16cd4ac4 100644 --- a/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h +++ b/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h @@ -11,7 +11,74 @@ #include -inline constexpr std::string_view DATA = R"( +inline constexpr std::string_view NTH_ELEMENT_DATA = R"( +0 0 0 +0 1 0 +0 2 0 +0 3 0 +0 4 1 +0 5 0 +0 6 0 +0 7 0 +1 0 0 +1 1 0 +1 2 0 +1 3 1 +1 4 1 +1 5 1 +1 6 1 +1 7 1 +2 0 1 +2 1 1 +2 2 1 +2 3 1 +2 4 1 +2 5 1 +2 6 1 +2 7 1 +3 0 1 +3 1 1 +3 2 1 +3 3 1 +3 4 1 +3 5 1 +3 6 1 +3 7 1 +4 0 1 +4 1 1 +4 2 1 +4 3 1 +4 4 1 +4 5 1 +4 6 1 +4 7 1 +5 0 1 +5 1 1 +5 2 1 +5 3 1 +5 4 1 +5 5 1 +5 6 1 +5 7 1 +6 0 1 +6 1 1 +6 2 1 +6 3 1 +6 4 1 +6 5 1 +6 6 1 +6 7 1 +7 0 1 +7 1 1 +7 2 1 +7 3 1 +7 4 1 +7 5 1 +7 6 1 +7 7 1 +)"; + +inline constexpr std::string_view SORT_DATA = R"( 0 0 0 0 1 1 0 2 1