Skip to content

Commit 7736d2c

Browse files
Update src/array_api_extra/_lib/_funcs.py
Co-authored-by: Guido Imperiale <[email protected]>
1 parent 507afa4 commit 7736d2c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def one_hot(
394394
# specification.
395395
msg = "x must have a concrete size."
396396
raise TypeError(msg)
397-
out = xp.zeros((x.size, num_classes), dtype=dtype)
397+
out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398398
x_flattened = xp.reshape(x, (-1,))
399399
if supports_fancy_indexing:
400400
out = at(out)[xp.arange(x_size), x_flattened].set(1)

0 commit comments

Comments
 (0)