Skip to content

Commit 8c8b5c5

Browse files
committed
Revert "MAINT: clarify default_device output"
This reverts commit 2c87d36.
1 parent 3312d8f commit 8c8b5c5

File tree

5 files changed

+14
-35
lines changed

5 files changed

+14
-35
lines changed

array_api_compat/common/_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# These functions are modified from the NumPy versions.
2020

21-
# Creation functions add the device keyword (which does nothing for NumPy and Dask)
21+
# Creation functions add the device keyword (which does nothing for NumPy)
2222

2323
def arange(
2424
start: Union[int, float],

array_api_compat/cupy/_info.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
complex128,
2727
)
2828

29-
3029
class __array_namespace_info__:
3130
"""
3231
Get the array API inspection namespace for CuPy.
@@ -118,7 +117,7 @@ def default_device(self):
118117
119118
Returns
120119
-------
121-
device : Device
120+
device : str
122121
The default device used for new CuPy arrays.
123122
124123
Examples
@@ -127,15 +126,6 @@ def default_device(self):
127126
>>> info.default_device()
128127
Device(0)
129128
130-
Notes
131-
-----
132-
This method returns the static default device when CuPy is initialized.
133-
However, the *current* device used by creation functions (``empty`` etc.)
134-
can be changed globally or with a context manager.
135-
136-
See Also
137-
--------
138-
https://github.com/data-apis/array-api/issues/835
139129
"""
140130
return cuda.Device(0)
141131

@@ -322,7 +312,7 @@ def devices(self):
322312
323313
Returns
324314
-------
325-
devices : list[Device]
315+
devices : list of str
326316
The devices supported by CuPy.
327317
328318
See Also

array_api_compat/dask/array/_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def default_device(self):
130130
131131
Returns
132132
-------
133-
device : Device
133+
device : str
134134
The default device used for new Dask arrays.
135135
136136
Examples
@@ -335,7 +335,7 @@ def devices(self):
335335
336336
Returns
337337
-------
338-
devices : list[Device]
338+
devices : list of str
339339
The devices supported by Dask.
340340
341341
See Also

array_api_compat/numpy/_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def default_device(self):
119119
120120
Returns
121121
-------
122-
device : Device
122+
device : str
123123
The default device used for new NumPy arrays.
124124
125125
Examples
@@ -326,7 +326,7 @@ def devices(self):
326326
327327
Returns
328328
-------
329-
devices : list[Device]
329+
devices : list of str
330330
The devices supported by NumPy.
331331
332332
See Also

array_api_compat/torch/_info.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -102,24 +102,15 @@ def default_device(self):
102102
103103
Returns
104104
-------
105-
device : Device
105+
device : str
106106
The default device used for new PyTorch arrays.
107107
108108
Examples
109109
--------
110110
>>> info = np.__array_namespace_info__()
111111
>>> info.default_device()
112-
device(type='cpu')
112+
'cpu'
113113
114-
Notes
115-
-----
116-
This method returns the static default device when PyTorch is initialized.
117-
However, the *current* device used by creation functions (``empty`` etc.)
118-
can be changed at runtime.
119-
120-
See Also
121-
--------
122-
https://github.com/data-apis/array-api/issues/835
123114
"""
124115
return torch.device("cpu")
125116

@@ -129,9 +120,9 @@ def default_dtypes(self, *, device=None):
129120
130121
Parameters
131122
----------
132-
device : Device, optional
133-
The device to get the default data types for.
134-
Unused for PyTorch, as all devices use the same default dtypes.
123+
device : str, optional
124+
The device to get the default data types for. For PyTorch, only
125+
``'cpu'`` is allowed.
135126
136127
Returns
137128
-------
@@ -259,9 +250,8 @@ def dtypes(self, *, device=None, kind=None):
259250
260251
Parameters
261252
----------
262-
device : Device, optional
253+
device : str, optional
263254
The device to get the data types for.
264-
Unused for PyTorch, as all devices use the same dtypes.
265255
kind : str or tuple of str, optional
266256
The kind of data types to return. If ``None``, all data types are
267257
returned. If a string, only data types of that kind are returned.
@@ -320,7 +310,7 @@ def devices(self):
320310
321311
Returns
322312
-------
323-
devices : list[Device]
313+
devices : list of str
324314
The devices supported by PyTorch.
325315
326316
See Also
@@ -343,7 +333,6 @@ def devices(self):
343333
# device:
344334
try:
345335
torch.device('notadevice')
346-
raise AssertionError("unreachable") # pragma: nocover
347336
except RuntimeError as e:
348337
# The error message is something like:
349338
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"

0 commit comments

Comments
 (0)