Skip to content

Commit a4fa9b5

Browse files
committed
Implement Stalling and Use ReturnCode
1 parent efad806 commit a4fa9b5

File tree

2 files changed

+122
-27
lines changed

2 files changed

+122
-27
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.145.6"
4+
version = "6.146.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -79,7 +79,7 @@ PrecompileTools = "1"
7979
Printf = "1.9"
8080
RecursiveArrayTools = "2, 3"
8181
Reexport = "1.0"
82-
SciMLBase = "2.12.0"
82+
SciMLBase = "2.18.0"
8383
SciMLOperators = "0.2, 0.3"
8484
Setfield = "0.8, 1"
8585
SparseArrays = "1.9"

src/termination_conditions.jl

Lines changed: 120 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
NonlinearSafeTerminationReturnCode
33
44
Return Codes for the safe nonlinear termination conditions.
5+
6+
These return codes have been deprecated. Termination Conditions will return
7+
`SciMLBase.Retcode.T` starting from v7.
58
"""
69
@enumx NonlinearSafeTerminationReturnCode begin
710
"""
@@ -116,15 +119,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges).
116119
117120
```julia
118121
RelSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100,
119-
patience_objective_multiplier = 3, min_max_factor = 1.3)
122+
patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20)
120123
```
121124
"""
122-
Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3} <:
125+
Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <:
123126
AbstractSafeNonlinearTerminationMode
124127
protective_threshold::T1 = nothing
125128
patience_steps::Int = 100
126129
patience_objective_multiplier::T2 = 3
127130
min_max_factor::T3 = 1.3
131+
max_stalled_steps::T4 = nothing
128132
end
129133

130134
@doc doc"""
@@ -137,15 +141,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges).
137141
138142
```julia
139143
AbsSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100,
140-
patience_objective_multiplier = 3, min_max_factor = 1.3)
144+
patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20)
141145
```
142146
"""
143-
Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3} <:
147+
Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <:
144148
AbstractSafeNonlinearTerminationMode
145149
protective_threshold::T1 = nothing
146150
patience_steps::Int = 100
147151
patience_objective_multiplier::T2 = 3
148152
min_max_factor::T3 = 1.3
153+
max_stalled_steps::T4 = nothing
149154
end
150155

151156
@doc doc"""
@@ -157,15 +162,16 @@ Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found
157162
158163
```julia
159164
RelSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100,
160-
patience_objective_multiplier = 3, min_max_factor = 1.3)
165+
patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20)
161166
```
162167
"""
163-
Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3} <:
168+
Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <:
164169
AbstractSafeBestNonlinearTerminationMode
165170
protective_threshold::T1 = nothing
166171
patience_steps::Int = 100
167172
patience_objective_multiplier::T2 = 3
168173
min_max_factor::T3 = 1.3
174+
max_stalled_steps::T4 = nothing
169175
end
170176

171177
@doc doc"""
@@ -177,21 +183,23 @@ Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found
177183
178184
```julia
179185
AbsSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100,
180-
patience_objective_multiplier = 3, min_max_factor = 1.3)
186+
patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20)
181187
```
182188
"""
183-
Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} <:
189+
Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <:
184190
AbstractSafeBestNonlinearTerminationMode
185191
protective_threshold::T1 = nothing
186192
patience_steps::Int = 100
187193
patience_objective_multiplier::T2 = 3
188194
min_max_factor::T3 = 1.3
195+
max_stalled_steps::T4 = nothing
189196
end
190197

191-
mutable struct NonlinearTerminationModeCache{uType, T,
192-
M <: AbstractNonlinearTerminationMode, I, OT, SV}
198+
mutable struct NonlinearTerminationModeCache{uType, T, dep_retcode,
199+
M <: AbstractNonlinearTerminationMode, I, OT, SV,
200+
R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T}, UN, ST, MSS}
193201
u::uType
194-
retcode::NonlinearSafeTerminationReturnCode.T
202+
retcode::R
195203
abstol::T
196204
reltol::T
197205
best_objective_value::T
@@ -200,6 +208,10 @@ mutable struct NonlinearTerminationModeCache{uType, T,
200208
objectives_trace::OT
201209
nsteps::Int
202210
saved_values::SV
211+
u0_norm::UN
212+
step_norm_trace::ST
213+
max_stalled_steps::MSS
214+
u_diff_cache::uType
203215
end
204216

205217
get_termination_mode(cache::NonlinearTerminationModeCache) = cache.mode
@@ -227,7 +239,8 @@ end
227239

228240
function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T},
229241
mode::AbstractNonlinearTerminationMode, saved_value_prototype...;
230-
abstol = nothing, reltol = nothing, kwargs...) where {T <: Number}
242+
use_deprecated_retcodes::Val{D} = Val(true), # Remove in v8, warn in v7
243+
abstol = nothing, reltol = nothing, kwargs...) where {D, T <: Number}
231244
abstol = _get_tolerance(abstol, T)
232245
reltol = _get_tolerance(reltol, T)
233246
TT = typeof(abstol)
@@ -236,25 +249,74 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
236249
if mode isa AbstractSafeNonlinearTerminationMode
237250
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
238251
initial_objective = maximum(abs, du)
252+
u0_norm = nothing
239253
else
240254
initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
255+
u0_norm = mode.max_stalled_steps === nothing ? nothing : norm(u, 2)
241256
end
242257
objectives_trace = Vector{TT}(undef, mode.patience_steps)
258+
step_norm_trace = mode.max_stalled_steps === nothing ? nothing :
259+
Vector{TT}(undef, mode.max_stalled_steps)
243260
best_value = initial_objective
261+
max_stalled_steps = mode.max_stalled_steps
262+
if ArrayInterface.can_setindex(u_) && step_norm_trace !== nothing
263+
u_diff_cache = similar(u_)
264+
else
265+
u_diff_cache = u_
266+
end
244267
else
245268
initial_objective = nothing
246269
objectives_trace = nothing
270+
u0_norm = nothing
271+
step_norm_trace = nothing
247272
best_value = __cvt_real(T, Inf)
273+
max_stalled_steps = nothing
274+
u_diff_cache = u_
248275
end
249276

250277
length(saved_value_prototype) == 0 && (saved_value_prototype = nothing)
251278

252-
return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode),
253-
typeof(initial_objective), typeof(objectives_trace),
254-
typeof(saved_value_prototype)}(u_, NonlinearSafeTerminationReturnCode.Default,
255-
abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0,
256-
saved_value_prototype)
257-
end
279+
retcode = ifelse(D, NonlinearSafeTerminationReturnCode.Default, ReturnCode.Default)
280+
281+
return NonlinearTerminationModeCache{typeof(u_), TT, D, typeof(mode),
282+
typeof(initial_objective), typeof(objectives_trace), typeof(saved_value_prototype),
283+
typeof(retcode), typeof(u0_norm), typeof(step_norm_trace),
284+
typeof(max_stalled_steps)}(u_, retcode, abstol, reltol, best_value, mode,
285+
initial_objective, objectives_trace, 0, saved_value_prototype, u0_norm,
286+
step_norm_trace, max_stalled_steps, u_diff_cache)
287+
end
288+
289+
# function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du,
290+
# u, saved_value_prototype...; abstol = nothing, reltol = nothing,
291+
# kwargs...) where {uType, T, dep_retcode}
292+
# length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
293+
294+
# u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
295+
# (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
296+
# cache.u = u_
297+
# cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
298+
# ReturnCode.Default)
299+
300+
# cache.abstol = _get_tolerance(abstol, T)
301+
# cache.reltol = _get_tolerance(reltol, T)
302+
# cache.nsteps = 0
303+
304+
# if mode isa AbstractSafeNonlinearTerminationMode
305+
# if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
306+
# initial_objective = maximum(abs, du)
307+
# else
308+
# initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
309+
# end
310+
# best_value = initial_objective
311+
# else
312+
# initial_objective = nothing
313+
# objectives_trace = nothing
314+
# best_value = __cvt_real(T, Inf)
315+
# end
316+
# cache.best_objective_value = best_value
317+
# cache.initial_objective = initial_objective
318+
# return cache
319+
# end
258320

259321
# This dispatch is needed based on how Terminating Callback works!
260322
# This intentially drops the `abstol` and `reltol` arguments
@@ -273,8 +335,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminati
273335
return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol)
274336
end
275337

276-
function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode,
277-
du, u, uprev, args...)
338+
function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::AbstractSafeNonlinearTerminationMode,
339+
du, u, uprev, args...) where {uType, TT, dep_retcode}
278340
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
279341
objective = maximum(abs, du)
280342
criteria = cache.abstol
@@ -285,13 +347,15 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
285347

286348
# Protective Break
287349
if isinf(objective) || isnan(objective)
288-
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
350+
cache.retcode = ifelse(dep_retcode,
351+
NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable)
289352
return true
290353
end
291354
## By default we turn this off since it has the potential for false positives
292355
if cache.mode.protective_threshold !== nothing &&
293356
(objective > cache.initial_objective * cache.mode.protective_threshold * length(du))
294-
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
357+
cache.retcode = ifelse(dep_retcode,
358+
NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable)
295359
return true
296360
end
297361

@@ -307,7 +371,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
307371

308372
# Main Termination Condition
309373
if objective criteria
310-
cache.retcode = NonlinearSafeTerminationReturnCode.Success
374+
cache.retcode = ifelse(dep_retcode,
375+
NonlinearSafeTerminationReturnCode.Success, ReturnCode.Success)
311376
return true
312377
end
313378

@@ -324,13 +389,43 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
324389
min_obj, max_obj = extrema(cache.objectives_trace)
325390
end
326391
if min_obj < cache.mode.min_max_factor * max_obj
327-
cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination
392+
cache.retcode = ifelse(dep_retcode,
393+
NonlinearSafeTerminationReturnCode.PatienceTermination,
394+
ReturnCode.Stalled)
395+
return true
396+
end
397+
end
398+
end
399+
400+
# Test for stalling if that is not disabled
401+
if cache.step_norm_trace !== nothing
402+
if ArrayInterface.can_setindex(cache.u_diff_cache)
403+
@. cache.u_diff_cache = u - uprev
404+
else
405+
cache.u_diff_cache = u .- uprev
406+
end
407+
du_norm = norm(cache.u_diff_cache, 2)
408+
cache.step_norm_trace[mod1(cache.nsteps, length(cache.step_norm_trace))] = du_norm
409+
if cache.nsteps cache.mode.max_stalled_steps
410+
max_step_norm = maximum(cache.step_norm_trace)
411+
if cache.mode isa AbsSafeTerminationMode ||
412+
cache.mode isa AbsSafeBestTerminationMode
413+
stalled_step = max_step_norm cache.abstol
414+
else
415+
stalled_step = max_step_norm
416+
cache.reltol * (max_step_norm + cache.u0_norm)
417+
end
418+
if stalled_step
419+
cache.retcode = ifelse(dep_retcode,
420+
NonlinearSafeTerminationReturnCode.PatienceTermination,
421+
ReturnCode.Stalled)
328422
return true
329423
end
330424
end
331425
end
332426

333-
cache.retcode = NonlinearSafeTerminationReturnCode.Failure
427+
cache.retcode = ifelse(dep_retcode,
428+
NonlinearSafeTerminationReturnCode.Failure, ReturnCode.Failure)
334429
return false
335430
end
336431

0 commit comments

Comments
 (0)