diff --git a/src/common_defaults.jl b/src/common_defaults.jl index 3fd2400e2..2ec515552 100644 --- a/src/common_defaults.jl +++ b/src/common_defaults.jl @@ -22,20 +22,20 @@ end @inline ODE_DEFAULT_NORM(u::Union{AbstractFloat, Complex}, t) = @fastmath abs(u) @inline function ODE_DEFAULT_NORM(u::Array{T}, t) where {T <: Union{AbstractFloat, Complex}} - x = abs2(u[1]) - @inbounds for i in 2:length(u) - x += abs2(u[i]) + x = zero(T) + @inbounds @fastmath for ui in u + x += abs2(ui) end - Base.FastMath.sqrt_fast(real(x) / length(u)) + Base.FastMath.sqrt_fast(real(x) / max(length(u), 1)) end @inline function ODE_DEFAULT_NORM(u::StaticArrays.StaticArray{T}, t) where {T <: Union{AbstractFloat, Complex}} - Base.FastMath.sqrt_fast(real(sum(abs2, u)) / length(u)) + Base.FastMath.sqrt_fast(real(sum(abs2, u)) / max(length(u), 1)) end @inline function ODE_DEFAULT_NORM(u::AbstractArray, t) - Base.FastMath.sqrt_fast(UNITLESS_ABS2(u) / recursive_length(u)) + Base.FastMath.sqrt_fast(UNITLESS_ABS2(u) / max(recursive_length(u), 1)) end @inline ODE_DEFAULT_NORM(u, t) = norm(u) diff --git a/test/ode_default_norm.jl b/test/ode_default_norm.jl index 5fa70d2a3..c6dc707ea 100644 --- a/test/ode_default_norm.jl +++ b/test/ode_default_norm.jl @@ -44,3 +44,5 @@ u7 = ArrayPartition(u1, ones(0)) @test UNITLESS_ABS2(u7) == 3.0 @test recursive_length(u7) == 3 @test ODE_DEFAULT_NORM(u7, 0.0) == 1.0 + +@test ODE_DEFAULT_NORM(Float64[], 0.0) == 0.0