Skip to content

Commit 507afa4

Browse files
Update src/array_api_extra/_delegation.py
Co-authored-by: Guido Imperiale <[email protected]>
1 parent bb3e4ea commit 507afa4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/array_api_extra/_delegation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def one_hot(
169169
msg = "x must have an integral dtype."
170170
raise TypeError(msg)
171171
if dtype is None:
172-
dtype = xp.empty(()).dtype # Default float dtype
172+
dtype = xp.__array_namespace_info__().default_dtypes(device=get_device(x))["real floating"]
173173
# Delegate where possible.
174174
if is_jax_namespace(xp):
175175
assert is_jax_array(x)

0 commit comments

Comments
 (0)