Skip to content

refactor: make @constants create tunable = false parameters #3641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions src/constants.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import SymbolicUtils: symtype, term, hasmetadata, issym
struct MTKConstantCtx end

isconstant(x::Num) = isconstant(unwrap(x))
"""
Test whether `x` is a constant-type Sym.
"""
function isconstant(x)
x = unwrap(x)
x isa Symbolic && getmetadata(x, MTKConstantCtx, false)
x isa Symbolic && !getmetadata(x, VariableTunable, true)
end

"""
Expand All @@ -16,12 +12,11 @@ end
Maps the parameter to a constant. The parameter must have a default.
"""
function toconstant(s)
hasmetadata(s, Symbolics.VariableDefaultValue) ||
throw(ArgumentError("Constant `$(s)` must be assigned a default value."))
setmetadata(s, MTKConstantCtx, true)
s = toparam(s)
setmetadata(s, VariableTunable, false)
end

toconstant(s::Num) = wrap(toconstant(value(s)))
toconstant(s::Union{Num, Symbolics.Arr}) = wrap(toconstant(value(s)))

"""
$(SIGNATURES)
Expand All @@ -36,15 +31,3 @@ macro constants(xs...)
xs,
toconstant) |> esc
end

"""
Substitute all `@constants` in the given expression
"""
function subs_constants(eqs)
consts = collect_constants(eqs)
if !isempty(consts)
csubs = Dict(c => getdefault(c) for c in consts)
eqs = substitute(eqs, csubs)
end
return eqs
end
3 changes: 1 addition & 2 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs(
disturbance_inputs = unwrap.(disturbance_inputs)

eqs = [eq for eq in full_equations(sys)]
eqs = map(subs_constants, eqs)

if disturbance_inputs !== nothing && !disturbance_argument
# Set all disturbance *inputs* to zero (we just want to keep the disturbance state)
subs = Dict(disturbance_inputs .=> 0)
Expand All @@ -237,7 +237,6 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs(
p = reorder_parameters(sys, ps)
t = get_iv(sys)

# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
if disturbance_argument
args = (dvs, inputs, p..., t, disturbance_inputs)
else
Expand Down
4 changes: 2 additions & 2 deletions src/problems/optimizationproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ function SciMLBase.OptimizationFunction{iip}(sys::System;
else
_cons_h = cons_hess_prototype = nothing
end
cons_expr = subs_constants(cstr)
cons_expr = cstr
end

obj_expr = subs_constants(cost(sys))
obj_expr = cost(sys)

observedfun = ObservedFunctionCache(
sys; expression, eval_expression, eval_module, checkbounds, cse)
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe
has_tearing_state, defaults, InvalidSystemException,
ExtraEquationsSystemException,
ExtraVariablesSystemException,
get_postprocess_fbody, vars!,
vars!,
IncrementalCycleTracker, add_edge_checked!, topological_sort,
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
filter_kwargs, lower_varname_with_unit,
Expand Down
265 changes: 1 addition & 264 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra

using ModelingToolkit: process_events, get_preprocess_constants
using ModelingToolkit: process_events

const MAX_INLINE_NLSOLVE_SIZE = 8

Expand Down Expand Up @@ -96,136 +96,6 @@ function torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_
sparse(I, J, true, length(eqs_idxs), length(states_idxs))
end

function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDict,
assignments, (deps, invdeps), var2assignment; checkbounds = true)
isempty(vars) && throw(ArgumentError("vars may not be empty"))
length(eqs) == length(vars) ||
throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
rhss = map(x -> x.rhs, eqs)
# We use `vars` instead of `graph` to capture parameters, too.
paramset = ModelingToolkit.vars(r for r in rhss)

# Compute necessary assignments for the nlsolve expr
init_assignments = [var2assignment[p] for p in paramset if haskey(var2assignment, p)]
if isempty(init_assignments)
needed_assignments_idxs = Int[]
needed_assignments = similar(assignments, 0)
else
tmp = [init_assignments]
# `deps[init_assignments]` gives the dependency of `init_assignments`
while true
next_assignments = unique(reduce(vcat, deps[init_assignments]))
isempty(next_assignments) && break
init_assignments = next_assignments
push!(tmp, init_assignments)
end
needed_assignments_idxs = unique(reduce(vcat, reverse(tmp)))
needed_assignments = assignments[needed_assignments_idxs]
end

# Compute `params`. They are like enclosed variables
rhsvars = [ModelingToolkit.vars(r.rhs) for r in needed_assignments]
vars_set = Set(vars)
outer_set = BitSet()
inner_set = BitSet()
for (i, vs) in enumerate(rhsvars)
j = needed_assignments_idxs[i]
if isdisjoint(vars_set, vs)
push!(outer_set, j)
else
push!(inner_set, j)
end
end
init_refine = BitSet()
for i in inner_set
union!(init_refine, invdeps[i])
end
intersect!(init_refine, outer_set)
setdiff!(outer_set, init_refine)
union!(inner_set, init_refine)

next_refine = BitSet()
while true
for i in init_refine
id = invdeps[i]
isempty(id) && break
union!(next_refine, id)
end
intersect!(next_refine, outer_set)
isempty(next_refine) && break
setdiff!(outer_set, next_refine)
union!(inner_set, next_refine)

init_refine, next_refine = next_refine, init_refine
empty!(next_refine)
end
global2local = Dict(j => i for (i, j) in enumerate(needed_assignments_idxs))
inner_idxs = [global2local[i] for i in collect(inner_set)]
outer_idxs = [global2local[i] for i in collect(outer_set)]
extravars = reduce(union!, rhsvars[inner_idxs], init = Set())
union!(paramset, extravars)
setdiff!(paramset, vars)
setdiff!(paramset, [needed_assignments[i].lhs for i in inner_idxs])
union!(paramset, [needed_assignments[i].lhs for i in outer_idxs])
params = collect(paramset)

# splatting to tighten the type
u0 = []
for v in vars
v in keys(u0map) || (push!(u0, 1e-3); continue)
u = substitute(v, u0map)
for i in 1:length(u0map)
u = substitute(u, u0map)
u isa Number && (push!(u0, u); break)
end
u isa Number || error("$v doesn't have a default.")
end
u0 = [u0...]
# specialize on the scalar case
isscalar = length(u0) == 1
u0 = isscalar ? u0[1] : SVector(u0...)

fname = gensym("fun")
# f is the function to find roots on
if isscalar
funex = rhss[1]
pre = get_preprocess_constants(funex)
else
funex = MakeArray(rhss, SVector)
pre = get_preprocess_constants(rhss)
end
f = Func(
[DestructuredArgs(vars, inbounds = !checkbounds)
DestructuredArgs(params, inbounds = !checkbounds)],
[],
pre(Let(needed_assignments[inner_idxs],
funex,
false))) |> SymbolicUtils.Code.toexpr

# solver call contains code to call the root-finding solver on the function f
solver_call = LiteralExpr(quote
$numerical_nlsolve($fname,
# initial guess
$u0,
# "captured variables"
($(params...),))
end)

preassignments = []
for i in outer_idxs
ii = needed_assignments_idxs[i]
is_not_prepended_assignment[ii] || continue
is_not_prepended_assignment[ii] = false
push!(preassignments, assignments[ii])
end

nlsolve_expr = Assignment[preassignments
fname ← drop_expr(@RuntimeGeneratedFunction(f))
DestructuredArgs(vars, inbounds = !checkbounds) ← solver_call]

nlsolve_expr
end

"""
find_solve_sequence(sccs, vars)

Expand All @@ -242,136 +112,3 @@ function find_solve_sequence(sccs, vars)
return find_solve_sequence(sccs, vars′)
end
end

function build_observed_function(state, ts, var_eq_matching, var_sccs,
is_solver_unknown_idxs,
assignments,
deps,
sol_states,
var2assignment;
expression = false,
output_type = Array,
checkbounds = true)
is_not_prepended_assignment = trues(length(assignments))
if (isscalar = !(ts isa AbstractVector))
ts = [ts]
end
ts = unwrap.(Symbolics.scalarize(ts))

vars = Set()
sys = state.sys
foreach(Base.Fix1(vars!, vars), ts)
ivs = independent_variables(sys)
dep_vars = collect(setdiff(vars, ivs))

fullvars = state.fullvars
s = state.structure
unknown_vars = fullvars[is_solver_unknown_idxs]
algvars = fullvars[.!is_solver_unknown_idxs]

required_algvars = Set(intersect(algvars, vars))
obs = observed(sys)
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys))
sts = Set(unknowns(sys))

# FIXME: This is a rather rough estimate of dependencies. We assume
# the expression depends on everything before the `maxidx`.
subs = Dict()
maxidx = 0
for (i, s) in enumerate(dep_vars)
idx = get(observed_idx, s, nothing)
if idx !== nothing
idx > maxidx && (maxidx = idx)
else
s′ = get(namespaced_to_obs, s, nothing)
if s′ !== nothing
subs[s] = s′
s = s′
idx = get(observed_idx, s, nothing)
end
if idx !== nothing
idx > maxidx && (maxidx = idx)
elseif !(s in sts)
s′ = get(namespaced_to_sts, s, nothing)
if s′ !== nothing
subs[s] = s′
continue
end
throw(ArgumentError("$s is either an observed nor an unknown variable."))
end
continue
end
end
ts = map(t -> substitute(t, subs), ts)
vs = Set()
for idx in 1:maxidx
vars!(vs, obs[idx].rhs)
union!(required_algvars, intersect(algvars, vs))
empty!(vs)
end
for eq in assignments
vars!(vs, eq.rhs)
union!(required_algvars, intersect(algvars, vs))
empty!(vs)
end

varidxs = findall(x -> x in required_algvars, fullvars)
subset = find_solve_sequence(var_sccs, varidxs)
if !isempty(subset)
eqs = equations(sys)

nested_torn_vars_idxs = []
for iscc in subset
torn_vars_idxs = Int[var
for var in var_sccs[iscc]
if var_eq_matching[var] !== unassigned]
isempty(torn_vars_idxs) || push!(nested_torn_vars_idxs, torn_vars_idxs)
end
torn_eqs = [[eqs[var_eq_matching[i]] for i in idxs]
for idxs in nested_torn_vars_idxs]
torn_vars = [fullvars[idxs] for idxs in nested_torn_vars_idxs]
u0map = defaults(sys)
assignments = copy(assignments)
solves = map(zip(torn_eqs, torn_vars)) do (eqs, vars)
gen_nlsolve!(is_not_prepended_assignment, eqs, vars,
u0map, assignments, deps, var2assignment;
checkbounds = checkbounds)
end
else
solves = []
end

subs = []
for sym in vars
eqidx = get(observed_idx, sym, nothing)
eqidx === nothing && continue
push!(subs, sym ← obs[eqidx].rhs)
end
pre = get_postprocess_fbody(sys)
cpre = get_preprocess_constants([obs[1:maxidx];
isscalar ? ts[1] : MakeArray(ts, output_type)])
pre2 = x -> pre(cpre(x))
ex = Code.toexpr(
Func(
[DestructuredArgs(unknown_vars, inbounds = !checkbounds)
DestructuredArgs(parameters(sys), inbounds = !checkbounds)
independent_variables(sys)],
[],
pre2(Let(
[collect(Iterators.flatten(solves))
assignments[is_not_prepended_assignment]
map(eq -> eq.lhs ← eq.rhs, obs[1:maxidx])
subs],
isscalar ? ts[1] : MakeArray(ts, output_type),
false))),
sol_states)

expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
end

struct ODAEProblem{iip} end

@deprecate ODAEProblem(args...; kw...) ODEProblem(args...; kw...)
@deprecate ODAEProblem{iip}(args...; kw...) where {iip} ODEProblem{iip}(args...; kw...)
4 changes: 1 addition & 3 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,12 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
a, b, islinear = linear_expansion(term, var)
a, b = unwrap(a), unwrap(b)
islinear || (all_int_vars = false; continue)
a = ModelingToolkit.fold_constants(a)
b = ModelingToolkit.fold_constants(b)
if a isa Symbolic
all_int_vars = false
if !allow_symbolic
if allow_parameter
all(
x -> ModelingToolkit.isparameter(x) || ModelingToolkit.isconstant(x),
x -> ModelingToolkit.isparameter(x),
vars(a)) || continue
else
continue
Expand Down
Loading
Loading