Skip to content

Commit ffb6f00

Browse files
committed
BroadcastStyle{N+1}
1 parent ac20bd3 commit ffb6f00

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/array.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri
9090

9191
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
9292

93-
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, T}}) where {N, T <: AbstractGPUArray} = Base.BroadcastStyle(T)
93+
function Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
94+
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
95+
S = Base.BroadcastStyle(T)
96+
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
97+
# isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
98+
(typeof(S).name.wrapper){var"N+1"}()
99+
end
94100

95101
Base.map(f, x::OneHotLike) = Base.broadcast(f, x)
96102

0 commit comments

Comments
 (0)