We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 507afa4 commit 7736d2cCopy full SHA for 7736d2c
src/array_api_extra/_lib/_funcs.py
@@ -394,7 +394,7 @@ def one_hot(
394
# specification.
395
msg = "x must have a concrete size."
396
raise TypeError(msg)
397
- out = xp.zeros((x.size, num_classes), dtype=dtype)
+ out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398
x_flattened = xp.reshape(x, (-1,))
399
if supports_fancy_indexing:
400
out = at(out)[xp.arange(x_size), x_flattened].set(1)
0 commit comments