diff --git a/spec/draft/API_specification/searching_functions.rst b/spec/draft/API_specification/searching_functions.rst index c952f1aad..1a584f158 100644 --- a/spec/draft/API_specification/searching_functions.rst +++ b/spec/draft/API_specification/searching_functions.rst @@ -22,6 +22,7 @@ Objects in API argmax argmin + count_nonzero nonzero searchsorted where diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 029459b9a..4eee3173b 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -1,7 +1,7 @@ -__all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"] +__all__ = ["argmax", "argmin", "count_nonzero", "nonzero", "searchsorted", "where"] -from ._types import Optional, Tuple, Literal, array +from ._types import Optional, Tuple, Literal, Union, array def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: @@ -54,15 +54,41 @@ def argmin(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ -def nonzero(x: array, /) -> Tuple[array, ...]: +def count_nonzero( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> array: """ - Returns the indices of the array elements which are non-zero. + Counts the number of array elements which are non-zero. - .. note:: - If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero. + Parameters + ---------- + x: array + input array. + axis: Optional[Union[int, Tuple[int, ...]]] + axis or axes along which to count non-zero values. By default, the number of non-zero values must be computed over the entire array. If a tuple of integers, the number of non-zero values must be computed over multiple axes. Default: ``None``. + keepdims: bool + if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result. Default: ``False``. - .. note:: - If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``. + Returns + ------- + out: array + if the number of non-zeros values was computed over the entire array, a zero-dimensional array containing the total number of non-zero values; otherwise, a non-zero-dimensional array containing the counts along the specified axes. The returned array must have the default array index data type. + + Notes + ----- + + - If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero. + - If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``. + """ + + +def nonzero(x: array, /) -> Tuple[array, ...]: + """ + Returns the indices of the array elements which are non-zero. .. admonition:: Data-dependent output shape :class: admonition important @@ -76,12 +102,15 @@ def nonzero(x: array, /) -> Tuple[array, ...]: Returns ------- - out: Typle[array, ...] + out: Tuple[array, ...] a tuple of ``k`` arrays, one for each dimension of ``x`` and each of size ``n`` (where ``n`` is the total number of non-zero elements), containing the indices of the non-zero elements in that dimension. The indices must be returned in row-major, C-style order. The returned array must have the default array index data type. Notes ----- + - If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero. + - If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``. + .. versionchanged:: 2022.12 Added complex data type support. """