Skip to content

fix: allow specifying type of buffers inside MTKParameters #3585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c38d473
fix: allow specifying type of buffers inside `MTKParameters`
AayushSabharwal Apr 28, 2025
9c40240
fix: fix accidental narrowing of nonnumeric buffer
AayushSabharwal Apr 28, 2025
7b4b205
fix: error if `container_type` passed to `MTKParameters` is not an `A…
AayushSabharwal Apr 30, 2025
5f7b2fb
fix: retain `split` kwarg when simplifying initialization system
AayushSabharwal May 5, 2025
7a51652
refactor: make `EmptySciMLFunction` subtype `SciMLBase.AbstractSciMLF…
AayushSabharwal May 5, 2025
e56ea2a
fix: retain `iip` and `u0Type` in `SCCNonlinearProblem` constructor
AayushSabharwal May 5, 2025
d1642f0
feat: add `p_constructor` kwarg to `MTKParameters` constructor
AayushSabharwal May 5, 2025
91b9388
feat: implement `ArrayInterface.ismutable` for `MTKParameters`
AayushSabharwal May 5, 2025
7629546
fix: handle immutable MTKParameters in `remake_buffer`
AayushSabharwal May 5, 2025
3547062
feat: add `p_constructor` kwarg to problem constructors
AayushSabharwal May 5, 2025
35a7659
feat: propagate `p_constructor` to `InitializationProblem`
AayushSabharwal May 5, 2025
fd869e1
fix: propagate `iip` of parent problem to `InitializationProblem`
AayushSabharwal May 5, 2025
b69d79e
fix: retain type of buffers when promoting `u0`/`p` of initialization…
AayushSabharwal May 5, 2025
aff87fa
fix: handle immutable buffers in initialization
AayushSabharwal May 5, 2025
ad0347b
fix: handle `u0_constructor`, `p_constructor` in `remake_initializati…
AayushSabharwal May 5, 2025
926fced
fix: handle immutable MTKParameters in symbolic `late_binding_update_…
AayushSabharwal May 5, 2025
52ce641
fix: handle edge case in floating point type promotion
AayushSabharwal May 5, 2025
57e4778
fix: call `u0_constructor` on `resid_prototype`
AayushSabharwal May 5, 2025
ba2bfff
test: test initialization on static array problems
AayushSabharwal May 5, 2025
d3ce046
build: bump SciMLBase, StochasticDelayDiffEq compat
AayushSabharwal May 8, 2025
d351267
fix: handle empty `syms` in `concrete_getu`
AayushSabharwal May 12, 2025
329b7d5
fix: fix `is_update_oop` passed as type
AayushSabharwal May 12, 2025
6c57557
fix: fix call to `remake_initialization_data`
AayushSabharwal May 12, 2025
f1e422c
Update Project.toml
ChrisRackauckas May 12, 2025
7600f5f
fix: fix `update_initializeprob!`
AayushSabharwal May 12, 2025
4631a1e
fix: add `f` field to `MockIntegrator`
AayushSabharwal May 12, 2025
a35ae7d
fix: handle discretes properly in `get_mtkparameters_reconstructor`
AayushSabharwal May 15, 2025
1de4cd5
refactor: move ChainRulesCoreExt into main package
AayushSabharwal May 15, 2025
1dac3e7
fix: use `@ignore_derivatives` inside `update_initializeprob!`
AayushSabharwal May 15, 2025
179f0bd
Update Project.toml
ChrisRackauckas May 15, 2025
2569966
fix: add `struct IntervalNonlinearFunctionExpr`
AayushSabharwal May 16, 2025
0987a9f
ci: add SciMLSensitivity/Core8 to downstream CI
AayushSabharwal May 16, 2025
37f5320
build: bump SciMLBase compat
AayushSabharwal May 17, 2025
22151e7
fix: only promote tunables/initials in `promote_u0_p` if non-empty
AayushSabharwal May 20, 2025
2c0976e
refactor: reorder tests
AayushSabharwal May 20, 2025
6d9e49b
fix: always construct unscalarized observed equations for array varia…
AayushSabharwal May 20, 2025
ee9ee8a
test: update tests to account for changes to array hack
AayushSabharwal May 20, 2025
5189598
fix: fix `narrow_buffer_type` for `BlockedArray` of arrays
AayushSabharwal May 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
- {user: SciML, repo: MethodOfLines.jl, group: 2D_Diffusion}
- {user: SciML, repo: MethodOfLines.jl, group: DAE}
- {user: SciML, repo: ModelingToolkitNeuralNets.jl, group: All}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core8}

- {user: Neuroblox, repo: Neuroblox.jl, group: All}
steps:
Expand Down
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down Expand Up @@ -65,7 +66,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
[weakdeps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
Expand All @@ -74,7 +74,6 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
[extensions]
MTKBifurcationKitExt = "BifurcationKit"
MTKCasADiDynamicOptExt = "CasADi"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKFMIExt = "FMI"
MTKInfiniteOptExt = "InfiniteOpt"
Expand Down Expand Up @@ -142,15 +141,15 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.84"
SciMLBase = "2.91.1"
SciMLStructures = "1.7"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDelayDiffEq = "1.8.1"
StochasticDelayDiffEq = "1.10"
StochasticDiffEq = "6.72.1"
SymbolicIndexingInterface = "0.3.39"
SymbolicUtils = "3.26.1"
Expand Down
4 changes: 4 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc
using OffsetArrays: Origin
import CommonSolve
import EnumX
import ChainRulesCore
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk

using RuntimeGeneratedFunctions
using RuntimeGeneratedFunctions: drop_expr
Expand Down Expand Up @@ -204,6 +206,8 @@ include("structural_transformation/StructuralTransformations.jl")
@reexport using .StructuralTransformations
include("inputoutput.jl")

include("adjoints.jl")

for S in subtypes(ModelingToolkit.AbstractSystem)
S = nameof(S)
@eval convert_system(::Type{<:$S}, sys::$S) = sys
Expand Down
32 changes: 12 additions & 20 deletions ext/MTKChainRulesCoreExt.jl → src/adjoints.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
module MTKChainRulesCoreExt

import ModelingToolkit as MTK
import ChainRulesCore
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk

function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...)
function mtp_pullback(dt)
dt = unthunk(dt)
dtunables = dt isa AbstractArray ? dt : dt.tunable
(NoTangent(), dtunables[1:length(tunables)],
ntuple(_ -> NoTangent(), length(args))...)
end
MTK.MTKParameters(tunables, args...), mtp_pullback
MTKParameters(tunables, args...), mtp_pullback
end

function subset_idxs(idxs, portion, template)
Expand Down Expand Up @@ -70,23 +64,23 @@ function selected_tangents(
end

function ChainRulesCore.rrule(
::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals)
::typeof(remake_buffer), indp, oldbuf::MTKParameters, idxs, vals)
if idxs isa AbstractSet
idxs = collect(idxs)
end
idxs = map(idxs) do i
i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i)
i isa ParameterIndex ? i : parameter_index(indp, i)
end
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
newbuf = remake_buffer(indp, oldbuf, idxs, vals)
tunable_idxs = reduce(
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable);
vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Tunable);
init = Union{Int, AbstractVector{Int}}[])
initials_idxs = reduce(
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials);
vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Initials);
init = Union{Int, AbstractVector{Int}}[])
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
disc_idxs = subset_idxs(idxs, SciMLStructures.Discrete(), oldbuf.discrete)
const_idxs = subset_idxs(idxs, SciMLStructures.Constants(), oldbuf.constant)
nn_idxs = subset_idxs(idxs, NONNUMERIC_PORTION, oldbuf.nonnumeric)

pullback = let idxs = idxs
function remake_buffer_pullback(buf′)
Expand All @@ -102,13 +96,11 @@ function ChainRulesCore.rrule(
oldbuf′ = Tangent{typeof(oldbuf)}(;
tunable, initials, discrete, constant, nonnumeric)
idxs′ = NoTangent()
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
vals′ = map(i -> _ducktyped_parameter_values(buf′, i), idxs)
return f′, indp′, oldbuf′, idxs′, vals′
end
end
newbuf, pullback
end

ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol)

end
ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol)
13 changes: 9 additions & 4 deletions src/linearization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ function (linfun::LinearizationFunction)(u, p, t)
linfun.num_states == length(u) ||
error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))")
integ_cache = (linfun.caches,)
integ = MockIntegrator{true}(u, p, t, integ_cache, nothing)
integ = MockIntegrator{true}(u, p, t, fun, integ_cache, nothing)
u, p, success = SciMLBase.get_initial_values(
linfun.prob, integ, fun, linfun.initializealg, Val(true);
linfun.initialize_kwargs...)
Expand Down Expand Up @@ -325,7 +325,7 @@ Mock `DEIntegrator` to allow using `CheckInit` without having to create a new in

$(TYPEDFIELDS)
"""
struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
struct MockIntegrator{iip, U, P, T, F, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
"""
The state vector.
"""
Expand All @@ -339,6 +339,10 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip
"""
t::T
"""
The wrapped `SciMLFunction`.
"""
f::F
"""
The integrator cache.
"""
cache::C
Expand All @@ -348,8 +352,9 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip
opts::O
end

function MockIntegrator{iip}(u::U, p::P, t::T, cache::C, opts::O) where {iip, U, P, T, C, O}
return MockIntegrator{iip, U, P, T, C, O}(u, p, t, cache, opts)
function MockIntegrator{iip}(
u::U, p::P, t::T, f::F, cache::C, opts::O) where {iip, U, P, T, F, C, O}
return MockIntegrator{iip, U, P, T, F, C, O}(u, p, t, f, cache, opts)
end

SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u
Expand Down
12 changes: 0 additions & 12 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -860,12 +860,9 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
# map of array observed variable (unscalarized) to number of its
# scalarized terms that appear in observed equations
arr_obs_occurrences = Dict()
# to check if array variables occur in unscalarized form anywhere
all_vars = Set()
for (i, eq) in enumerate(obs)
lhs = eq.lhs
rhs = eq.rhs
vars!(all_vars, rhs)

# HACK 1
if cse && is_getindexed_array(rhs)
Expand Down Expand Up @@ -920,7 +917,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
tempvar; T = Symbolics.symtype(rhs_arr)))
tempvar = setmetadata(
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
vars!(all_vars, rhs_arr)
tempeq = tempvar ~ rhs_arr
rhs_to_tempvar[rhs_arr] = tempvar
push!(obs, tempeq)
Expand All @@ -946,18 +942,10 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
cnt == 0 && continue
arr_obs_occurrences[arg1] = cnt + 1
end
for eq in neweqs
vars!(all_vars, eq.rhs)
end

# also count unscalarized variables used in callbacks
for ev in Iterators.flatten((continuous_events(sys), discrete_events(sys)))
vars!(all_vars, ev)
end
obs_arr_eqs = Equation[]
for (arrvar, cnt) in arr_obs_occurrences
cnt == length(arrvar) || continue
arrvar in all_vars || continue
# firstindex returns 1 for multidimensional array symbolics
firstind = first(eachindex(arrvar))
scal = [arrvar[i] for i in eachindex(arrvar)]
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
end

if simplify_system
isys = structural_simplify(isys; fully_determined)
isys = structural_simplify(isys; fully_determined, split = is_split(sys))
end

ts = get_tearing_state(isys)
Expand Down Expand Up @@ -1554,6 +1554,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
else
NonlinearLeastSquaresProblem
end
TProb(isys, u0map, parammap; kwargs...,
TProb{iip}(isys, u0map, parammap; kwargs...,
build_initializeprob = false, is_initializeprob = true)
end
6 changes: 3 additions & 3 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.")
end

_f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
_f, u0, p = process_SciMLProblem(EmptySciMLFunction{true}, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false, cse)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

Expand Down Expand Up @@ -449,7 +449,7 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
_, u0, p = process_SciMLProblem(EmptySciMLFunction{iip}, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false)
# identity function to make syms works
quote
Expand Down Expand Up @@ -506,7 +506,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
return ODEProblem(osys, u0map, tspan, parammap; check_length = false,
build_initializeprob = false, kwargs...)
else
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
_, u0, p = process_SciMLProblem(EmptySciMLFunction{true}, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], tofloat = false,
check_length = false, build_initializeprob = false, cse)
f = (du, u, p, t) -> (du .= 0; nothing)
Expand Down
41 changes: 28 additions & 13 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,7 @@ function SciMLBase.remake_initialization_data(
length(oldinitprob.f.resid_prototype), new_initu0, new_initp))
end
initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp)
return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!,
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata)
return @set oldinitdata.initializeprob = initprob
end

dvs = unknowns(sys)
Expand Down Expand Up @@ -582,21 +581,35 @@ function SciMLBase.remake_initialization_data(
op, missing_unknowns, missing_pars = build_operating_point!(sys,
u0map, pmap, defs, cmap, dvs, ps)
floatT = float_type_from_varmap(op)
u0_constructor = p_constructor = identity
if newu0 isa StaticArray
u0_constructor = vals -> SymbolicUtils.Code.create_array(
typeof(newu0), floatT, Val(1), Val(length(vals)), vals...)
end
if newp isa StaticArray || newp isa MTKParameters && newp.initials isa StaticArray
p_constructor = vals -> SymbolicUtils.Code.create_array(
typeof(newp.initials), floatT, Val(1), Val(length(vals)), vals...)
end
kws = maybe_build_initialization_problem(
sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns;
use_scc, initialization_eqs, floatT, allow_incomplete = true)
sys, SciMLBase.isinplace(odefn), op, u0map, pmap, t0, defs, guesses, missing_unknowns;
use_scc, initialization_eqs, floatT, u0_constructor, p_constructor, allow_incomplete = true)

return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp)
odefn = remake(odefn; kws...)
return SciMLBase.remake_initialization_data(sys, odefn, newu0, t0, newp, newu0, newp)
end

function promote_u0_p(u0, p::MTKParameters, t0)
u0 = DiffEqBase.promote_u0(u0, p.tunable, t0)
u0 = DiffEqBase.promote_u0(u0, p.initials, t0)

tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
if !isempty(p.tunable)
tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
end
if !isempty(p.initials)
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
end

return u0, p
end
Expand Down Expand Up @@ -627,12 +640,12 @@ function SciMLBase.late_binding_update_u0_p(
if length(newu0) != length(prob.u0)
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
end
meta.set_initial_unknowns!(newp, newu0)
newp = meta.set_initial_unknowns!(newp, newu0)
return newu0, newp
end

newp = p === missing ? copy(newp) : newp

syms = []
vals = []
allsyms = all_symbols(sys)
for (k, v) in u0
v === nothing && continue
Expand All @@ -644,9 +657,11 @@ function SciMLBase.late_binding_update_u0_p(
k = k2
end
is_parameter(sys, Initial(k)) || continue
setp(sys, Initial(k))(newp, v)
push!(syms, Initial(k))
push!(vals, v)
end

newp = setp_oop(sys, syms)(newp, vals)
return newu0, newp
end

Expand Down
Loading
Loading