2
2
NonlinearSafeTerminationReturnCode
3
3
4
4
Return Codes for the safe nonlinear termination conditions.
5
+
6
+ These return codes have been deprecated. Termination Conditions will return
7
+ `SciMLBase.Retcode.T` starting from v7.
5
8
"""
6
9
@enumx NonlinearSafeTerminationReturnCode begin
7
10
"""
@@ -116,15 +119,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges).
116
119
117
120
```julia
118
121
RelSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100,
119
- patience_objective_multiplier = 3, min_max_factor = 1.3)
122
+ patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20 )
120
123
```
121
124
"""
122
- Base. @kwdef struct RelSafeTerminationMode{T1, T2, T3} < :
125
+ Base. @kwdef struct RelSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int} } < :
123
126
AbstractSafeNonlinearTerminationMode
124
127
protective_threshold:: T1 = nothing
125
128
patience_steps:: Int = 100
126
129
patience_objective_multiplier:: T2 = 3
127
130
min_max_factor:: T3 = 1.3
131
+ max_stalled_steps:: T4 = nothing
128
132
end
129
133
130
134
@doc doc"""
@@ -137,15 +141,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges).
137
141
138
142
```julia
139
143
AbsSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100,
140
- patience_objective_multiplier = 3, min_max_factor = 1.3)
144
+ patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20 )
141
145
```
142
146
"""
143
- Base. @kwdef struct AbsSafeTerminationMode{T1, T2, T3} < :
147
+ Base. @kwdef struct AbsSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int} } < :
144
148
AbstractSafeNonlinearTerminationMode
145
149
protective_threshold:: T1 = nothing
146
150
patience_steps:: Int = 100
147
151
patience_objective_multiplier:: T2 = 3
148
152
min_max_factor:: T3 = 1.3
153
+ max_stalled_steps:: T4 = nothing
149
154
end
150
155
151
156
@doc doc"""
@@ -157,15 +162,16 @@ Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found
157
162
158
163
```julia
159
164
RelSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100,
160
- patience_objective_multiplier = 3, min_max_factor = 1.3)
165
+ patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20 )
161
166
```
162
167
"""
163
- Base. @kwdef struct RelSafeBestTerminationMode{T1, T2, T3} < :
168
+ Base. @kwdef struct RelSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int} } < :
164
169
AbstractSafeBestNonlinearTerminationMode
165
170
protective_threshold:: T1 = nothing
166
171
patience_steps:: Int = 100
167
172
patience_objective_multiplier:: T2 = 3
168
173
min_max_factor:: T3 = 1.3
174
+ max_stalled_steps:: T4 = nothing
169
175
end
170
176
171
177
@doc doc"""
@@ -177,21 +183,23 @@ Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found
177
183
178
184
```julia
179
185
AbsSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100,
180
- patience_objective_multiplier = 3, min_max_factor = 1.3)
186
+ patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20 )
181
187
```
182
188
"""
183
- Base. @kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} < :
189
+ Base. @kwdef struct AbsSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int} } < :
184
190
AbstractSafeBestNonlinearTerminationMode
185
191
protective_threshold:: T1 = nothing
186
192
patience_steps:: Int = 100
187
193
patience_objective_multiplier:: T2 = 3
188
194
min_max_factor:: T3 = 1.3
195
+ max_stalled_steps:: T4 = nothing
189
196
end
190
197
191
- mutable struct NonlinearTerminationModeCache{uType, T,
192
- M <: AbstractNonlinearTerminationMode , I, OT, SV}
198
+ mutable struct NonlinearTerminationModeCache{uType, T, dep_retcode,
199
+ M <: AbstractNonlinearTerminationMode , I, OT, SV,
200
+ R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T} , UN, ST, MSS}
193
201
u:: uType
194
- retcode:: NonlinearSafeTerminationReturnCode.T
202
+ retcode:: R
195
203
abstol:: T
196
204
reltol:: T
197
205
best_objective_value:: T
@@ -200,6 +208,10 @@ mutable struct NonlinearTerminationModeCache{uType, T,
200
208
objectives_trace:: OT
201
209
nsteps:: Int
202
210
saved_values:: SV
211
+ u0_norm:: UN
212
+ step_norm_trace:: ST
213
+ max_stalled_steps:: MSS
214
+ u_diff_cache:: uType
203
215
end
204
216
205
217
get_termination_mode (cache:: NonlinearTerminationModeCache ) = cache. mode
227
239
228
240
function SciMLBase. init (du:: Union{AbstractArray{T}, T} , u:: Union{AbstractArray{T}, T} ,
229
241
mode:: AbstractNonlinearTerminationMode , saved_value_prototype... ;
230
- abstol = nothing , reltol = nothing , kwargs... ) where {T <: Number }
242
+ use_deprecated_retcodes:: Val{D} = Val (true ), # Remove in v8, warn in v7
243
+ abstol = nothing , reltol = nothing , kwargs... ) where {D, T <: Number }
231
244
abstol = _get_tolerance (abstol, T)
232
245
reltol = _get_tolerance (reltol, T)
233
246
TT = typeof (abstol)
@@ -236,25 +249,74 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
236
249
if mode isa AbstractSafeNonlinearTerminationMode
237
250
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
238
251
initial_objective = maximum (abs, du)
252
+ u0_norm = nothing
239
253
else
240
254
initial_objective = maximum (abs, du) / (maximum (abs, du .+ u) + eps (TT))
255
+ u0_norm = mode. max_stalled_steps === nothing ? nothing : norm (u, 2 )
241
256
end
242
257
objectives_trace = Vector {TT} (undef, mode. patience_steps)
258
+ step_norm_trace = mode. max_stalled_steps === nothing ? nothing :
259
+ Vector {TT} (undef, mode. max_stalled_steps)
243
260
best_value = initial_objective
261
+ max_stalled_steps = mode. max_stalled_steps
262
+ if ArrayInterface. can_setindex (u_) && step_norm_trace != = nothing
263
+ u_diff_cache = similar (u_)
264
+ else
265
+ u_diff_cache = u_
266
+ end
244
267
else
245
268
initial_objective = nothing
246
269
objectives_trace = nothing
270
+ u0_norm = nothing
271
+ step_norm_trace = nothing
247
272
best_value = __cvt_real (T, Inf )
273
+ max_stalled_steps = nothing
274
+ u_diff_cache = u_
248
275
end
249
276
250
277
length (saved_value_prototype) == 0 && (saved_value_prototype = nothing )
251
278
252
- return NonlinearTerminationModeCache{typeof (u_), TT, typeof (mode),
253
- typeof (initial_objective), typeof (objectives_trace),
254
- typeof (saved_value_prototype)}(u_, NonlinearSafeTerminationReturnCode. Default,
255
- abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0 ,
256
- saved_value_prototype)
257
- end
279
+ retcode = ifelse (D, NonlinearSafeTerminationReturnCode. Default, ReturnCode. Default)
280
+
281
+ return NonlinearTerminationModeCache{typeof (u_), TT, D, typeof (mode),
282
+ typeof (initial_objective), typeof (objectives_trace), typeof (saved_value_prototype),
283
+ typeof (retcode), typeof (u0_norm), typeof (step_norm_trace),
284
+ typeof (max_stalled_steps)}(u_, retcode, abstol, reltol, best_value, mode,
285
+ initial_objective, objectives_trace, 0 , saved_value_prototype, u0_norm,
286
+ step_norm_trace, max_stalled_steps, u_diff_cache)
287
+ end
288
+
289
+ # function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du,
290
+ # u, saved_value_prototype...; abstol = nothing, reltol = nothing,
291
+ # kwargs...) where {uType, T, dep_retcode}
292
+ # length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
293
+
294
+ # u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
295
+ # (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
296
+ # cache.u = u_
297
+ # cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
298
+ # ReturnCode.Default)
299
+
300
+ # cache.abstol = _get_tolerance(abstol, T)
301
+ # cache.reltol = _get_tolerance(reltol, T)
302
+ # cache.nsteps = 0
303
+
304
+ # if mode isa AbstractSafeNonlinearTerminationMode
305
+ # if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
306
+ # initial_objective = maximum(abs, du)
307
+ # else
308
+ # initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
309
+ # end
310
+ # best_value = initial_objective
311
+ # else
312
+ # initial_objective = nothing
313
+ # objectives_trace = nothing
314
+ # best_value = __cvt_real(T, Inf)
315
+ # end
316
+ # cache.best_objective_value = best_value
317
+ # cache.initial_objective = initial_objective
318
+ # return cache
319
+ # end
258
320
259
321
# This dispatch is needed based on how Terminating Callback works!
260
322
# This intentially drops the `abstol` and `reltol` arguments
@@ -273,8 +335,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminati
273
335
return check_convergence (mode, du, u, uprev, cache. abstol, cache. reltol)
274
336
end
275
337
276
- function (cache:: NonlinearTerminationModeCache )(mode:: AbstractSafeNonlinearTerminationMode ,
277
- du, u, uprev, args... )
338
+ function (cache:: NonlinearTerminationModeCache{uType, TT, dep_retcode} )(mode:: AbstractSafeNonlinearTerminationMode ,
339
+ du, u, uprev, args... ) where {uType, TT, dep_retcode}
278
340
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
279
341
objective = maximum (abs, du)
280
342
criteria = cache. abstol
@@ -285,13 +347,15 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
285
347
286
348
# Protective Break
287
349
if isinf (objective) || isnan (objective)
288
- cache. retcode = NonlinearSafeTerminationReturnCode. ProtectiveTermination
350
+ cache. retcode = ifelse (dep_retcode,
351
+ NonlinearSafeTerminationReturnCode. ProtectiveTermination, ReturnCode. Unstable)
289
352
return true
290
353
end
291
354
# # By default we turn this off since it has the potential for false positives
292
355
if cache. mode. protective_threshold != = nothing &&
293
356
(objective > cache. initial_objective * cache. mode. protective_threshold * length (du))
294
- cache. retcode = NonlinearSafeTerminationReturnCode. ProtectiveTermination
357
+ cache. retcode = ifelse (dep_retcode,
358
+ NonlinearSafeTerminationReturnCode. ProtectiveTermination, ReturnCode. Unstable)
295
359
return true
296
360
end
297
361
@@ -307,7 +371,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
307
371
308
372
# Main Termination Condition
309
373
if objective ≤ criteria
310
- cache. retcode = NonlinearSafeTerminationReturnCode. Success
374
+ cache. retcode = ifelse (dep_retcode,
375
+ NonlinearSafeTerminationReturnCode. Success, ReturnCode. Success)
311
376
return true
312
377
end
313
378
@@ -324,13 +389,43 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
324
389
min_obj, max_obj = extrema (cache. objectives_trace)
325
390
end
326
391
if min_obj < cache. mode. min_max_factor * max_obj
327
- cache. retcode = NonlinearSafeTerminationReturnCode. PatienceTermination
392
+ cache. retcode = ifelse (dep_retcode,
393
+ NonlinearSafeTerminationReturnCode. PatienceTermination,
394
+ ReturnCode. Stalled)
395
+ return true
396
+ end
397
+ end
398
+ end
399
+
400
+ # Test for stalling if that is not disabled
401
+ if cache. step_norm_trace != = nothing
402
+ if ArrayInterface. can_setindex (cache. u_diff_cache)
403
+ @. cache. u_diff_cache = u - uprev
404
+ else
405
+ cache. u_diff_cache = u .- uprev
406
+ end
407
+ du_norm = norm (cache. u_diff_cache, 2 )
408
+ cache. step_norm_trace[mod1 (cache. nsteps, length (cache. step_norm_trace))] = du_norm
409
+ if cache. nsteps ≥ cache. mode. max_stalled_steps
410
+ max_step_norm = maximum (cache. step_norm_trace)
411
+ if cache. mode isa AbsSafeTerminationMode ||
412
+ cache. mode isa AbsSafeBestTerminationMode
413
+ stalled_step = max_step_norm ≤ cache. abstol
414
+ else
415
+ stalled_step = max_step_norm ≤
416
+ cache. reltol * (max_step_norm + cache. u0_norm)
417
+ end
418
+ if stalled_step
419
+ cache. retcode = ifelse (dep_retcode,
420
+ NonlinearSafeTerminationReturnCode. PatienceTermination,
421
+ ReturnCode. Stalled)
328
422
return true
329
423
end
330
424
end
331
425
end
332
426
333
- cache. retcode = NonlinearSafeTerminationReturnCode. Failure
427
+ cache. retcode = ifelse (dep_retcode,
428
+ NonlinearSafeTerminationReturnCode. Failure, ReturnCode. Failure)
334
429
return false
335
430
end
336
431
0 commit comments