Skip to content

Commit 35543cf

Browse files
committed
Update
1 parent d0eae05 commit 35543cf

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

jax/_src/array.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,28 @@ def __array__(self, dtype=None, context=None, copy=None):
404404
kwds = {} if copy is None else {'copy': copy}
405405
return np.asarray(self._value, dtype=dtype, **kwds)
406406

407-
def __dlpack__(self, *, stream: int | Any | None = None):
407+
def __dlpack__(self, *, stream: int | Any | None = None,
408+
max_version: tuple[int, int] | None = None,
409+
dl_device: tuple[enum.Enum, int] | None = None,
410+
copy: bool | None = None):
408411
if len(self._arrays) != 1:
409412
raise ValueError("__dlpack__ only supported for unsharded arrays.")
413+
410414
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
411-
return to_dlpack(self, stream=stream)
415+
416+
device_set = self.sharding.device_set
417+
if len(device_set) > 1:
418+
raise BufferError(
419+
"to_dlpack can only pack a dlpack tensor from an array on a singular "
420+
f"device, but an array with a Sharding over {len(device_set)} devices "
421+
"was provided."
422+
)
423+
device = device_set.pop()
424+
return to_dlpack(self, stream=stream,
425+
max_version=max_version,
426+
device=device,
427+
dl_device=dl_device, # type: ignore
428+
copy=copy)
412429

413430
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
414431
if len(self._arrays) != 1:

jax/_src/dlpack.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
from jax._src.lib import xla_client
2525
from jax._src.lib import xla_extension_version
2626
from jax._src.typing import Array
27+
from jax._src.api import device_put
2728

29+
DLPACK_VERSION = (0, 1)
30+
MIN_DLPACK_VERSION = (0, 1)
2831

2932
# A set of dtypes that dlpack supports.
3033
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -48,9 +51,33 @@ class DLDeviceType(enum.IntEnum):
4851
kDLCUDA = 2
4952
kDLROCM = 10
5053

54+
def _to_dlpack(x: Array, stream: int | Any | None,
55+
device: xla_client.Device | None = None,
56+
dlpack_device: xla_client.Device | None = None,
57+
copy: bool | None = None):
58+
arr = None
59+
if dlpack_device and dlpack_device != device:
60+
if copy is not None and not copy:
61+
raise ValueError(
62+
f"Specified {dlpack_device=} which requires a copy since the source device "
63+
f"is {repr(device)}, however copy=False. Set copy=True or "
64+
"copy=None to perform the requested operation."
65+
)
66+
else:
67+
arr = device_put(x, dlpack_device)
68+
if arr is None:
69+
arr = x.copy() if copy else x
70+
71+
return xla_client._xla.buffer_to_dlpack_managed_tensor(
72+
arr.addressable_data(0), stream=stream
73+
) # type: ignore
5174

5275
def to_dlpack(x: Array, take_ownership: bool = False,
53-
stream: int | Any | None = None):
76+
stream: int | Any | None = None,
77+
device: xla_client.Device | None = None,
78+
dl_device: tuple[DLDeviceType, int] | None = None,
79+
max_version: tuple[int, int] | None = None,
80+
copy : bool | None = None):
5481
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
5582
5683
Args:
@@ -73,14 +100,40 @@ def to_dlpack(x: Array, take_ownership: bool = False,
73100
if not isinstance(x, array.ArrayImpl):
74101
raise TypeError("Argument to to_dlpack must be a jax.Array, "
75102
f"got {type(x)}")
76-
assert len(x.devices()) == 1
77103
if take_ownership:
78104
warnings.warn(
79105
"take_ownership in to_dlpack is deprecated and it is a no-op."
80106
)
81-
return xla_client._xla.buffer_to_dlpack_managed_tensor(
82-
x.addressable_data(0), stream=stream
83-
) # type: ignore
107+
108+
dlpack_device = None
109+
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
110+
if dl_device_type:
111+
try:
112+
dl_device_platform = {
113+
DLDeviceType.kDLCPU: "cpu",
114+
DLDeviceType.kDLCUDA: "cuda",
115+
DLDeviceType.kDLROCM: "rocm",
116+
}[dl_device_type]
117+
backend = xla_bridge.get_backend(dl_device_platform)
118+
dlpack_device = backend.device_from_local_hardware_id(local_hardware_id)
119+
except TypeError:
120+
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
121+
# recommends using BufferError.
122+
raise BufferError(
123+
"The device specification passed to to_dlpack contains an unsupported "
124+
f"device type (DLDeviceType: {dl_device_type})")
125+
126+
if max_version is None or max_version[0] >= DLPACK_VERSION[0]:
127+
return _to_dlpack(x, stream=stream, device=device, dlpack_device=dlpack_device, copy=copy)
128+
elif max_version >= MIN_DLPACK_VERSION:
129+
# Legacy path to be implemented when XLA adopts DLManagedTensorVersioned format
130+
raise RuntimeError("This branch should be unreachable. "
131+
"Please open a bug if you see this.")
132+
else:
133+
raise BufferError(
134+
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
135+
f"version ({max_version}) was requested."
136+
)
84137

85138

86139
def from_dlpack(external_array):

0 commit comments

Comments
 (0)