Skip to content

Commit e63aad0

Browse files
Merge pull request #2207 from SciML/myb/dde
Add DDE support in `System`
2 parents 47d8f05 + d76d73f commit e63aad0

File tree

4 files changed

+210
-13
lines changed

4 files changed

+210
-13
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 151 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
120120
implicit_dae = false,
121121
ddvs = implicit_dae ? map(Differential(get_iv(sys)), dvs) :
122122
nothing,
123+
isdde = false,
123124
has_difference = false,
124125
kwargs...)
125-
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
126+
if isdde
127+
eqs = delay_to_function(sys)
128+
else
129+
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
130+
end
126131
if !implicit_dae
127132
check_operator_variables(eqs, Differential)
128133
check_lhs(eqs, Differential, Set(dvs))
@@ -136,15 +141,59 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
136141
p = map(x -> time_varying_as_func(value(x), sys), ps)
137142
t = get_iv(sys)
138143

139-
pre, sol_states = get_substitutions_and_solved_states(sys,
140-
no_postprocess = has_difference)
144+
if isdde
145+
build_function(rhss, u, DDE_HISTORY_FUN, p, t; kwargs...)
146+
else
147+
pre, sol_states = get_substitutions_and_solved_states(sys,
148+
no_postprocess = has_difference)
141149

142-
if implicit_dae
143-
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre, states = sol_states,
144-
kwargs...)
150+
if implicit_dae
151+
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre,
152+
states = sol_states,
153+
kwargs...)
154+
else
155+
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
156+
kwargs...)
157+
end
158+
end
159+
end
160+
161+
function isdelay(var, iv)
162+
iv === nothing && return false
163+
isvariable(var) || return false
164+
if istree(var) && !ModelingToolkit.isoperator(var, Symbolics.Operator)
165+
args = arguments(var)
166+
length(args) == 1 || return false
167+
isequal(args[1], iv) || return true
168+
end
169+
return false
170+
end
171+
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
172+
function delay_to_function(sys::AbstractODESystem)
173+
delay_to_function(full_equations(sys),
174+
get_iv(sys),
175+
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(states(sys))),
176+
parameters(sys),
177+
DDE_HISTORY_FUN)
178+
end
179+
function delay_to_function(eqs::Vector{<:Equation}, iv, sts, ps, h)
180+
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
181+
end
182+
function delay_to_function(eq::Equation, iv, sts, ps, h)
183+
delay_to_function(eq.lhs, iv, sts, ps, h) ~ delay_to_function(eq.rhs, iv, sts, ps, h)
184+
end
185+
function delay_to_function(expr, iv, sts, ps, h)
186+
if isdelay(expr, iv)
187+
v = operation(expr)
188+
time = arguments(expr)[1]
189+
idx = sts[v]
190+
return term(getindex, h(Sym{Any}(:ˍ₋arg3), time), idx, type = Real) # BIG BIG HACK
191+
elseif istree(expr)
192+
return similarterm(expr,
193+
operation(expr),
194+
map(x -> delay_to_function(x, iv, sts, ps, h), arguments(expr)))
145195
else
146-
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
147-
kwargs...)
196+
return expr
148197
end
149198
end
150199

@@ -485,6 +534,30 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
485534
observed = observedfun)
486535
end
487536

537+
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
538+
DDEFunction{true}(sys, args...; kwargs...)
539+
end
540+
541+
function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
542+
ps = parameters(sys), u0 = nothing;
543+
eval_module = @__MODULE__,
544+
checkbounds = false,
545+
kwargs...) where {iip}
546+
f_gen = generate_function(sys, dvs, ps; isdde = true,
547+
expression = Val{true},
548+
expression_module = eval_module, checkbounds = checkbounds,
549+
kwargs...)
550+
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
551+
f(u, p, h, t) = f_oop(u, p, h, t)
552+
f(du, u, p, h, t) = f_iip(du, u, p, h, t)
553+
554+
DDEFunction{iip}(f,
555+
sys = sys,
556+
syms = Symbol.(dvs),
557+
indepsym = Symbol(get_iv(sys)),
558+
paramsyms = Symbol.(ps))
559+
end
560+
488561
"""
489562
```julia
490563
ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
@@ -577,9 +650,14 @@ end
577650
"""
578651
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
579652
580-
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
653+
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
581654
"""
582-
function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union)
655+
function get_u0_p(sys,
656+
u0map,
657+
parammap;
658+
use_union = false,
659+
tofloat = !use_union,
660+
symbolic_u0 = false)
583661
eqs = equations(sys)
584662
dvs = states(sys)
585663
ps = parameters(sys)
@@ -588,7 +666,11 @@ function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union)
588666
defs = mergedefaults(defs, parammap, ps)
589667
defs = mergedefaults(defs, u0map, dvs)
590668

591-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
669+
if symbolic_u0
670+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
671+
else
672+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
673+
end
592674
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
593675
p = p === nothing ? SciMLBase.NullParameters() : p
594676
u0, p, defs
@@ -604,13 +686,14 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
604686
eval_expression = true,
605687
use_union = false,
606688
tofloat = !use_union,
689+
symbolic_u0 = false,
607690
kwargs...)
608691
eqs = equations(sys)
609692
dvs = states(sys)
610693
ps = parameters(sys)
611694
iv = get_iv(sys)
612695

613-
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
696+
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)
614697

615698
if implicit_dae && du0map !== nothing
616699
ddvs = map(Differential(iv), dvs)
@@ -802,6 +885,62 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
802885
end
803886
end
804887

888+
function generate_history(sys::AbstractODESystem, u0; kwargs...)
889+
build_function(u0, parameters(sys), get_iv(sys); expression = Val{false}, kwargs...)
890+
end
891+
892+
function DiffEqBase.DDEProblem(sys::AbstractODESystem, args...; kwargs...)
893+
DDEProblem{true}(sys, args...; kwargs...)
894+
end
895+
function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
896+
tspan = get_tspan(sys),
897+
parammap = DiffEqBase.NullParameters();
898+
callback = nothing,
899+
check_length = true,
900+
kwargs...) where {iip}
901+
has_difference = any(isdifferenceeq, equations(sys))
902+
f, u0, p = process_DEProblem(DDEFunction{iip}, sys, u0map, parammap;
903+
t = tspan !== nothing ? tspan[1] : tspan,
904+
has_difference = has_difference,
905+
symbolic_u0 = true,
906+
check_length, kwargs...)
907+
h_oop, h_iip = generate_history(sys, u0)
908+
h = h_oop
909+
u0 = h(p, tspan[1])
910+
cbs = process_events(sys; callback, has_difference, kwargs...)
911+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
912+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
913+
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
914+
if clock isa Clock
915+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
916+
else
917+
error("$clock is not a supported clock type.")
918+
end
919+
end
920+
if cbs === nothing
921+
if length(discrete_cbs) == 1
922+
cbs = only(discrete_cbs)
923+
else
924+
cbs = CallbackSet(discrete_cbs...)
925+
end
926+
else
927+
cbs = CallbackSet(cbs, discrete_cbs)
928+
end
929+
else
930+
svs = nothing
931+
end
932+
kwargs = filter_kwargs(kwargs)
933+
934+
kwargs1 = (;)
935+
if cbs !== nothing
936+
kwargs1 = merge(kwargs1, (callback = cbs,))
937+
end
938+
if svs !== nothing
939+
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
940+
end
941+
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
942+
end
943+
805944
"""
806945
```julia
807946
ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,

src/systems/diffeqs/odesystem.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,10 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
187187
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
188188

189189
iv′ = value(scalarize(iv))
190-
dvs′ = value.(scalarize(dvs))
191190
ps′ = value.(scalarize(ps))
192191
ctrl′ = value.(scalarize(controls))
192+
dvs′ = value.(scalarize(dvs))
193+
dvs′ = filter(x -> !isdelay(x, iv), dvs′)
193194

194195
if !(isempty(default_u0) && isempty(default_p))
195196
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
@@ -258,6 +259,10 @@ function ODESystem(eqs, iv = nothing; kwargs...)
258259
push!(algeeq, eq)
259260
end
260261
end
262+
for v in allstates
263+
isdelay(v, iv) || continue
264+
collect_vars!(allstates, ps, arguments(v)[1], iv)
265+
end
261266
algevars = setdiff(allstates, diffvars)
262267
# the orders here are very important!
263268
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,

src/systems/systemstructure.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ end
253253
function TearingState(sys; quick_cancel = false, check = true)
254254
sys = flatten(sys)
255255
ivs = independent_variables(sys)
256+
iv = length(ivs) == 1 ? ivs[1] : nothing
256257
eqs = copy(equations(sys))
257258
neqs = length(eqs)
258259
dervaridxs = OrderedSet{Int}()
@@ -287,6 +288,7 @@ function TearingState(sys; quick_cancel = false, check = true)
287288
isalgeq = true
288289
statevars = []
289290
for var in vars
291+
ModelingToolkit.isdelay(var, iv) && continue
290292
set_incidence = true
291293
@label ANOTHER_VAR
292294
_var, _ = var_from_nested_derivative(var)

test/dde.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using ModelingToolkit, DelayDiffEq, Test
2+
p0 = 0.2;
3+
q0 = 0.3;
4+
v0 = 1;
5+
d0 = 5;
6+
p1 = 0.2;
7+
q1 = 0.3;
8+
v1 = 1;
9+
d1 = 1;
10+
d2 = 1;
11+
beta0 = 1;
12+
beta1 = 1;
13+
tau = 1;
14+
function bc_model(du, u, h, p, t)
15+
du[1] = (v0 / (1 + beta0 * (h(p, t - tau)[3]^2))) * (p0 - q0) * u[1] - d0 * u[1]
16+
du[2] = (v0 / (1 + beta0 * (h(p, t - tau)[3]^2))) * (1 - p0 + q0) * u[1] +
17+
(v1 / (1 + beta1 * (h(p, t - tau)[3]^2))) * (p1 - q1) * u[2] - d1 * u[2]
18+
du[3] = (v1 / (1 + beta1 * (h(p, t - tau)[3]^2))) * (1 - p1 + q1) * u[2] - d2 * u[3]
19+
end
20+
lags = [tau]
21+
h(p, t) = ones(3)
22+
h2(p, t) = ones(3) .- t * q1 * 10
23+
tspan = (0.0, 10.0)
24+
u0 = [1.0, 1.0, 1.0]
25+
prob = DDEProblem(bc_model, u0, h, tspan, constant_lags = lags)
26+
alg = MethodOfSteps(Vern9())
27+
sol = solve(prob, alg, reltol = 1e-7, abstol = 1e-10)
28+
prob2 = DDEProblem(bc_model, u0, h2, tspan, constant_lags = lags)
29+
sol2 = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)
30+
31+
@parameters p0=0.2 p1=0.2 q0=0.3 q1=0.3 v0=1 v1=1 d0=5 d1=1 d2=1 beta0=1 beta1=1
32+
@variables t x₀(t) x₁(t) x₂(..)
33+
tau = 1
34+
D = Differential(t)
35+
eqs = [D(x₀) ~ (v0 / (1 + beta0 * (x₂(t - tau)^2))) * (p0 - q0) * x₀ - d0 * x₀
36+
D(x₁) ~ (v0 / (1 + beta0 * (x₂(t - tau)^2))) * (1 - p0 + q0) * x₀ +
37+
(v1 / (1 + beta1 * (x₂(t - tau)^2))) * (p1 - q1) * x₁ - d1 * x₁
38+
D(x₂(t)) ~ (v1 / (1 + beta1 * (x₂(t - tau)^2))) * (1 - p1 + q1) * x₁ - d2 * x₂(t)]
39+
@named sys = System(eqs)
40+
prob = DDEProblem(sys,
41+
[x₀ => 1.0, x₁ => 1.0, x₂(t) => 1.0],
42+
tspan,
43+
constant_lags = [tau])
44+
sol_mtk = solve(prob, alg, reltol = 1e-7, abstol = 1e-10)
45+
@test sol_mtk.u[end] sol.u[end]
46+
prob2 = DDEProblem(sys,
47+
[x₀ => 1.0 - t * q1 * 10, x₁ => 1.0 - t * q1 * 10, x₂(t) => 1.0 - t * q1 * 10],
48+
tspan,
49+
constant_lags = [tau])
50+
sol2_mtk = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)
51+
@test sol2_mtk.u[end] sol2.u[end]

0 commit comments

Comments
 (0)