Skip to content

Commit 49f94d0

Browse files
committed
Automatically register GPU converters upon loading the glue package.
1 parent ee9efbc commit 49f94d0

File tree

4 files changed

+19
-2
lines changed

4 files changed

+19
-2
lines changed

lib/FluxAMDGPU/src/FluxAMDGPU.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,13 @@ using Flux: OneHotArray, OneHotLike
1010
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:ROCArray}}) where N =
1111
AMDGPU.ROCArrayStyle{N}()
1212

13+
function __init__()
14+
if Flux.default_gpu_converter[] === identity
15+
@info "Registering AMDGPU.jl as the default GPU converter"
16+
Flux.default_gpu_converter[] = roc
17+
else
18+
@warn "Not registering AMDGPU.jl as the default GPU converter as another one has been registered already."
19+
end
20+
end
21+
1322
end # module

lib/FluxAMDGPU/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using .AMDGPU
99
ENV["JULIA_GPU_ALLOWSCALAR"] = "false"
1010

1111
using .Flux
12-
Flux.default_gpu_converter[] = AMDGPU.roc
12+
@assert Flux.default_gpu_converter[] == roc
1313

1414
using Zygote
1515
using Zygote: pullback

lib/FluxCUDA/src/FluxCUDA.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,13 @@ include("onehot.jl")
99
include("ctc.jl")
1010
include("cudnn.jl")
1111

12+
function __init__()
13+
if Flux.default_gpu_converter[] === identity
14+
@info "Registering CUDA.jl as the default GPU converter"
15+
Flux.default_gpu_converter[] = cu
16+
else
17+
@warn "Not registering CUDA.jl as the default GPU converter as another one has been registered already."
18+
end
19+
end
1220

1321
end # module

lib/FluxCUDA/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using .CUDA
99
ENV["JULIA_GPU_ALLOWSCALAR"] = "false"
1010

1111
using .Flux
12-
Flux.default_gpu_converter[] = cu
12+
@assert Flux.default_gpu_converter[] == cu
1313

1414
using Zygote
1515
using Zygote: pullback

0 commit comments

Comments
 (0)