diff --git a/numcodecs/compat.py b/numcodecs/compat.py index 52827ca6..a239e838 100644 --- a/numcodecs/compat.py +++ b/numcodecs/compat.py @@ -6,6 +6,7 @@ import array +import cupy as cp import numpy as np @@ -48,8 +49,8 @@ def ensure_ndarray(buf): """ - if isinstance(buf, np.ndarray): - # already a numpy array + if isinstance(buf, (np.ndarray, cp.ndarray)): + # already an array arr = buf elif isinstance(buf, array.array) and buf.typecode in 'cu': @@ -151,6 +152,12 @@ def ensure_bytes(buf): if arr.dtype == object: raise TypeError('object arrays are not supported') + # Force CuPy arrays to NumPy arrays + # because they don't have a `tobytes` method (yet) + # xref: https://github.com/cupy/cupy/pull/2617 + if isinstance(arr, cp.ndarray): + arr = arr.get() + # create bytes buf = arr.tobytes(order='A') @@ -160,6 +167,12 @@ def ensure_bytes(buf): def ensure_text(s, encoding='utf-8'): if not isinstance(s, text_type): s = ensure_contiguous_ndarray(s) + + # Force CuPy arrays to NumPy arrays + # as they support the buffer protocol + if isinstance(s, cp.ndarray): + s = s.get() + s = codecs.decode(s, encoding) return s