@@ -259,7 +259,7 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
259
259
Vector {TT} (undef, mode. max_stalled_steps)
260
260
best_value = initial_objective
261
261
max_stalled_steps = mode. max_stalled_steps
262
- if ArrayInterface. can_setindex (u_) && step_norm_trace != = nothing
262
+ if ArrayInterface. can_setindex (u_) && ! (u_ isa Number) && step_norm_trace != = nothing
263
263
u_diff_cache = similar (u_)
264
264
else
265
265
u_diff_cache = u_
@@ -286,37 +286,38 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
286
286
step_norm_trace, max_stalled_steps, u_diff_cache)
287
287
end
288
288
289
- # function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du,
290
- # u, saved_value_prototype...; abstol = nothing, reltol = nothing,
291
- # kwargs...) where {uType, T, dep_retcode}
292
- # length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
293
-
294
- # u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
295
- # (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
296
- # cache.u = u_
297
- # cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
298
- # ReturnCode.Default)
299
-
300
- # cache.abstol = _get_tolerance(abstol, T)
301
- # cache.reltol = _get_tolerance(reltol, T)
302
- # cache.nsteps = 0
303
-
304
- # if mode isa AbstractSafeNonlinearTerminationMode
305
- # if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
306
- # initial_objective = maximum(abs, du)
307
- # else
308
- # initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
309
- # end
310
- # best_value = initial_objective
311
- # else
312
- # initial_objective = nothing
313
- # objectives_trace = nothing
314
- # best_value = __cvt_real(T, Inf)
315
- # end
316
- # cache.best_objective_value = best_value
317
- # cache.initial_objective = initial_objective
318
- # return cache
319
- # end
289
+ function SciMLBase. reinit! (cache:: NonlinearTerminationModeCache{uType, T, dep_retcode} , du,
290
+ u, saved_value_prototype... ; abstol = nothing , reltol = nothing ,
291
+ kwargs... ) where {uType, T, dep_retcode}
292
+ length (saved_value_prototype) != 0 && (cache. saved_values = saved_value_prototype)
293
+
294
+ u_ = cache. mode isa AbstractSafeBestNonlinearTerminationMode ?
295
+ (ArrayInterface. can_setindex (u) ? copy (u) : u) : nothing
296
+ cache. u = u_
297
+ cache. retcode = ifelse (dep_retcode, NonlinearSafeTerminationReturnCode. Default,
298
+ ReturnCode. Default)
299
+
300
+ cache. abstol = _get_tolerance (abstol, T)
301
+ cache. reltol = _get_tolerance (reltol, T)
302
+ cache. nsteps = 0
303
+
304
+ if mode isa AbstractSafeNonlinearTerminationMode
305
+ if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
306
+ initial_objective = maximum (abs, du)
307
+ else
308
+ initial_objective = maximum (abs, du) / (maximum (abs, du .+ u) + eps (TT))
309
+ cache. max_stalled_steps != = nothing && (cache. u0_norm = norm (u_, 2 ))
310
+ end
311
+ best_value = initial_objective
312
+ else
313
+ initial_objective = nothing
314
+ objectives_trace = nothing
315
+ best_value = __cvt_real (T, Inf )
316
+ end
317
+ cache. best_objective_value = best_value
318
+ cache. initial_objective = initial_objective
319
+ return cache
320
+ end
320
321
321
322
# This dispatch is needed based on how Terminating Callback works!
322
323
# This intentially drops the `abstol` and `reltol` arguments
@@ -399,7 +400,7 @@ function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::Ab
399
400
400
401
# Test for stalling if that is not disabled
401
402
if cache. step_norm_trace != = nothing
402
- if ArrayInterface. can_setindex (cache. u_diff_cache)
403
+ if ArrayInterface. can_setindex (cache. u_diff_cache) && ! (u isa Number)
403
404
@. cache. u_diff_cache = u - uprev
404
405
else
405
406
cache. u_diff_cache = u .- uprev
0 commit comments