From 7d143dafbc85c0d57d05538173ad0cf419a2a808 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 15 Jan 2025 15:33:14 +0000 Subject: [PATCH] Changes to allow nan functions to work with xarray --- cubed/array/nan_functions.py | 4 ++-- cubed/array_api/array_object.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cubed/array/nan_functions.py b/cubed/array/nan_functions.py index 1801fa5be..fe0d7b667 100644 --- a/cubed/array/nan_functions.py +++ b/cubed/array/nan_functions.py @@ -8,9 +8,9 @@ # https://github.com/data-apis/array-api/issues/621 -def nanmean(x, /, *, axis=None, keepdims=False, split_every=None): +def nanmean(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None): """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" - dtype = x.dtype + dtype = dtype or x.dtype intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] return reduction( x, diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index 94121d7ea..83b1ee941 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -367,9 +367,9 @@ def __array_namespace__(self, /, *, api_version=None): "2023.12", ): raise ValueError(f"Unrecognized array API version: {api_version!r}") - import cubed.array_api as array_api + import cubed - return array_api + return cubed def __bool__(self, /): if self.ndim != 0: