Skip to content

Commit 5b5a938

Browse files
Merge pull request #994 from SciML/ap/retcode
Implement Stalling and Use ReturnCode
2 parents 4c23942 + a0cf827 commit 5b5a938

File tree

2 files changed

+123
-26
lines changed

2 files changed

+123
-26
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.145.6"
4+
version = "6.146.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -79,7 +79,7 @@ PrecompileTools = "1"
7979
Printf = "1.9"
8080
RecursiveArrayTools = "2, 3"
8181
Reexport = "1.0"
82-
SciMLBase = "2.12.0"
82+
SciMLBase = "2.19.0"
8383
SciMLOperators = "0.2, 0.3"
8484
Setfield = "0.8, 1"
8585
SparseArrays = "1.9"

src/termination_conditions.jl

Lines changed: 121 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
NonlinearSafeTerminationReturnCode
33
44
Return Codes for the safe nonlinear termination conditions.
5+
6+
These return codes have been deprecated. Termination Conditions will return
7+
`SciMLBase.Retcode.T` starting from v7.
58
"""
69
@enumx NonlinearSafeTerminationReturnCode begin
710
"""
@@ -116,15 +119,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges).
116119
117120
```julia
118121
RelSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100,
119-
patience_objective_multiplier = 3, min_max_factor = 1.3)
122+
patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing)
120123
```
121124
"""
122-
Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3} <:
125+
Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <:
123126
AbstractSafeNonlinearTerminationMode
124127
protective_threshold::T1 = nothing
125128
patience_steps::Int = 100
126129
patience_objective_multiplier::T2 = 3
127130
min_max_factor::T3 = 1.3
131+
max_stalled_steps::T4 = nothing
128132
end
129133

130134
@doc doc"""
@@ -137,15 +141,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges).
137141
138142
```julia
139143
AbsSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100,
140-
patience_objective_multiplier = 3, min_max_factor = 1.3)
144+
patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing)
141145
```
142146
"""
143-
Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3} <:
147+
Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <:
144148
AbstractSafeNonlinearTerminationMode
145149
protective_threshold::T1 = nothing
146150
patience_steps::Int = 100
147151
patience_objective_multiplier::T2 = 3
148152
min_max_factor::T3 = 1.3
153+
max_stalled_steps::T4 = nothing
149154
end
150155

151156
@doc doc"""
@@ -157,15 +162,16 @@ Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found
157162
158163
```julia
159164
RelSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100,
160-
patience_objective_multiplier = 3, min_max_factor = 1.3)
165+
patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing)
161166
```
162167
"""
163-
Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3} <:
168+
Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <:
164169
AbstractSafeBestNonlinearTerminationMode
165170
protective_threshold::T1 = nothing
166171
patience_steps::Int = 100
167172
patience_objective_multiplier::T2 = 3
168173
min_max_factor::T3 = 1.3
174+
max_stalled_steps::T4 = nothing
169175
end
170176

171177
@doc doc"""
@@ -177,21 +183,23 @@ Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found
177183
178184
```julia
179185
AbsSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100,
180-
patience_objective_multiplier = 3, min_max_factor = 1.3)
186+
patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing)
181187
```
182188
"""
183-
Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} <:
189+
Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <:
184190
AbstractSafeBestNonlinearTerminationMode
185191
protective_threshold::T1 = nothing
186192
patience_steps::Int = 100
187193
patience_objective_multiplier::T2 = 3
188194
min_max_factor::T3 = 1.3
195+
max_stalled_steps::T4 = nothing
189196
end
190197

191-
mutable struct NonlinearTerminationModeCache{uType, T,
192-
M <: AbstractNonlinearTerminationMode, I, OT, SV}
198+
mutable struct NonlinearTerminationModeCache{uType, T, dep_retcode,
199+
M <: AbstractNonlinearTerminationMode, I, OT, SV,
200+
R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T}, UN, ST, MSS}
193201
u::uType
194-
retcode::NonlinearSafeTerminationReturnCode.T
202+
retcode::R
195203
abstol::T
196204
reltol::T
197205
best_objective_value::T
@@ -200,6 +208,10 @@ mutable struct NonlinearTerminationModeCache{uType, T,
200208
objectives_trace::OT
201209
nsteps::Int
202210
saved_values::SV
211+
u0_norm::UN
212+
step_norm_trace::ST
213+
max_stalled_steps::MSS
214+
u_diff_cache::uType
203215
end
204216

205217
get_termination_mode(cache::NonlinearTerminationModeCache) = cache.mode
@@ -227,7 +239,8 @@ end
227239

228240
function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T},
229241
mode::AbstractNonlinearTerminationMode, saved_value_prototype...;
230-
abstol = nothing, reltol = nothing, kwargs...) where {T <: Number}
242+
use_deprecated_retcodes::Val{D} = Val(true), # Remove in v8, warn in v7
243+
abstol = nothing, reltol = nothing, kwargs...) where {D, T <: Number}
231244
abstol = _get_tolerance(abstol, T)
232245
reltol = _get_tolerance(reltol, T)
233246
TT = typeof(abstol)
@@ -236,24 +249,75 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
236249
if mode isa AbstractSafeNonlinearTerminationMode
237250
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
238251
initial_objective = maximum(abs, du)
252+
u0_norm = nothing
239253
else
240254
initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
255+
u0_norm = mode.max_stalled_steps === nothing ? nothing : norm(u, 2)
241256
end
242257
objectives_trace = Vector{TT}(undef, mode.patience_steps)
258+
step_norm_trace = mode.max_stalled_steps === nothing ? nothing :
259+
Vector{TT}(undef, mode.max_stalled_steps)
243260
best_value = initial_objective
261+
max_stalled_steps = mode.max_stalled_steps
262+
if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && step_norm_trace !== nothing
263+
u_diff_cache = similar(u_)
264+
else
265+
u_diff_cache = u_
266+
end
244267
else
245268
initial_objective = nothing
246269
objectives_trace = nothing
270+
u0_norm = nothing
271+
step_norm_trace = nothing
247272
best_value = __cvt_real(T, Inf)
273+
max_stalled_steps = nothing
274+
u_diff_cache = u_
248275
end
249276

250277
length(saved_value_prototype) == 0 && (saved_value_prototype = nothing)
251278

252-
return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode),
253-
typeof(initial_objective), typeof(objectives_trace),
254-
typeof(saved_value_prototype)}(u_, NonlinearSafeTerminationReturnCode.Default,
255-
abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0,
256-
saved_value_prototype)
279+
retcode = ifelse(D, NonlinearSafeTerminationReturnCode.Default, ReturnCode.Default)
280+
281+
return NonlinearTerminationModeCache{typeof(u_), TT, D, typeof(mode),
282+
typeof(initial_objective), typeof(objectives_trace), typeof(saved_value_prototype),
283+
typeof(retcode), typeof(u0_norm), typeof(step_norm_trace),
284+
typeof(max_stalled_steps)}(u_, retcode, abstol, reltol, best_value, mode,
285+
initial_objective, objectives_trace, 0, saved_value_prototype, u0_norm,
286+
step_norm_trace, max_stalled_steps, u_diff_cache)
287+
end
288+
289+
function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du,
290+
u, saved_value_prototype...; abstol = nothing, reltol = nothing,
291+
kwargs...) where {uType, T, dep_retcode}
292+
length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
293+
294+
u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
295+
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
296+
cache.u = u_
297+
cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
298+
ReturnCode.Default)
299+
300+
cache.abstol = _get_tolerance(abstol, T)
301+
cache.reltol = _get_tolerance(reltol, T)
302+
cache.nsteps = 0
303+
304+
mode = get_termination_mode(cache)
305+
if mode isa AbstractSafeNonlinearTerminationMode
306+
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
307+
initial_objective = maximum(abs, du)
308+
else
309+
initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
310+
cache.max_stalled_steps !== nothing && (cache.u0_norm = norm(u_, 2))
311+
end
312+
best_value = initial_objective
313+
else
314+
initial_objective = nothing
315+
objectives_trace = nothing
316+
best_value = __cvt_real(T, Inf)
317+
end
318+
cache.best_objective_value = best_value
319+
cache.initial_objective = initial_objective
320+
return cache
257321
end
258322

259323
# This dispatch is needed based on how Terminating Callback works!
@@ -273,8 +337,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminati
273337
return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol)
274338
end
275339

276-
function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode,
277-
du, u, uprev, args...)
340+
function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::AbstractSafeNonlinearTerminationMode,
341+
du, u, uprev, args...) where {uType, TT, dep_retcode}
278342
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
279343
objective = maximum(abs, du)
280344
criteria = cache.abstol
@@ -285,13 +349,15 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
285349

286350
# Protective Break
287351
if isinf(objective) || isnan(objective)
288-
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
352+
cache.retcode = ifelse(dep_retcode,
353+
NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable)
289354
return true
290355
end
291356
## By default we turn this off since it has the potential for false positives
292357
if cache.mode.protective_threshold !== nothing &&
293358
(objective > cache.initial_objective * cache.mode.protective_threshold * length(du))
294-
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
359+
cache.retcode = ifelse(dep_retcode,
360+
NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable)
295361
return true
296362
end
297363

@@ -307,7 +373,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
307373

308374
# Main Termination Condition
309375
if objective criteria
310-
cache.retcode = NonlinearSafeTerminationReturnCode.Success
376+
cache.retcode = ifelse(dep_retcode,
377+
NonlinearSafeTerminationReturnCode.Success, ReturnCode.Success)
311378
return true
312379
end
313380

@@ -324,13 +391,43 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
324391
min_obj, max_obj = extrema(cache.objectives_trace)
325392
end
326393
if min_obj < cache.mode.min_max_factor * max_obj
327-
cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination
394+
cache.retcode = ifelse(dep_retcode,
395+
NonlinearSafeTerminationReturnCode.PatienceTermination,
396+
ReturnCode.Stalled)
397+
return true
398+
end
399+
end
400+
end
401+
402+
# Test for stalling if that is not disabled
403+
if cache.step_norm_trace !== nothing
404+
if ArrayInterface.can_setindex(cache.u_diff_cache) && !(u isa Number)
405+
@. cache.u_diff_cache = u - uprev
406+
else
407+
cache.u_diff_cache = u .- uprev
408+
end
409+
du_norm = norm(cache.u_diff_cache, 2)
410+
cache.step_norm_trace[mod1(cache.nsteps, length(cache.step_norm_trace))] = du_norm
411+
if cache.nsteps cache.mode.max_stalled_steps
412+
max_step_norm = maximum(cache.step_norm_trace)
413+
if cache.mode isa AbsSafeTerminationMode ||
414+
cache.mode isa AbsSafeBestTerminationMode
415+
stalled_step = max_step_norm cache.abstol
416+
else
417+
stalled_step = max_step_norm
418+
cache.reltol * (max_step_norm + cache.u0_norm)
419+
end
420+
if stalled_step
421+
cache.retcode = ifelse(dep_retcode,
422+
NonlinearSafeTerminationReturnCode.PatienceTermination,
423+
ReturnCode.Stalled)
328424
return true
329425
end
330426
end
331427
end
332428

333-
cache.retcode = NonlinearSafeTerminationReturnCode.Failure
429+
cache.retcode = ifelse(dep_retcode,
430+
NonlinearSafeTerminationReturnCode.Failure, ReturnCode.Failure)
334431
return false
335432
end
336433

0 commit comments

Comments
 (0)