diff --git a/Project.toml b/Project.toml index 15f993706..d6e3fa802 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.145.6" +version = "6.146.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -79,7 +79,7 @@ PrecompileTools = "1" Printf = "1.9" RecursiveArrayTools = "2, 3" Reexport = "1.0" -SciMLBase = "2.12.0" +SciMLBase = "2.19.0" SciMLOperators = "0.2, 0.3" Setfield = "0.8, 1" SparseArrays = "1.9" diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index d3761ab05..c71beebe3 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -2,6 +2,9 @@ NonlinearSafeTerminationReturnCode Return Codes for the safe nonlinear termination conditions. + +These return codes have been deprecated. Termination Conditions will return +`SciMLBase.Retcode.T` starting from v7. """ @enumx NonlinearSafeTerminationReturnCode begin """ @@ -116,15 +119,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges). ```julia RelSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing) ``` """ -Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3} <: +Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: AbstractSafeNonlinearTerminationMode protective_threshold::T1 = nothing patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 + max_stalled_steps::T4 = nothing end @doc doc""" @@ -137,15 +141,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges). ```julia AbsSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing) ``` """ -Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3} <: +Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: AbstractSafeNonlinearTerminationMode protective_threshold::T1 = nothing patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 + max_stalled_steps::T4 = nothing end @doc doc""" @@ -157,15 +162,16 @@ Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found ```julia RelSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing) ``` """ -Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3} <: +Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: AbstractSafeBestNonlinearTerminationMode protective_threshold::T1 = nothing patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 + max_stalled_steps::T4 = nothing end @doc doc""" @@ -177,21 +183,23 @@ Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found ```julia AbsSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing) ``` """ -Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} <: +Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: AbstractSafeBestNonlinearTerminationMode protective_threshold::T1 = nothing patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 + max_stalled_steps::T4 = nothing end -mutable struct NonlinearTerminationModeCache{uType, T, - M <: AbstractNonlinearTerminationMode, I, OT, SV} +mutable struct NonlinearTerminationModeCache{uType, T, dep_retcode, + M <: AbstractNonlinearTerminationMode, I, OT, SV, + R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T}, UN, ST, MSS} u::uType - retcode::NonlinearSafeTerminationReturnCode.T + retcode::R abstol::T reltol::T best_objective_value::T @@ -200,6 +208,10 @@ mutable struct NonlinearTerminationModeCache{uType, T, objectives_trace::OT nsteps::Int saved_values::SV + u0_norm::UN + step_norm_trace::ST + max_stalled_steps::MSS + u_diff_cache::uType end get_termination_mode(cache::NonlinearTerminationModeCache) = cache.mode @@ -227,7 +239,8 @@ end function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T}, mode::AbstractNonlinearTerminationMode, saved_value_prototype...; - abstol = nothing, reltol = nothing, kwargs...) where {T <: Number} + use_deprecated_retcodes::Val{D} = Val(true), # Remove in v8, warn in v7 + abstol = nothing, reltol = nothing, kwargs...) where {D, T <: Number} abstol = _get_tolerance(abstol, T) reltol = _get_tolerance(reltol, T) TT = typeof(abstol) @@ -236,24 +249,75 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T if mode isa AbstractSafeNonlinearTerminationMode if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode initial_objective = maximum(abs, du) + u0_norm = nothing else initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT)) + u0_norm = mode.max_stalled_steps === nothing ? nothing : norm(u, 2) end objectives_trace = Vector{TT}(undef, mode.patience_steps) + step_norm_trace = mode.max_stalled_steps === nothing ? nothing : + Vector{TT}(undef, mode.max_stalled_steps) best_value = initial_objective + max_stalled_steps = mode.max_stalled_steps + if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && step_norm_trace !== nothing + u_diff_cache = similar(u_) + else + u_diff_cache = u_ + end else initial_objective = nothing objectives_trace = nothing + u0_norm = nothing + step_norm_trace = nothing best_value = __cvt_real(T, Inf) + max_stalled_steps = nothing + u_diff_cache = u_ end length(saved_value_prototype) == 0 && (saved_value_prototype = nothing) - return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode), - typeof(initial_objective), typeof(objectives_trace), - typeof(saved_value_prototype)}(u_, NonlinearSafeTerminationReturnCode.Default, - abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0, - saved_value_prototype) + retcode = ifelse(D, NonlinearSafeTerminationReturnCode.Default, ReturnCode.Default) + + return NonlinearTerminationModeCache{typeof(u_), TT, D, typeof(mode), + typeof(initial_objective), typeof(objectives_trace), typeof(saved_value_prototype), + typeof(retcode), typeof(u0_norm), typeof(step_norm_trace), + typeof(max_stalled_steps)}(u_, retcode, abstol, reltol, best_value, mode, + initial_objective, objectives_trace, 0, saved_value_prototype, u0_norm, + step_norm_trace, max_stalled_steps, u_diff_cache) +end + +function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du, + u, saved_value_prototype...; abstol = nothing, reltol = nothing, + kwargs...) where {uType, T, dep_retcode} + length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype) + + u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ? + (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing + cache.u = u_ + cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default, + ReturnCode.Default) + + cache.abstol = _get_tolerance(abstol, T) + cache.reltol = _get_tolerance(reltol, T) + cache.nsteps = 0 + + mode = get_termination_mode(cache) + if mode isa AbstractSafeNonlinearTerminationMode + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + initial_objective = maximum(abs, du) + else + initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT)) + cache.max_stalled_steps !== nothing && (cache.u0_norm = norm(u_, 2)) + end + best_value = initial_objective + else + initial_objective = nothing + objectives_trace = nothing + best_value = __cvt_real(T, Inf) + end + cache.best_objective_value = best_value + cache.initial_objective = initial_objective + return cache end # This dispatch is needed based on how Terminating Callback works! @@ -273,8 +337,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminati return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) end -function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode, - du, u, uprev, args...) +function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::AbstractSafeNonlinearTerminationMode, + du, u, uprev, args...) where {uType, TT, dep_retcode} if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode objective = maximum(abs, du) criteria = cache.abstol @@ -285,13 +349,15 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi # Protective Break if isinf(objective) || isnan(objective) - cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable) return true end ## By default we turn this off since it has the potential for false positives if cache.mode.protective_threshold !== nothing && (objective > cache.initial_objective * cache.mode.protective_threshold * length(du)) - cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable) return true end @@ -307,7 +373,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi # Main Termination Condition if objective ≤ criteria - cache.retcode = NonlinearSafeTerminationReturnCode.Success + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.Success, ReturnCode.Success) return true end @@ -324,13 +391,43 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi min_obj, max_obj = extrema(cache.objectives_trace) end if min_obj < cache.mode.min_max_factor * max_obj - cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.PatienceTermination, + ReturnCode.Stalled) + return true + end + end + end + + # Test for stalling if that is not disabled + if cache.step_norm_trace !== nothing + if ArrayInterface.can_setindex(cache.u_diff_cache) && !(u isa Number) + @. cache.u_diff_cache = u - uprev + else + cache.u_diff_cache = u .- uprev + end + du_norm = norm(cache.u_diff_cache, 2) + cache.step_norm_trace[mod1(cache.nsteps, length(cache.step_norm_trace))] = du_norm + if cache.nsteps ≥ cache.mode.max_stalled_steps + max_step_norm = maximum(cache.step_norm_trace) + if cache.mode isa AbsSafeTerminationMode || + cache.mode isa AbsSafeBestTerminationMode + stalled_step = max_step_norm ≤ cache.abstol + else + stalled_step = max_step_norm ≤ + cache.reltol * (max_step_norm + cache.u0_norm) + end + if stalled_step + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.PatienceTermination, + ReturnCode.Stalled) return true end end end - cache.retcode = NonlinearSafeTerminationReturnCode.Failure + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.Failure, ReturnCode.Failure) return false end