2
2
NonlinearSafeTerminationReturnCode
3
3
4
4
Return Codes for the safe nonlinear termination conditions.
5
+
6
+ These return codes have been deprecated. Termination Conditions will return
7
+ `SciMLBase.Retcode.T` starting from v7.
5
8
"""
6
9
@enumx NonlinearSafeTerminationReturnCode begin
7
10
"""
@@ -116,15 +119,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges).
116
119
117
120
```julia
118
121
RelSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100,
119
- patience_objective_multiplier = 3, min_max_factor = 1.3)
122
+ patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing )
120
123
```
121
124
"""
122
- Base. @kwdef struct RelSafeTerminationMode{T1, T2, T3} < :
125
+ Base. @kwdef struct RelSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int} } < :
123
126
AbstractSafeNonlinearTerminationMode
124
127
protective_threshold:: T1 = nothing
125
128
patience_steps:: Int = 100
126
129
patience_objective_multiplier:: T2 = 3
127
130
min_max_factor:: T3 = 1.3
131
+ max_stalled_steps:: T4 = nothing
128
132
end
129
133
130
134
@doc doc"""
@@ -137,15 +141,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges).
137
141
138
142
```julia
139
143
AbsSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100,
140
- patience_objective_multiplier = 3, min_max_factor = 1.3)
144
+ patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing )
141
145
```
142
146
"""
143
- Base. @kwdef struct AbsSafeTerminationMode{T1, T2, T3} < :
147
+ Base. @kwdef struct AbsSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int} } < :
144
148
AbstractSafeNonlinearTerminationMode
145
149
protective_threshold:: T1 = nothing
146
150
patience_steps:: Int = 100
147
151
patience_objective_multiplier:: T2 = 3
148
152
min_max_factor:: T3 = 1.3
153
+ max_stalled_steps:: T4 = nothing
149
154
end
150
155
151
156
@doc doc"""
@@ -157,15 +162,16 @@ Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found
157
162
158
163
```julia
159
164
RelSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100,
160
- patience_objective_multiplier = 3, min_max_factor = 1.3)
165
+ patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing )
161
166
```
162
167
"""
163
- Base. @kwdef struct RelSafeBestTerminationMode{T1, T2, T3} < :
168
+ Base. @kwdef struct RelSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int} } < :
164
169
AbstractSafeBestNonlinearTerminationMode
165
170
protective_threshold:: T1 = nothing
166
171
patience_steps:: Int = 100
167
172
patience_objective_multiplier:: T2 = 3
168
173
min_max_factor:: T3 = 1.3
174
+ max_stalled_steps:: T4 = nothing
169
175
end
170
176
171
177
@doc doc"""
@@ -177,21 +183,23 @@ Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found
177
183
178
184
```julia
179
185
AbsSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100,
180
- patience_objective_multiplier = 3, min_max_factor = 1.3)
186
+ patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing )
181
187
```
182
188
"""
183
- Base. @kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} < :
189
+ Base. @kwdef struct AbsSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int} } < :
184
190
AbstractSafeBestNonlinearTerminationMode
185
191
protective_threshold:: T1 = nothing
186
192
patience_steps:: Int = 100
187
193
patience_objective_multiplier:: T2 = 3
188
194
min_max_factor:: T3 = 1.3
195
+ max_stalled_steps:: T4 = nothing
189
196
end
190
197
191
- mutable struct NonlinearTerminationModeCache{uType, T,
192
- M <: AbstractNonlinearTerminationMode , I, OT, SV}
198
+ mutable struct NonlinearTerminationModeCache{uType, T, dep_retcode,
199
+ M <: AbstractNonlinearTerminationMode , I, OT, SV,
200
+ R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T} , UN, ST, MSS}
193
201
u:: uType
194
- retcode:: NonlinearSafeTerminationReturnCode.T
202
+ retcode:: R
195
203
abstol:: T
196
204
reltol:: T
197
205
best_objective_value:: T
@@ -200,6 +208,10 @@ mutable struct NonlinearTerminationModeCache{uType, T,
200
208
objectives_trace:: OT
201
209
nsteps:: Int
202
210
saved_values:: SV
211
+ u0_norm:: UN
212
+ step_norm_trace:: ST
213
+ max_stalled_steps:: MSS
214
+ u_diff_cache:: uType
203
215
end
204
216
205
217
get_termination_mode (cache:: NonlinearTerminationModeCache ) = cache. mode
227
239
228
240
function SciMLBase. init (du:: Union{AbstractArray{T}, T} , u:: Union{AbstractArray{T}, T} ,
229
241
mode:: AbstractNonlinearTerminationMode , saved_value_prototype... ;
230
- abstol = nothing , reltol = nothing , kwargs... ) where {T <: Number }
242
+ use_deprecated_retcodes:: Val{D} = Val (true ), # Remove in v8, warn in v7
243
+ abstol = nothing , reltol = nothing , kwargs... ) where {D, T <: Number }
231
244
abstol = _get_tolerance (abstol, T)
232
245
reltol = _get_tolerance (reltol, T)
233
246
TT = typeof (abstol)
@@ -236,24 +249,75 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
236
249
if mode isa AbstractSafeNonlinearTerminationMode
237
250
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
238
251
initial_objective = maximum (abs, du)
252
+ u0_norm = nothing
239
253
else
240
254
initial_objective = maximum (abs, du) / (maximum (abs, du .+ u) + eps (TT))
255
+ u0_norm = mode. max_stalled_steps === nothing ? nothing : norm (u, 2 )
241
256
end
242
257
objectives_trace = Vector {TT} (undef, mode. patience_steps)
258
+ step_norm_trace = mode. max_stalled_steps === nothing ? nothing :
259
+ Vector {TT} (undef, mode. max_stalled_steps)
243
260
best_value = initial_objective
261
+ max_stalled_steps = mode. max_stalled_steps
262
+ if ArrayInterface. can_setindex (u_) && ! (u_ isa Number) && step_norm_trace != = nothing
263
+ u_diff_cache = similar (u_)
264
+ else
265
+ u_diff_cache = u_
266
+ end
244
267
else
245
268
initial_objective = nothing
246
269
objectives_trace = nothing
270
+ u0_norm = nothing
271
+ step_norm_trace = nothing
247
272
best_value = __cvt_real (T, Inf )
273
+ max_stalled_steps = nothing
274
+ u_diff_cache = u_
248
275
end
249
276
250
277
length (saved_value_prototype) == 0 && (saved_value_prototype = nothing )
251
278
252
- return NonlinearTerminationModeCache{typeof (u_), TT, typeof (mode),
253
- typeof (initial_objective), typeof (objectives_trace),
254
- typeof (saved_value_prototype)}(u_, NonlinearSafeTerminationReturnCode. Default,
255
- abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0 ,
256
- saved_value_prototype)
279
+ retcode = ifelse (D, NonlinearSafeTerminationReturnCode. Default, ReturnCode. Default)
280
+
281
+ return NonlinearTerminationModeCache{typeof (u_), TT, D, typeof (mode),
282
+ typeof (initial_objective), typeof (objectives_trace), typeof (saved_value_prototype),
283
+ typeof (retcode), typeof (u0_norm), typeof (step_norm_trace),
284
+ typeof (max_stalled_steps)}(u_, retcode, abstol, reltol, best_value, mode,
285
+ initial_objective, objectives_trace, 0 , saved_value_prototype, u0_norm,
286
+ step_norm_trace, max_stalled_steps, u_diff_cache)
287
+ end
288
+
289
+ function SciMLBase. reinit! (cache:: NonlinearTerminationModeCache{uType, T, dep_retcode} , du,
290
+ u, saved_value_prototype... ; abstol = nothing , reltol = nothing ,
291
+ kwargs... ) where {uType, T, dep_retcode}
292
+ length (saved_value_prototype) != 0 && (cache. saved_values = saved_value_prototype)
293
+
294
+ u_ = cache. mode isa AbstractSafeBestNonlinearTerminationMode ?
295
+ (ArrayInterface. can_setindex (u) ? copy (u) : u) : nothing
296
+ cache. u = u_
297
+ cache. retcode = ifelse (dep_retcode, NonlinearSafeTerminationReturnCode. Default,
298
+ ReturnCode. Default)
299
+
300
+ cache. abstol = _get_tolerance (abstol, T)
301
+ cache. reltol = _get_tolerance (reltol, T)
302
+ cache. nsteps = 0
303
+
304
+ mode = get_termination_mode (cache)
305
+ if mode isa AbstractSafeNonlinearTerminationMode
306
+ if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
307
+ initial_objective = maximum (abs, du)
308
+ else
309
+ initial_objective = maximum (abs, du) / (maximum (abs, du .+ u) + eps (TT))
310
+ cache. max_stalled_steps != = nothing && (cache. u0_norm = norm (u_, 2 ))
311
+ end
312
+ best_value = initial_objective
313
+ else
314
+ initial_objective = nothing
315
+ objectives_trace = nothing
316
+ best_value = __cvt_real (T, Inf )
317
+ end
318
+ cache. best_objective_value = best_value
319
+ cache. initial_objective = initial_objective
320
+ return cache
257
321
end
258
322
259
323
# This dispatch is needed based on how Terminating Callback works!
@@ -273,8 +337,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminati
273
337
return check_convergence (mode, du, u, uprev, cache. abstol, cache. reltol)
274
338
end
275
339
276
- function (cache:: NonlinearTerminationModeCache )(mode:: AbstractSafeNonlinearTerminationMode ,
277
- du, u, uprev, args... )
340
+ function (cache:: NonlinearTerminationModeCache{uType, TT, dep_retcode} )(mode:: AbstractSafeNonlinearTerminationMode ,
341
+ du, u, uprev, args... ) where {uType, TT, dep_retcode}
278
342
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
279
343
objective = maximum (abs, du)
280
344
criteria = cache. abstol
@@ -285,13 +349,15 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
285
349
286
350
# Protective Break
287
351
if isinf (objective) || isnan (objective)
288
- cache. retcode = NonlinearSafeTerminationReturnCode. ProtectiveTermination
352
+ cache. retcode = ifelse (dep_retcode,
353
+ NonlinearSafeTerminationReturnCode. ProtectiveTermination, ReturnCode. Unstable)
289
354
return true
290
355
end
291
356
# # By default we turn this off since it has the potential for false positives
292
357
if cache. mode. protective_threshold != = nothing &&
293
358
(objective > cache. initial_objective * cache. mode. protective_threshold * length (du))
294
- cache. retcode = NonlinearSafeTerminationReturnCode. ProtectiveTermination
359
+ cache. retcode = ifelse (dep_retcode,
360
+ NonlinearSafeTerminationReturnCode. ProtectiveTermination, ReturnCode. Unstable)
295
361
return true
296
362
end
297
363
@@ -307,7 +373,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
307
373
308
374
# Main Termination Condition
309
375
if objective ≤ criteria
310
- cache. retcode = NonlinearSafeTerminationReturnCode. Success
376
+ cache. retcode = ifelse (dep_retcode,
377
+ NonlinearSafeTerminationReturnCode. Success, ReturnCode. Success)
311
378
return true
312
379
end
313
380
@@ -324,13 +391,43 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
324
391
min_obj, max_obj = extrema (cache. objectives_trace)
325
392
end
326
393
if min_obj < cache. mode. min_max_factor * max_obj
327
- cache. retcode = NonlinearSafeTerminationReturnCode. PatienceTermination
394
+ cache. retcode = ifelse (dep_retcode,
395
+ NonlinearSafeTerminationReturnCode. PatienceTermination,
396
+ ReturnCode. Stalled)
397
+ return true
398
+ end
399
+ end
400
+ end
401
+
402
+ # Test for stalling if that is not disabled
403
+ if cache. step_norm_trace != = nothing
404
+ if ArrayInterface. can_setindex (cache. u_diff_cache) && ! (u isa Number)
405
+ @. cache. u_diff_cache = u - uprev
406
+ else
407
+ cache. u_diff_cache = u .- uprev
408
+ end
409
+ du_norm = norm (cache. u_diff_cache, 2 )
410
+ cache. step_norm_trace[mod1 (cache. nsteps, length (cache. step_norm_trace))] = du_norm
411
+ if cache. nsteps ≥ cache. mode. max_stalled_steps
412
+ max_step_norm = maximum (cache. step_norm_trace)
413
+ if cache. mode isa AbsSafeTerminationMode ||
414
+ cache. mode isa AbsSafeBestTerminationMode
415
+ stalled_step = max_step_norm ≤ cache. abstol
416
+ else
417
+ stalled_step = max_step_norm ≤
418
+ cache. reltol * (max_step_norm + cache. u0_norm)
419
+ end
420
+ if stalled_step
421
+ cache. retcode = ifelse (dep_retcode,
422
+ NonlinearSafeTerminationReturnCode. PatienceTermination,
423
+ ReturnCode. Stalled)
328
424
return true
329
425
end
330
426
end
331
427
end
332
428
333
- cache. retcode = NonlinearSafeTerminationReturnCode. Failure
429
+ cache. retcode = ifelse (dep_retcode,
430
+ NonlinearSafeTerminationReturnCode. Failure, ReturnCode. Failure)
334
431
return false
335
432
end
336
433
0 commit comments