From a0af83768398e9fdc5be4688309e7cf4aeae7639 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 18 Apr 2024 01:36:22 -0700 Subject: [PATCH] feat: add `nextafter` to specification --- .../elementwise_functions.rst | 1 + .../_draft/elementwise_functions.py | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/spec/draft/API_specification/elementwise_functions.rst b/spec/draft/API_specification/elementwise_functions.rst index 4919cff98..a853ca18c 100644 --- a/spec/draft/API_specification/elementwise_functions.rst +++ b/spec/draft/API_specification/elementwise_functions.rst @@ -66,6 +66,7 @@ Objects in API minimum multiply negative + nextafter not_equal positive pow diff --git a/src/array_api_stubs/_draft/elementwise_functions.py b/src/array_api_stubs/_draft/elementwise_functions.py index 4462329d6..ec0b0567c 100644 --- a/src/array_api_stubs/_draft/elementwise_functions.py +++ b/src/array_api_stubs/_draft/elementwise_functions.py @@ -48,6 +48,7 @@ "minimum", "multiply", "negative", + "nextafter", "not_equal", "positive", "pow", @@ -2069,6 +2070,35 @@ def negative(x: array, /) -> array: """ +def nextafter(x1: array, x2: array, /) -> array: + """ + Returns the next representable floating-point value for each element ``x1_i`` of the input array ``x1`` in the direction of the respective element ``x2_i`` of the input array ``x2``. + + Parameters + ---------- + x1: array + first input array. Should have a real-valued floating-point data type. + x2: array + second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have the same data type as ``x1``. + + Returns + ------- + out: array + an array containing the element-wise results. The returned array must have the same data type as ``x1``. + + Notes + ----- + + **Special cases** + + For real-valued floating-point operands, + + - If either ``x1_i`` or ``x2_i`` is ``NaN``, the result is ``NaN``. + - If ``x1_i`` is ``-0`` and ``x2_i`` is ``+0``, the result is ``+0``. + - If ``x1_i`` is ``+0`` and ``x2_i`` is ``-0``, the result is ``-0``. + """ + + def not_equal(x1: array, x2: array, /) -> array: """ Computes the truth value of ``x1_i != x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``.