Skip to content

Commit 1e2db4c

Browse files
committed
Add a reinit function
1 parent a4fa9b5 commit 1e2db4c

File tree

1 file changed

+34
-33
lines changed

1 file changed

+34
-33
lines changed

src/termination_conditions.jl

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
259259
Vector{TT}(undef, mode.max_stalled_steps)
260260
best_value = initial_objective
261261
max_stalled_steps = mode.max_stalled_steps
262-
if ArrayInterface.can_setindex(u_) && step_norm_trace !== nothing
262+
if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && step_norm_trace !== nothing
263263
u_diff_cache = similar(u_)
264264
else
265265
u_diff_cache = u_
@@ -286,37 +286,38 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
286286
step_norm_trace, max_stalled_steps, u_diff_cache)
287287
end
288288

289-
# function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du,
290-
# u, saved_value_prototype...; abstol = nothing, reltol = nothing,
291-
# kwargs...) where {uType, T, dep_retcode}
292-
# length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
293-
294-
# u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
295-
# (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
296-
# cache.u = u_
297-
# cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
298-
# ReturnCode.Default)
299-
300-
# cache.abstol = _get_tolerance(abstol, T)
301-
# cache.reltol = _get_tolerance(reltol, T)
302-
# cache.nsteps = 0
303-
304-
# if mode isa AbstractSafeNonlinearTerminationMode
305-
# if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
306-
# initial_objective = maximum(abs, du)
307-
# else
308-
# initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
309-
# end
310-
# best_value = initial_objective
311-
# else
312-
# initial_objective = nothing
313-
# objectives_trace = nothing
314-
# best_value = __cvt_real(T, Inf)
315-
# end
316-
# cache.best_objective_value = best_value
317-
# cache.initial_objective = initial_objective
318-
# return cache
319-
# end
289+
function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du,
290+
u, saved_value_prototype...; abstol = nothing, reltol = nothing,
291+
kwargs...) where {uType, T, dep_retcode}
292+
length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
293+
294+
u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
295+
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
296+
cache.u = u_
297+
cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
298+
ReturnCode.Default)
299+
300+
cache.abstol = _get_tolerance(abstol, T)
301+
cache.reltol = _get_tolerance(reltol, T)
302+
cache.nsteps = 0
303+
304+
if mode isa AbstractSafeNonlinearTerminationMode
305+
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
306+
initial_objective = maximum(abs, du)
307+
else
308+
initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
309+
cache.max_stalled_steps !== nothing && (cache.u0_norm = norm(u_, 2))
310+
end
311+
best_value = initial_objective
312+
else
313+
initial_objective = nothing
314+
objectives_trace = nothing
315+
best_value = __cvt_real(T, Inf)
316+
end
317+
cache.best_objective_value = best_value
318+
cache.initial_objective = initial_objective
319+
return cache
320+
end
320321

321322
# This dispatch is needed based on how Terminating Callback works!
322323
# This intentially drops the `abstol` and `reltol` arguments
@@ -399,7 +400,7 @@ function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::Ab
399400

400401
# Test for stalling if that is not disabled
401402
if cache.step_norm_trace !== nothing
402-
if ArrayInterface.can_setindex(cache.u_diff_cache)
403+
if ArrayInterface.can_setindex(cache.u_diff_cache) && !(u isa Number)
403404
@. cache.u_diff_cache = u - uprev
404405
else
405406
cache.u_diff_cache = u .- uprev

0 commit comments

Comments
 (0)