Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,60 @@ function modify_dt_for_tstops!(integrator)
if has_tstop(integrator)
tdir_t = integrator.tdir * integrator.t
tdir_tstop = first_tstop(integrator)
distance_to_tstop = abs(tdir_tstop - tdir_t)

# Store the original dt to check if it gets significantly reduced
original_dt = abs(integrator.dt)

if integrator.opts.adaptive
integrator.dt = integrator.tdir *
min(abs(integrator.dt), abs(tdir_tstop - tdir_t)) # step! to the end
if original_dt < distance_to_tstop
# Normal step, no tstop interference
integrator.next_step_tstop = false
else
# Distance is smaller, entering tstop snap mode
integrator.next_step_tstop = true
integrator.tstop_target = integrator.tdir * tdir_tstop
end
integrator.dt = integrator.tdir * min(original_dt, distance_to_tstop)
elseif iszero(integrator.dtcache) && integrator.dtchangeable
integrator.dt = integrator.tdir * abs(tdir_tstop - tdir_t)
integrator.dt = integrator.tdir * distance_to_tstop
integrator.next_step_tstop = true
integrator.tstop_target = integrator.tdir * tdir_tstop
elseif integrator.dtchangeable && !integrator.force_stepfail
# always try to step! with dtcache, but lower if a tstop
# however, if force_stepfail then don't set to dtcache, and no tstop worry
integrator.dt = integrator.tdir *
min(abs(integrator.dtcache), abs(tdir_tstop - tdir_t)) # step! to the end
if abs(integrator.dtcache) < distance_to_tstop
# Normal step with dtcache, no tstop interference
integrator.next_step_tstop = false
else
# Distance is smaller, entering tstop snap mode
integrator.next_step_tstop = true
integrator.tstop_target = integrator.tdir * tdir_tstop
end
integrator.dt = integrator.tdir * min(abs(integrator.dtcache), distance_to_tstop)
else
integrator.next_step_tstop = false
end
else
integrator.next_step_tstop = false
end
end

function handle_tstop_step!(integrator)
# Check if dt is extremely small (< eps(t))
eps_threshold = eps(abs(integrator.t))

if abs(integrator.dt) < eps_threshold
# Skip perform_step! entirely for tiny dt
integrator.accept_step = true
else
# Normal step
perform_step!(integrator, integrator.cache)
end

# Flag will be reset in fixed_t_for_floatingpoint_error! when t is updated
end

# Want to extend savevalues! for DDEIntegrator
function savevalues!(integrator::ODEIntegrator, force_save = false, reduce_size = true)
_savevalues!(integrator, force_save, reduce_size)
Expand Down Expand Up @@ -328,6 +368,13 @@ function log_step!(progress_name, progress_id, progress_message, dt, u, p, t, ts
end

function fixed_t_for_floatingpoint_error!(integrator, ttmp)
# If we're in tstop snap mode, use exact tstop target
if integrator.next_step_tstop
# Reset the flag now that we're snapping to tstop
integrator.next_step_tstop = false
return integrator.tstop_target
end

if has_tstop(integrator)
tstop = integrator.tdir * first_tstop(integrator)
if abs(ttmp - tstop) <
Expand Down
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqCore/src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori
force_stepfail::Bool
last_stepfail::Bool
just_hit_tstop::Bool
next_step_tstop::Bool
tstop_target::tType
do_error_check::Bool
event_last_time::Int
vector_event_last_time::Int
Expand Down
13 changes: 11 additions & 2 deletions lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ function SciMLBase.__init(
u_modified = false
EEst = EEstT(1)
just_hit_tstop = false
next_step_tstop = false
tstop_target = zero(t)
isout = false
accept_step = false
force_stepfail = false
Expand Down Expand Up @@ -540,7 +542,7 @@ function SciMLBase.__init(
callback_cache,
kshortsize, force_stepfail,
last_stepfail,
just_hit_tstop, do_error_check,
just_hit_tstop, next_step_tstop, tstop_target, do_error_check,
event_last_time,
vector_event_last_time,
last_event_error, accept_step,
Expand Down Expand Up @@ -607,7 +609,14 @@ function SciMLBase.solve!(integrator::ODEIntegrator)
if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success
return integrator.sol
end
perform_step!(integrator, integrator.cache)

# Use special tstop handling if flag is set, otherwise normal stepping
if integrator.next_step_tstop
handle_tstop_step!(integrator)
else
perform_step!(integrator, integrator.cache)
end

loopfooter!(integrator)
if isempty(integrator.opts.tstops)
break
Expand Down
265 changes: 264 additions & 1 deletion test/interface/ode_tstops_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq, Test, Random
using OrdinaryDiffEq, Test, Random, StaticArrays, DiffEqCallbacks
import ODEProblemLibrary: prob_ode_linear
Random.seed!(100)

Expand Down Expand Up @@ -88,3 +88,266 @@ end
sol2 = solve(prob2, Tsit5())
@test 0.0:0.07:1.0 ⊆ sol2.t
end

@testset "Tstop Overshoot and Dense Time Event Tests" begin
# Tests for issue #2752: tstop overshoot errors with StaticArrays

@testset "StaticArrays vs Arrays with extreme precision" begin
# Test the specific case that was failing: extreme precision + StaticArrays
function precise_dynamics(u, p, t)
x = @view u[1:2]
v = @view u[3:4]

# Electromagnetic-like dynamics
dv = -0.01 * x + 1e-6 * sin(100*t) * SVector{2}(1, 1)

return SVector{4}(v[1], v[2], dv[1], dv[2])
end

function precise_dynamics_array!(du, u, p, t)
x = @view u[1:2]
v = @view u[3:4]

dv = -0.01 * x + 1e-6 * sin(100*t) * [1, 1]
du[1] = v[1]
du[2] = v[2]
du[3] = dv[1]
du[4] = dv[2]
end

# Initial conditions
u0_static = SVector{4}(1.0, -0.5, 0.01, 0.01)
u0_array = [1.0, -0.5, 0.01, 0.01]
tspan = (0.0, 2.0)
tstops = [0.5, 1.0, 1.5]

# Test with extreme tolerances that originally caused issues
prob_static = ODEProblem(precise_dynamics, u0_static, tspan)
sol_static = solve(prob_static, Vern9(); reltol=1e-12, abstol=1e-15,
tstops=tstops, save_everystep=false)
@test sol_static.retcode == :Success
for tstop in tstops
@test tstop ∈ sol_static.t
end

prob_array = ODEProblem(precise_dynamics_array!, u0_array, tspan)
sol_array = solve(prob_array, Vern9(); reltol=1e-12, abstol=1e-15,
tstops=tstops, save_everystep=false)
@test sol_array.retcode == :Success
for tstop in tstops
@test tstop ∈ sol_array.t
end

# Solutions should be very close despite different array types
@test isapprox(sol_static(2.0), sol_array(2.0), rtol=1e-10)
end

@testset "Duplicate tstops handling" begin
function simple_ode(u, p, t)
[0.1 * u[1]]
end

u0 = SVector{1}(1.0)
tspan = (0.0, 2.0)

# Test multiple identical tstops - should all be processed
duplicate_tstops = [0.5, 0.5, 0.5, 1.0, 1.0]

prob = ODEProblem(simple_ode, u0, tspan)
sol = solve(prob, Vern9(); tstops=duplicate_tstops)

@test sol.retcode == :Success

# Count how many times each tstop appears in solution
count_05 = count(t -> abs(t - 0.5) < 1e-12, sol.t)
count_10 = count(t -> abs(t - 1.0) < 1e-12, sol.t)

# Should handle all duplicate tstops (though may not save all due to deduplication)
@test count_05 >= 1 # At least one 0.5
@test count_10 >= 1 # At least one 1.0

# Test with StaticArrays too
prob_static = ODEProblem(simple_ode, u0, tspan)
sol_static = solve(prob_static, Vern9(); tstops=duplicate_tstops)
@test sol_static.retcode == :Success
end

@testset "PresetTimeCallback with identical times" begin
# Test PresetTimeCallback scenarios where callbacks are set at same times as tstops

event_times = Float64[]
callback_times = Float64[]

function affect_preset!(integrator)
push!(callback_times, integrator.t)
integrator.u[1] += 0.1 # Small modification
end

function simple_growth(u, p, t)
[0.1 * u[1]]
end

u0 = SVector{1}(1.0)
tspan = (0.0, 3.0)

# Define times where both tstops and callbacks should trigger
critical_times = [0.5, 1.0, 1.5, 2.0, 2.5]

# Create PresetTimeCallback at the same times as tstops
preset_cb = PresetTimeCallback(critical_times, affect_preset!)

prob = ODEProblem(simple_growth, u0, tspan)
sol = solve(prob, Vern9(); tstops=critical_times, callback=preset_cb,
reltol=1e-10, abstol=1e-12)

@test sol.retcode == :Success

# Verify all tstops were hit
for time in critical_times
@test any(abs.(sol.t .- time) .< 1e-10)
end

# Verify all callbacks were triggered
@test length(callback_times) == length(critical_times)
for time in critical_times
@test any(abs.(callback_times .- time) .< 1e-10)
end

# Test the same with regular arrays
u0_array = [1.0]
callback_times_array = Float64[]

function affect_preset_array!(integrator)
push!(callback_times_array, integrator.t)
integrator.u[1] += 0.1
end

function simple_growth_array!(du, u, p, t)
du[1] = 0.1 * u[1]
end

preset_cb_array = PresetTimeCallback(critical_times, affect_preset_array!)

prob_array = ODEProblem(simple_growth_array!, u0_array, tspan)
sol_array = solve(prob_array, Vern9(); tstops=critical_times, callback=preset_cb_array,
reltol=1e-10, abstol=1e-12)

@test sol_array.retcode == :Success
@test length(callback_times_array) == length(critical_times)

# Both should have triggered all events
@test length(callback_times) == length(callback_times_array) == length(critical_times)
end

@testset "Tiny tstop step handling" begin
# Test cases where tstop is very close to current time (dt < eps(t))
function test_ode(u, p, t)
[0.01 * u[1]]
end

u0 = SVector{1}(1.0)
tspan = (0.0, 1.0)

# Create tstop very close to start time (would cause tiny dt)
tiny_tstops = [1e-15, 1e-14, 1e-13]

for tiny_tstop in tiny_tstops
prob = ODEProblem(test_ode, u0, tspan)
sol = solve(prob, Vern9(); tstops=[tiny_tstop], save_everystep=false)

@test sol.retcode == :Success
@test any(abs.(sol.t .- tiny_tstop) .< 1e-14) # Should handle tiny tstop correctly
end
end

@testset "Multiple close tstops with StaticArrays" begin
# Test with multiple tstops that are very close together - stress test the flag logic
function oscillator(u, p, t)
SVector{2}(u[2], -u[1]) # Simple harmonic oscillator
end

u0 = SVector{2}(1.0, 0.0)
tspan = (0.0, 4.0)

# Multiple tstops close together (within floating-point precision range)
close_tstops = [1.0, 1.0 + 1e-14, 1.0 + 2e-14, 1.0 + 5e-14,
2.0, 2.0 + 1e-15, 2.0 + 1e-14,
3.0, 3.0 + 1e-13]

prob = ODEProblem(oscillator, u0, tspan)
sol = solve(prob, Vern9(); tstops=close_tstops, reltol=1e-12, abstol=1e-15)

@test sol.retcode == :Success

# Should handle all close tstops without error
# (Some might be deduplicated, but no errors should occur)
unique_times = [1.0, 2.0, 3.0]
for time in unique_times
@test any(abs.(sol.t .- time) .< 1e-10) # At least hit the main times
end
end

@testset "Backward integration with tstop flags" begin
# Test that the fix works for backward time integration
function decay_ode(u, p, t)
[-0.1 * u[1]]
end

u0 = SVector{1}(1.0)
tspan = (2.0, 0.0) # Backward integration
tstops = [1.5, 1.0, 0.5]

prob = ODEProblem(decay_ode, u0, tspan)
sol = solve(prob, Vern9(); tstops=tstops, reltol=1e-12, abstol=1e-15)

@test sol.retcode == :Success
for tstop in tstops
@test tstop ∈ sol.t
end
end

@testset "Continuous callbacks during tstop steps" begin
# Test that continuous callbacks work properly with tstop flag mechanism

crossing_times = Float64[]

function affect_continuous!(integrator)
push!(crossing_times, integrator.t)
end

function condition_continuous(u, t, integrator)
u[1] - 0.5 # Crosses when u[1] = 0.5
end

function exponential_growth(u, p, t)
[0.2 * u[1]] # Exponential growth
end

u0 = SVector{1}(0.1) # Start below 0.5
tspan = (0.0, 10.0)
tstops = [2.0, 4.0, 6.0, 8.0] # Regular tstops

continuous_cb = ContinuousCallback(condition_continuous, affect_continuous!)

prob = ODEProblem(exponential_growth, u0, tspan)
sol = solve(prob, Vern9(); tstops=tstops, callback=continuous_cb,
reltol=1e-10, abstol=1e-12)

@test sol.retcode == :Success

# Should hit all tstops
for tstop in tstops
@test tstop ∈ sol.t
end

# Should also detect continuous callback crossings
@test length(crossing_times) > 0 # At least one crossing detected

# Verify crossings are at correct value
for crossing_time in crossing_times
u_at_crossing = sol(crossing_time)
@test abs(u_at_crossing[1] - 0.5) < 1e-8 # Should be very close to 0.5
end
end

end
Loading