Skip to content

NumPy __array_ufunc__ does not work with typing #6524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
milliams opened this issue Apr 27, 2022 · 7 comments
Open

NumPy __array_ufunc__ does not work with typing #6524

milliams opened this issue Apr 27, 2022 · 7 comments

Comments

@milliams
Copy link

What is your issue?

When using NumPy functions which have been patched to work with xarray objects using __array_ufunc__, typing tools like MyPy do not correctly calculate the return value.

For example, the function np.exp has been adapted by xarray to return a DataArray if it is passed a DataArray, so that code like the following will work:

import xarray as xr
import numpy as np

da = xr.DataArray([1, 2, 3, 4])
blah = np.exp(da).rename("blah")

This code creates an xr.DataArray, uses numpy to calculate the exp of its values and then, since it is returned as a xr.DataArray it can call xarray methods like rename on it.

However, running MyPy on this code gives the error:

error: "ndarray[Any, dtype[Any]]" has no attribute "rename"

This is because there is typing information from NumPy which claims that np.exp returns ndarray[Any, dtype[Any]].

Now, I'm unsure whether this is a bug in xarray not providing the typing information to the __array_ufunc__ code, a bug in NumPy for not having the flexibility to type these overrides or in MyPy for now allowing for overrides like this to happen but I wanted to star with the place where the observed error occurs.

@milliams milliams added the needs triage Issue that has not been reviewed by xarray team member label Apr 27, 2022
@dcherian dcherian added topic-typing and removed needs triage Issue that has not been reviewed by xarray team member labels Apr 27, 2022
@max-sixty
Copy link
Collaborator

max-sixty commented Apr 27, 2022

Thanks for the issue @milliams .

My guess is that the "ndarray[Any, dtype[Any]]" type signature is from np.exp. Do you know whether we get the same result with another library's __array_ufunc__'s arrays, with that library's equivalent rename method?

@milliams
Copy link
Author

I've checked Pandas, Dask and AstroPy and while they all implement __array_ufunc__ in some way, none of them have typing support yet. I imagine that you are leading the curve with the 3rd-party NumPy typing here so are more likely to be hitting these types of corner cases.

@max-sixty
Copy link
Collaborator

Thanks, I see the same thing re pandas. It makes me think it may be coming from xarray, since it's not just np.exp(series).rename("blah") which is raising.

But I'm not sure where it would be coming from. Our __array_ufunc__ isn't typed...!

@shoyer
Copy link
Member

shoyer commented Apr 28, 2022

I think this would need to get updated on the NumPy side. Ideally NumPy ufuncs would be typed to check for __array_ufunc__. Something like:

from typing import Protocol, TypeVar

class HasArrayUFunc(Protocol):
    def __array_ufunc__(ufunc, method, *inputs, **kwargs):
        pass

ArrayOrHasArrayUFunc = TypeVar("ArrayOrHasArrayUFunc", ndarray, HasArrayUFunc)

def exp(x: ArrayOrHasArrayUFunc) -> ArrayOrHasArrayUFunc:
     ...

@headtr1ck
Copy link
Collaborator

Should we open an issue on numpy for this?

@shoyer
Copy link
Member

shoyer commented Jun 5, 2022

error: "ndarray[Any, dtype[Any]]" has no attribute "rename"

Yes, it's worth discussing. I don't know if there will be a satisfying resolution, though.

@brendan-m-murphy
Copy link

brendan-m-murphy commented Feb 27, 2025

If it's useful for anyone else, here is a hack:

from typing import cast, TypeVar

import numpy as np
import xarray as xr


XrData = TypeVar("XrData", xr.DataArray, xr.Dataset)


class XrTypesHack:
    """Wrap a numpy ufunc to operate on xarray DataArrays and Datasets with correct return types.

    There is a problem with numpy's type hints that causes mypy to think that numpy functions
    applied to xarray DataArrays and Datasets return numpy arrays: https://github.com/pydata/xarray/issues/8388.

    Further, using a class to do this wrapping seems to be necessary. Possibly related to some
    old mypy issue: https://github.com/python/mypy/issues/1551
    """
    def __init__(self, ufunc: np.ufunc) -> None:
        self.ufunc = ufunc

    def __call__(self, data: XrData) -> XrData:
        result = self.ufunc(data)
        if isinstance(data, xr.DataArray):
            return cast(xr.DataArray, result)
        if isinstance(data, xr.Dataset):
            return cast(xr.Dataset, result)


# example usage
xr_sqrt = XrTypesHack(np.sqrt)

I tried writing a function that mapped np.ufunc to Callable[[XrData], XrData] but ran into mypy issues. The class hack is suggested in mypy issue #1551.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants