Skip to content

Commit d455d58

Browse files
Merge pull request #3543 from AayushSabharwal/as/system
feat: add unified `System` type
2 parents ac22a7a + 43bb2e1 commit d455d58

File tree

142 files changed

+6977
-9838
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

142 files changed

+6977
-9838
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ BifurcationKit = "0.4"
9090
BlockArrays = "1.1"
9191
BoundaryValueDiffEqAscher = "1.6.0"
9292
BoundaryValueDiffEqMIRK = "1.7.0"
93-
CasADi = "1.0.6"
93+
CasADi = "1.0.7"
9494
ChainRulesCore = "1"
9595
Combinatorics = "1"
9696
CommonSolve = "0.2.4"

ext/MTKBifurcationKitExt.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module MTKBifurcationKitExt
55
# Imports
66
using ModelingToolkit, Setfield
77
import BifurcationKit
8+
using SymbolicIndexingInterface: is_time_dependent
89

910
### Observable Plotting Handling ###
1011

@@ -23,7 +24,7 @@ struct ObservableRecordFromSolution{S, T}
2324
# A Vector of pairs (Symbolic => value) with the default values of all system variables and parameters.
2425
subs_vals::T
2526

26-
function ObservableRecordFromSolution(nsys::NonlinearSystem,
27+
function ObservableRecordFromSolution(nsys::System,
2728
plot_var,
2829
bif_idx,
2930
u0_vals,
@@ -82,7 +83,7 @@ end
8283
### Creates BifurcationProblem Overloads ###
8384

8485
# When input is a NonlinearSystem.
85-
function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
86+
function BifurcationKit.BifurcationProblem(nsys::System,
8687
u0_bif,
8788
ps,
8889
bif_par,
@@ -92,7 +93,15 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
9293
jac = true,
9394
kwargs...)
9495
if !ModelingToolkit.iscomplete(nsys)
95-
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
96+
error("A completed `System` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
97+
end
98+
if is_time_dependent(nsys)
99+
nsys = System([0 ~ eq.rhs for eq in full_equations(nsys)],
100+
unknowns(nsys),
101+
parameters(nsys);
102+
observed = observed(nsys),
103+
name = nameof(nsys))
104+
nsys = complete(nsys)
96105
end
97106
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
98107
# Creates F and J functions.
@@ -143,17 +152,4 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
143152
kwargs...)
144153
end
145154

146-
# When input is a ODESystem.
147-
function BifurcationKit.BifurcationProblem(osys::ODESystem, args...; kwargs...)
148-
if !ModelingToolkit.iscomplete(osys)
149-
error("A completed `ODESystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
150-
end
151-
nsys = NonlinearSystem([0 ~ eq.rhs for eq in full_equations(osys)],
152-
unknowns(osys),
153-
parameters(osys);
154-
observed = observed(osys),
155-
name = nameof(osys))
156-
return BifurcationKit.BifurcationProblem(complete(nsys), args...; kwargs...)
157-
end
158-
159155
end # module

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ function (M::MXLinearInterpolation)(τ)
5656
end
5757

5858
"""
59-
CasADiDynamicOptProblem(sys::ODESystem, u0, tspan, p; dt, steps)
59+
CasADiDynamicOptProblem(sys::System, u0, tspan, p; dt, steps)
6060
61-
Convert an ODESystem representing an optimal control system into a CasADi model
61+
Convert an System representing an optimal control system into a CasADi model
6262
for solving using optimization. Must provide either `dt`, the timestep between collocation
6363
points (which, along with the timespan, determines the number of points), or directly
6464
provide the number of points as `steps`.
@@ -68,10 +68,10 @@ The optimization variables:
6868
- a vector-of-vectors V representing the controls as an interpolation array
6969
7070
The constraints are:
71-
- The set of user constraints passed to the ODESystem via `constraints`
71+
- The set of user constraints passed to the System via `constraints`
7272
- The solver constraints that encode the time-stepping used by the solver
7373
"""
74-
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
74+
function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
7575
dt = nothing,
7676
steps = nothing,
7777
guesses = Dict(), kwargs...)
@@ -80,7 +80,8 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
8080
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
8181
t = tspan !== nothing ? tspan[1] : tspan, output_type = MX, kwargs...)
8282

83-
pmap = Dict{Any, Any}(pmap)
83+
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
84+
MTK.evaluate_varmap!(pmap, keys(pmap))
8485
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
8586
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
8687

@@ -143,15 +144,15 @@ function set_casadi_bounds!(model, sys, pmap)
143144
for (i, u) in enumerate(unknowns(sys))
144145
if MTK.hasbounds(u)
145146
lo, hi = MTK.getbounds(u)
146-
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= U.u[i, :])
147-
subject_to!(opti, U.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
147+
subject_to!(opti, Symbolics.fast_substitute(lo, pmap) <= U.u[i, :])
148+
subject_to!(opti, U.u[i, :] <= Symbolics.fast_substitute(hi, pmap))
148149
end
149150
end
150151
for (i, v) in enumerate(MTK.unbound_inputs(sys))
151152
if MTK.hasbounds(v)
152153
lo, hi = MTK.getbounds(v)
153-
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= V.u[i, :])
154-
subject_to!(opti, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
154+
subject_to!(opti, Symbolics.fast_substitute(lo, pmap) <= V.u[i, :])
155+
subject_to!(opti, V.u[i, :] <= Symbolics.fast_substitute(hi, pmap))
155156
end
156157
end
157158
end
@@ -167,15 +168,15 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
167168
@unpack opti, U, V, tₛ = model
168169

169170
iv = MTK.get_iv(sys)
170-
conssys = MTK.get_constraintsystem(sys)
171-
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
171+
jconstraints = MTK.get_constraints(sys)
172172
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
173173

174174
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
175175
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
176-
cons_unknowns = map(MTK.default_toterm, unknowns(conssys))
176+
cons_dvs, cons_ps = MTK.process_constraint_system(
177+
jconstraints, Set(unknowns(sys)), parameters(sys), iv; validate = false)
177178

178-
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
179+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in cons_dvs])
179180
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints; is_free_t, auxmap)
180181
# Manually substitute fixed-t variables
181182
for (i, cons) in enumerate(jconstraints)
@@ -207,9 +208,8 @@ end
207208

208209
function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
209210
@unpack opti, U, V, tₛ = model
210-
jcosts = copy(MTK.get_costs(sys))
211-
consolidate = MTK.get_consolidate(sys)
212-
if isnothing(jcosts) || isempty(jcosts)
211+
jcosts = cost(sys)
212+
if Symbolics._iszero(jcosts)
213213
minimize!(opti, MX(0))
214214
return
215215
end
@@ -218,24 +218,22 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
218218
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
219219
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
220220

221-
jcosts = substitute_casadi_vars(model, sys, pmap, jcosts; is_free_t)
221+
jcosts = substitute_casadi_vars(model, sys, pmap, [jcosts]; is_free_t)[1]
222222
# Substitute fixed-time variables.
223-
for i in 1:length(jcosts)
224-
costvars = MTK.vars(jcosts[i])
225-
for st in costvars
226-
MTK.iscall(st) || continue
227-
x = operation(st)
228-
t = only(arguments(st))
229-
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
230-
if haskey(stidxmap, x(iv))
231-
idx = stidxmap[x(iv)]
232-
cv = U
233-
else
234-
idx = ctidxmap[x(iv)]
235-
cv = V
236-
end
237-
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => cv(t)[idx]))
223+
costvars = MTK.vars(jcosts)
224+
for st in costvars
225+
MTK.iscall(st) || continue
226+
x = operation(st)
227+
t = only(arguments(st))
228+
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
229+
if haskey(stidxmap, x(iv))
230+
idx = stidxmap[x(iv)]
231+
cv = U
232+
else
233+
idx = ctidxmap[x(iv)]
234+
cv = V
238235
end
236+
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => cv(t)[idx]))
239237
end
240238

241239
dt = U.t[2] - U.t[1]
@@ -249,9 +247,9 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
249247
# Approximate integral as sum.
250248
intmap[int] = dt * tₛ * sum(arg)
251249
end
252-
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
253-
jcosts = MTK.value.(jcosts)
254-
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
250+
jcosts = Symbolics.substitute(jcosts, intmap)
251+
jcosts = MTK.value(jcosts)
252+
minimize!(opti, MX(jcosts))
255253
end
256254

257255
function substitute_casadi_vars(
@@ -264,20 +262,20 @@ function substitute_casadi_vars(
264262
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
265263
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
266264

267-
exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
268-
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
265+
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
266+
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
269267
# tf means different things in different contexts; a [tf] in a cost function
270268
# should be tₛ, while a x(tf) should translate to x[1]
271269
if is_free_t
272270
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
273271
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
274-
exprs = map(c -> Symbolics.fixpoint_sub(c, free_t_map), exprs)
272+
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
275273
end
276274

277275
# for variables like x(t)
278276
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
279277
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
280-
exprs = map(c -> Symbolics.fixpoint_sub(c, whole_interval_map), exprs)
278+
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
281279
exprs
282280
end
283281

ext/MTKFMIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
289289
end
290290

291291
eqs = [observed; diffeqs]
292-
return ODESystem(eqs, t, states, params; parameter_dependencies, defaults = defs,
292+
return System(eqs, t, states, params; parameter_dependencies, defaults = defs,
293293
discrete_events = [instance_management_callback], name, initialization_eqs)
294294
end
295295

0 commit comments

Comments
 (0)