diff --git a/src/layers/show.jl b/src/layers/show.jl index 421131f365..ea852adc3a 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -113,7 +113,10 @@ underscorise(n::Integer) = join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') function _nan_show(io::IO, x) - if !isempty(x) && _all(iszero, x) + if any(y -> y isa Zygote.AbstractGPUArray, x) + # These friendly warnings take 10-20 sec to compile the first time, for models on GPU. + printstyled(io, " (on GPU)", color=:light_black) + elseif !isempty(x) && _all(iszero, x) printstyled(io, " (all zero)", color=:cyan) elseif _any(isnan, x) printstyled(io, " (some NaN)", color=:red)