Skip to content

Commit d8f1e0c

Browse files
refactor: format
1 parent 59d36e4 commit d8f1e0c

File tree

4 files changed

+44
-35
lines changed

4 files changed

+44
-35
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
4242
end
4343
end
4444

45-
function (M::MXLinearInterpolation)(τ)
45+
function (M::MXLinearInterpolation)(τ)
4646
nt =- M.t[1]) / M.dt
4747
i = 1 + floor(Int, nt)
4848
Δ = nt - i + 1
4949

5050
(i > length(M.t) || i < 1) && error("Cannot extrapolate past the tspan.")
5151
if i < length(M.t)
52-
M.u[:, i] + Δ*(M.u[:, i + 1] - M.u[:, i])
52+
M.u[:, i] + Δ * (M.u[:, i + 1] - M.u[:, i])
5353
else
5454
M.u[:, i]
5555
end
@@ -74,7 +74,7 @@ The constraints are:
7474
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
7575
dt = nothing,
7676
steps = nothing,
77-
guesses = Dict(), kwargs...)
77+
guesses = Dict(), kwargs...)
7878
MTK.warn_overdetermined(sys, u0map)
7979
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
8080
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
@@ -104,21 +104,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
104104
subject_to!(opti, tₛ >= lo)
105105
subject_to!(opti, tₛ >= hi)
106106
end
107-
pmap[te_sym] = tₛ
107+
pmap[te_sym] = tₛ
108108
tsteps = LinRange(0, 1, steps)
109109
else
110110
tₛ = MX(1)
111111
tsteps = LinRange(tspan[1], tspan[2], steps)
112112
end
113-
113+
114114
U = CasADi.variable!(opti, length(states), steps)
115115
V = CasADi.variable!(opti, length(ctrls), steps)
116116
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
117117
c0 = MTK.value.([pmap[c] for c in ctrls])
118118
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))
119119

120-
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2]-tsteps[1])
121-
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2]-tsteps[1])
120+
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
121+
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
122122
for (i, ct) in enumerate(ctrls)
123123
pmap[ct] = V[i, :]
124124
end
@@ -185,8 +185,8 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
185185
x = MTK.operation(st)
186186
t = only(MTK.arguments(st))
187187
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
188-
if haskey(stidxmap, x(iv))
189-
idx = stidxmap[x(iv)]
188+
if haskey(stidxmap, x(iv))
189+
idx = stidxmap[x(iv)]
190190
cv = U
191191
else
192192
idx = ctidxmap[x(iv)]
@@ -196,11 +196,11 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
196196
end
197197

198198
if cons isa Equation
199-
subject_to!(opti, cons.lhs - cons.rhs==0)
199+
subject_to!(opti, cons.lhs - cons.rhs == 0)
200200
elseif cons.relational_op === Symbolics.geq
201-
subject_to!(opti, cons.lhs - cons.rhs0)
201+
subject_to!(opti, cons.lhs - cons.rhs 0)
202202
else
203-
subject_to!(opti, cons.lhs - cons.rhs0)
203+
subject_to!(opti, cons.lhs - cons.rhs 0)
204204
end
205205
end
206206
end
@@ -227,8 +227,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
227227
x = operation(st)
228228
t = only(arguments(st))
229229
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
230-
if haskey(stidxmap, x(iv))
231-
idx = stidxmap[x(iv)]
230+
if haskey(stidxmap, x(iv))
231+
idx = stidxmap[x(iv)]
232232
cv = U
233233
else
234234
idx = ctidxmap[x(iv)]
@@ -244,7 +244,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
244244
op = MTK.operation(int)
245245
arg = only(arguments(MTK.value(int)))
246246
lo, hi = (op.domain.domain.left, op.domain.domain.right)
247-
!isequal((lo, hi), tspan) && error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
247+
!isequal((lo, hi), tspan) &&
248+
error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
248249
# Approximate integral as sum.
249250
intmap[int] = dt * tₛ * sum(arg)
250251
end
@@ -253,7 +254,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
253254
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
254255
end
255256

256-
function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
257+
function substitute_casadi_vars(
258+
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
257259
@unpack opti, U, V, tₛ = model
258260
iv = MTK.get_iv(sys)
259261
sts = unknowns(sys)
@@ -281,44 +283,44 @@ end
281283

282284
function add_solve_constraints(prob, tableau)
283285
@unpack A, α, c = tableau
284-
@unpack model, f, p = prob
286+
@unpack model, f, p = prob
285287
@unpack opti, U, V, tₛ = model
286288
solver_opti = copy(opti)
287289

288-
tsteps = U.t
290+
tsteps = U.t
289291
dt = tsteps[2] - tsteps[1]
290292

291293
nᵤ = size(U.u, 1)
292294
nᵥ = size(V.u, 1)
293295

294296
if MTK.is_explicit(tableau)
295297
K = MX[]
296-
for k in 1:length(tsteps)-1
298+
for k in 1:(length(tsteps) - 1)
297299
τ = tsteps[k]
298300
for (i, h) in enumerate(c)
299301
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ)))
300-
Uₙ = U.u[:, k] + ΔU*dt
302+
Uₙ = U.u[:, k] + ΔU * dt
301303
Vₙ = V.u[:, k]
302304
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
303305
push!(K, Kₙ)
304306
end
305307
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
306-
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k+1])
308+
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k + 1])
307309
empty!(K)
308310
end
309311
else
310-
for k in 1:length(tsteps)-1
312+
for k in 1:(length(tsteps) - 1)
311313
τ = tsteps[k]
312314
Kᵢ = variable!(solver_opti, nᵤ, length(α))
313315
ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method
314316
for (i, h) in enumerate(c)
315-
ΔU = ΔUs[i,:]'
316-
Uₙ = U.u[:,k] + ΔU*dt
317-
Vₙ = V.u[:,k]
318-
subject_to!(solver_opti, Kᵢ[:,i] == tₛ * f(Uₙ, Vₙ, p, τ + h*dt))
317+
ΔU = ΔUs[i, :]'
318+
Uₙ = U.u[:, k] + ΔU * dt
319+
Vₙ = V.u[:, k]
320+
subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt))
319321
end
320-
ΔU_tot = dt*(Kᵢ*α)
321-
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:,k+1])
322+
ΔU_tot = dt * (Kᵢ * α)
323+
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1])
322324
end
323325
end
324326
solver_opti
@@ -331,7 +333,10 @@ end
331333
332334
NOTE: the solver should be passed in as a string to CasADi. "ipopt"
333335
"""
334-
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt", tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(), solver_options::Dict = Dict(), silent = false)
336+
function DiffEqBase.solve(
337+
prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt",
338+
tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(),
339+
solver_options::Dict = Dict(), silent = false)
335340
@unpack model, u0, p, tspan, f = prob
336341
tableau = tableau_getter()
337342
@unpack opti, U, V, tₛ = model
@@ -366,7 +371,8 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
366371
end
367372

368373
if failed
369-
ode_sol = SciMLBase.solution_new_retcode(ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
374+
ode_sol = SciMLBase.solution_new_retcode(
375+
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
370376
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
371377
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
372378
end

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ export AnalysisPoint, get_sensitivity_function, get_comp_sensitivity_function,
349349
function FMIComponent end
350350

351351
include("systems/optimal_control_interface.jl")
352-
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem, CasADiDynamicOptProblem
352+
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem,
353+
CasADiDynamicOptProblem
353354
export DynamicOptSolution
354355

355356
end # module

src/systems/optimal_control_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function constructDefault(T::Type = Float64)
4141
A = map(T, A)
4242
α = map(T, α)
4343
c = map(T, c)
44-
44+
4545
DiffEqBase.ImplicitRKTableau(A, c, α, 5)
4646
end
4747

test/extensions/dynamic_optimization.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ const M = ModelingToolkit
6161
@test jsol.sol(0.6)[1] 3.5
6262
@test jsol.sol(0.3)[1] 7.0
6363

64-
cprob = CasADiDynamicOptProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
64+
cprob = CasADiDynamicOptProblem(
65+
lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
6566
csol = solve(cprob, "ipopt", constructTsitouras5, silent = true)
6667
@test csol.sol(0.6)[1] 3.5
6768
@test csol.sol(0.3)[1] 7.0
@@ -87,7 +88,8 @@ const M = ModelingToolkit
8788
jsol = solve(jprob, Ipopt.Optimizer, constructRadauIA3, silent = true) # 12.190 s, 9.68 GiB
8889
@test all(u -> u > [1, 1], jsol.sol.u)
8990

90-
cprob = CasADiDynamicOptProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
91+
cprob = CasADiDynamicOptProblem(
92+
lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
9193
csol = solve(cprob, "ipopt", constructRadauIA3, silent = true)
9294
@test all(u -> u > [1, 1], csol.sol.u)
9395
end
@@ -220,7 +222,7 @@ end
220222
jprob = JuMPDynamicOptProblem(rocket, u0map, (ts, te), pmap; dt = 0.001, cse = false)
221223
jsol = solve(jprob, Ipopt.Optimizer, constructRadauIIA5, silent = true)
222224
@test jsol.sol.u[end][1] > 1.012
223-
225+
224226
cprob = CasADiDynamicOptProblem(rocket, u0map, (ts, te), pmap; dt = 0.001, cse = false)
225227
csol = solve(cprob, "ipopt"; silent = true)
226228
@test csol.sol.u[end][1] > 1.012

0 commit comments

Comments
 (0)