diff --git a/jax/_src/array.py b/jax/_src/array.py
index 1c24142981ca..5a94549df60d 100644
--- a/jax/_src/array.py
+++ b/jax/_src/array.py
@@ -46,7 +46,7 @@
     SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
     device_replica_id_map, hashed_index)
 from jax._src.layout import DeviceLocalLayout, Layout
-from jax._src.typing import ArrayLike
+from jax._src.typing import ArrayLike, DLDeviceType
 from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
 
 
@@ -404,11 +404,25 @@ def __array__(self, dtype=None, context=None, copy=None):
     kwds = {} if copy is None else {'copy': copy}
     return np.asarray(self._value, dtype=dtype, **kwds)
 
-  def __dlpack__(self, *, stream: int | Any | None = None):
-    if len(self._arrays) != 1:
-      raise BufferError("__dlpack__ only supported for unsharded arrays.")
+  def __dlpack__(self, *, stream: int | Any | None = None,
+                 max_version: tuple[int, int] | None = None,
+                 dl_device: tuple[DLDeviceType, int] | None = None,
+                 copy: bool | None = None):
     from jax._src.dlpack import to_dlpack  # pylint: disable=g-import-not-at-top
-    return to_dlpack(self, stream=stream)
+
+    device_set = self.sharding.device_set
+    if len(device_set) > 1:
+      raise BufferError(
+        "to_dlpack can only pack a dlpack tensor from an array on a singular "
+        f"device, but an array with a Sharding over {len(device_set)} devices "
+        "was provided."
+      )
+    device, = device_set
+    return to_dlpack(self, stream=stream,
+                     max_version=max_version,
+                     src_device=device,
+                     dl_device=dl_device,
+                     copy=copy)
 
   def __dlpack_device__(self) -> tuple[enum.Enum, int]:
     if len(self._arrays) != 1:
diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py
index 72503fe18a2c..45577eaa03d3 100644
--- a/jax/_src/dlpack.py
+++ b/jax/_src/dlpack.py
@@ -14,19 +14,21 @@
 
 from __future__ import annotations
 
-import enum
 from typing import Any
-import warnings
 
 from jax._src.api import device_put
 from jax import numpy as jnp
 from jax._src import array
 from jax._src import xla_bridge
+from jax._src.lax.lax import _array_copy
 from jax._src.lib import xla_client
 from jax._src.lib import xla_extension_version
-from jax._src.typing import Array
+from jax._src.typing import Array, DLDeviceType
 from jax._src.sharding import Sharding
 
+DLPACK_VERSION = (0, 8)
+MIN_DLPACK_VERSION = (0, 5)
+
 # A set of dtypes that dlpack supports.
 # Note: Make sure to use a "type", not a dtype instance, when looking up this set
 # because their hashes are different.
@@ -43,45 +45,112 @@
   SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})
 
 
-# Mirror of dlpack.h enum
-class DLDeviceType(enum.IntEnum):
-  kDLCPU = 1
-  kDLCUDA = 2
-  kDLROCM = 10
+def _to_dlpack(x: Array, stream: int | Any | None,
+               src_device: xla_client.Device | None = None,
+               device: xla_client.Device | None = None,
+               copy: bool | None = None):
 
+  if src_device is None:
+    src_device, = x.devices()
+  if device and (src_device is None or device != src_device):
+    if copy is not None and not copy:
+      raise ValueError(
+        f"Specified {device=} which requires a copy since the source device "
+        f"is {repr(src_device)}, however copy=False. Set copy=True or "
+        "copy=None to perform the requested operation."
+      )
+    else:
+      arr = device_put(x, device)
+  else:
+    arr = _array_copy(x) if copy else x
+  return xla_client._xla.buffer_to_dlpack_managed_tensor(
+    arr.addressable_data(0), stream=stream
+  )
 
-def to_dlpack(x: Array, take_ownership: bool = False,
-              stream: int | Any | None = None):
+def to_dlpack(x: Array, stream: int | Any | None = None,
+              src_device: xla_client.Device | None = None,
+              dl_device: tuple[DLDeviceType, int] | None = None,
+              max_version: tuple[int, int] | None = None,
+              copy : bool | None = None):
   """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
 
   Args:
     x: a :class:`~jax.Array`, on either CPU or GPU.
-    take_ownership: Deprecated. It is a no-op to set take_ownership. Will be
-      deleted in 01/2024.
     stream: optional platform-dependent stream to wait on until the buffer is
       ready. This corresponds to the `stream` argument to ``__dlpack__``
       documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
+    src_device: either a CPU or GPU :class:`~jax.Device`.
+    dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
+      format e.g. as produced by ``__dlpack_device__``.
+    max_version: the maximum DLPack version that the consumer (i.e. caller of
+      ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
+      This function is not guaranteed to return a capsule of version
+      ``max_version``.
+    copy: a boolean indicating whether or not to copy the input. If
+      ``copy=True`` then the function must always copy. When
+      ``copy=False`` then the function must never copy, and must raise an error
+      when a copy is deemed necessary. If ``copy=None`` then the function must
+      avoid a copy if possible but may copy if needed.
 
   Returns:
-    A dlpack PyCapsule object.
+    A DLPack PyCapsule object.
 
   Note:
-    While JAX arrays are always immutable, dlpack buffers cannot be marked as
-    immutable, and it is possible for processes external to JAX to mutate them
-    in-place. If a dlpack buffer derived from a JAX array is mutated, it may
-    lead to undefined behavior when using the associated JAX array.
+    While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
+    cannot be marked as immutable, and it is possible for processes external
+    to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
+    is mutated, it may lead to undefined behavior when using the associated JAX
+    array. When JAX eventually supports ``DLManagedTensorVersioned``
+    (DLPack 1.0), it will be possible to specify that a buffer is read-only.
   """
   if not isinstance(x, array.ArrayImpl):
     raise TypeError("Argument to to_dlpack must be a jax.Array, "
                     f"got {type(x)}")
-  assert len(x.devices()) == 1
-  if take_ownership:
-    warnings.warn(
-        "take_ownership in to_dlpack is deprecated and it is a no-op."
+
+  device = None
+  dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
+  if dl_device_type:
+    try:
+      dl_device_platform = {
+          DLDeviceType.kDLCPU: "cpu",
+          DLDeviceType.kDLCUDA: "cuda",
+          DLDeviceType.kDLROCM: "rocm",
+      }[dl_device_type]
+      backend = xla_bridge.get_backend(dl_device_platform)
+      device = backend.device_from_local_hardware_id(local_hardware_id)
+    except TypeError:
+      # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
+      # recommends using BufferError.
+      raise BufferError(
+          "The device specification passed to to_dlpack contains an unsupported "
+          f"device type (DLDeviceType: {dl_device_type})")
+
+  # As new versions are adopted over time, we can maintain some legacy paths
+  # for compatability mediated through the max_version parameter.
+  # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
+  # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
+  # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
+  if max_version is None or max_version >= DLPACK_VERSION:
+    # Latest
+    return _to_dlpack(
+      x, stream=stream,
+      src_device=src_device,
+      device=device,
+      copy=copy
+    )
+  elif max_version >= MIN_DLPACK_VERSION:
+    # Oldest supported
+    return _to_dlpack(
+      x, stream=stream,
+      src_device=src_device,
+      device=device,
+      copy=copy
+    )
+  else:
+    raise BufferError(
+      f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
+      f"version ({max_version}) was requested."
     )
-  return xla_client._xla.buffer_to_dlpack_managed_tensor(
-      x.addressable_data(0), stream=stream
-  )  # type: ignore
 
 def _place_array(_arr, device, dlpack_device, copy):
   if device and dlpack_device != device:
diff --git a/jax/_src/typing.py b/jax/_src/typing.py
index 6cd466500f71..afbdedf5c936 100644
--- a/jax/_src/typing.py
+++ b/jax/_src/typing.py
@@ -29,6 +29,7 @@
 from collections.abc import Sequence
 from typing import Any, Protocol, Union
 import numpy as np
+import enum
 
 from jax._src.basearray import (
     Array as Array,
@@ -83,3 +84,9 @@ def shape(self) -> Shape: ...
 class DeprecatedArg:
   def __repr__(self):
     return "Deprecated"
+
+# Mirror of dlpack.h enum
+class DLDeviceType(enum.IntEnum):
+  kDLCPU = 1
+  kDLCUDA = 2
+  kDLROCM = 10
diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py
index 077ae796e3e2..65c95d9c2ea8 100644
--- a/jax/experimental/jax2tf/call_tf.py
+++ b/jax/experimental/jax2tf/call_tf.py
@@ -334,7 +334,7 @@ def _arg_jax_to_tf(arg_jax):
     if (isinstance(arg_jax, jax.Array) and
         list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
         arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
-      arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
+      arg_dlpack = jax.dlpack.to_dlpack(arg_jax)
       return tf.experimental.dlpack.from_dlpack(arg_dlpack)
     # The following avoids copies to the host on CPU, always for Array
     # and even for ndarray if they are sufficiently aligned.
diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py
index 9935cd915530..6624ef723a3a 100644
--- a/tests/array_interoperability_test.py
+++ b/tests/array_interoperability_test.py
@@ -73,23 +73,48 @@ def setUp(self):
   @jtu.sample_product(
     shape=all_shapes,
     dtype=dlpack_dtypes,
-    gpu=[False, True],
+    copy=[False, True, None]
   )
-  def testJaxRoundTrip(self, shape, dtype, gpu):
+  @jtu.run_on_devices("gpu")
+  def testJaxRoundTrip(self, shape, dtype, copy):
+    if xb.using_pjrt_c_api():
+      self.skipTest("DLPack support is incomplete in the PJRT C API")  # TODO(skyewm)
     rng = jtu.rand_default(self.rng())
     np = rng(shape, dtype)
-    if gpu and jtu.test_device_matches(["cpu"]):
-      raise unittest.SkipTest("Skipping GPU test case on CPU")
-    device = jax.devices("gpu" if gpu else "cpu")[0]
-    x = jax.device_put(np, device)
-    dlpack = jax.dlpack.to_dlpack(x)
-    y = jax.dlpack.from_dlpack(dlpack)
-    self.assertEqual(y.devices(), {device})
-    self.assertAllClose(np.astype(x.dtype), y)
 
+    def _check_copy(x: jax.Array, y: jax.Array, expect_copy):
+      copied = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()
+      assert copied == expect_copy, f"Expected {'a' if expect_copy else 'no'} copy"
+
+    # Check if the source device is preserved
+    x = jax.device_put(np, jax.devices("cpu")[0])
+    device = jax.devices("gpu")[0]
+    y = jax.device_put(x, device)
+    dl_device = y.__dlpack_device__()
+    dlpack = jax.dlpack.to_dlpack(y, copy=copy)
+    z = jax.dlpack.from_dlpack(dlpack)
+
+    self.assertEqual(z.devices(), {device})
+    self.assertAllClose(np.astype(x.dtype), z)
     self.assertRaisesRegex(RuntimeError,
-                           "DLPack tensor may be consumed at most once",
-                           lambda: jax.dlpack.from_dlpack(dlpack))
+                          "DLPack tensor may be consumed at most once",
+                          lambda: jax.dlpack.from_dlpack(dlpack))
+
+    if shape in nonempty_array_shapes:
+      _check_copy(y, z, bool(copy))
+
+    # Check if the destination device can be specified
+    make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy)
+    if copy == False:
+      self.assertRaisesRegex(ValueError, "copy=False", make_dlpack)
+      return
+
+    z = jax.dlpack.from_dlpack(make_dlpack())
+    self.assertEqual(z.devices(), {device})
+    self.assertAllClose(x, z)
+
+    if shape in nonempty_array_shapes:
+      _check_copy(x, z, True)
 
   @jtu.sample_product(
     shape=all_shapes,