From c38d4733e3b1f419f6828b6b570049f422b8adb6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 28 Apr 2025 11:51:35 +0530 Subject: [PATCH 01/40] fix: allow specifying type of buffers inside `MTKParameters` --- src/systems/parameter_buffer.jl | 56 ++++++++++++++++++++++++++------- src/systems/problem_utils.jl | 2 +- test/initial_values.jl | 26 +++++++++++++++ test/mtkparameters.jl | 13 ++++++-- 4 files changed, 82 insertions(+), 15 deletions(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 6142c95776..1e847256b3 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -28,7 +28,11 @@ the default behavior). """ function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, - t0 = nothing, substitution_limit = 1000, floatT = nothing) + t0 = nothing, substitution_limit = 1000, floatT = nothing, + container_type = Vector) + if !(container_type <: AbstractArray) + container_type = Array + end ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) else @@ -133,18 +137,23 @@ function MTKParameters( end end end - tunable_buffer = narrow_buffer_type(tunable_buffer) + tunable_buffer = narrow_buffer_type(tunable_buffer; container_type) if isempty(tunable_buffer) tunable_buffer = SizedVector{0, Float64}() end - initials_buffer = narrow_buffer_type(initials_buffer) + initials_buffer = narrow_buffer_type(initials_buffer; container_type) if isempty(initials_buffer) initials_buffer = SizedVector{0, Float64}() end - disc_buffer = narrow_buffer_type.(disc_buffer) - const_buffer = narrow_buffer_type.(const_buffer) + disc_buffer = narrow_buffer_type.(disc_buffer; container_type) + const_buffer = narrow_buffer_type.(const_buffer; container_type) # Don't narrow nonnumeric types - nonnumeric_buffer = nonnumeric_buffer + if !isempty(nonnumeric_buffer) + nonnumeric_buffer = map(nonnumeric_buffer) do buf + SymbolicUtils.Code.create_array( + container_type, nothing, Val(1), Val(length(buf)), buf...) + end + end mtkps = MTKParameters{ typeof(tunable_buffer), typeof(initials_buffer), typeof(disc_buffer), @@ -160,21 +169,44 @@ function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate.. @set p.caches = buffers end -function narrow_buffer_type(buffer::AbstractArray) +function narrow_buffer_type(buffer::AbstractArray; container_type = typeof(buffer)) type = Union{} for x in buffer type = promote_type(type, typeof(x)) end - return convert.(type, buffer) + return SymbolicUtils.Code.create_array( + container_type, type, Val(ndims(buffer)), Val(length(buffer)), buffer...) end -function narrow_buffer_type(buffer::AbstractArray{<:AbstractArray}) - buffer = narrow_buffer_type.(buffer) +function narrow_buffer_type( + buffer::AbstractArray{<:AbstractArray}; container_type = typeof(buffer)) + type = Union{} + for arr in buffer + for x in arr + type = promote_type(type, typeof(x)) + end + end + buffer = map(buffer) do buf + SymbolicUtils.Code.create_array( + container_type, type, Val(ndims(buf)), Val(size(buf)), buf...) + end + return SymbolicUtils.Code.create_array( + container_type, nothing, Val(ndims(buffer)), Val(size(buffer)), buffer...) +end + +function narrow_buffer_type(buffer::BlockedArray; container_type = typeof(parent(buffer))) type = Union{} for x in buffer - type = promote_type(type, eltype(x)) + type = promote_type(type, typeof(x)) + end + tmp = SymbolicUtils.Code.create_array( + container_type, type, Val(ndims(buffer)), Val(size(buffer)), buffer...) + blocks = ntuple(Val(ndims(buffer))) do i + bsizes = blocksizes(buffer, i) + SymbolicUtils.Code.create_array( + container_type, Int, Val(1), Val(length(bsizes)), bsizes...) end - return broadcast.(convert, type, buffer) + return BlockedArray(tmp, blocks...) end function buffer_to_arraypartition(buf) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 4525b0e46b..5116453ef2 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1155,7 +1155,7 @@ function process_SciMLProblem( end evaluate_varmap!(op, ps; limit = substitution_limit) if is_split(sys) - p = MTKParameters(sys, op; floatT = floatT) + p = MTKParameters(sys, op; floatT = floatT, container_type = pType) else p = better_varmap_to_vars(op, ps; tofloat, container_type = pType) end diff --git a/test/initial_values.jl b/test/initial_values.jl index b3614de0f4..ea5e8e01bb 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -2,6 +2,7 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D, get_u0 using OrdinaryDiffEq using DataInterpolations +using StaticArrays using SymbolicIndexingInterface: getu @variables x(t)[1:3]=[1.0, 2.0, 3.0] y(t) z(t)[1:2] @@ -309,3 +310,28 @@ end @test prob[w2] ≈ -1.0 @test prob.ps[β] ≈ 8 / 3 end + +@testset "MTKParameters uses given `pType` for inner buffers" begin + @parameters σ ρ β + @variables x(t) y(t) z(t) + + eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + + @mtkbuild sys = ODESystem(eqs, t) + + u0 = SA[D(x) => 2.0f0, + x => 1.0f0, + y => 0.0f0, + z => 0.0f0] + + p = SA[σ => 28.0f0, + ρ => 10.0f0, + β => 8.0f0 / 3] + + tspan = (0.0f0, 100.0f0) + prob = ODEProblem(sys, u0, tspan, p) + @test prob.p.tunable isa SVector + @test prob.p.initials isa SVector +end diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 55de0768e0..9c857b6a55 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -1,8 +1,8 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters -using SymbolicIndexingInterface +using SymbolicIndexingInterface, StaticArrays using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants -using BlockArrays: BlockedArray, Block +using BlockArrays: BlockedArray, BlockedVector, Block using OrdinaryDiffEq using ForwardDiff using JET @@ -27,6 +27,15 @@ end @test getp(sys, a)(ps) == getp(sys, b)(ps) == getp(sys, c)(ps) == 0.0 @test getp(sys, d)(ps) isa Int +@testset "`container_type`" begin + ps2 = MTKParameters(sys, ivs; container_type = SVector) + @test ps2.tunable isa SVector + @test ps2.initials isa SVector + @test ps2.discrete isa Tuple{<:BlockedVector{Float64, <:SVector}} + @test ps2.constant isa Tuple{<:SVector, <:SVector, <:SVector{1, <:SMatrix}} + @test ps2.nonnumeric isa Tuple{<:SVector} +end + ivs[a] = 1.0 ps = MTKParameters(sys, ivs) for (p, val) in ivs From 9c402402f5843fdf9884ed1028dda8d762b27e74 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 28 Apr 2025 15:46:34 +0530 Subject: [PATCH 02/40] fix: fix accidental narrowing of nonnumeric buffer --- src/systems/parameter_buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 1e847256b3..419cb0cd77 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -151,7 +151,7 @@ function MTKParameters( if !isempty(nonnumeric_buffer) nonnumeric_buffer = map(nonnumeric_buffer) do buf SymbolicUtils.Code.create_array( - container_type, nothing, Val(1), Val(length(buf)), buf...) + container_type, eltype(buf), Val(1), Val(length(buf)), buf...) end end From 7b4b205b2fd9b55f2663bd6dfb646df07e18b80e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 30 Apr 2025 12:24:20 +0530 Subject: [PATCH 03/40] fix: error if `container_type` passed to `MTKParameters` is not an `AbstractArray` subtype --- src/systems/parameter_buffer.jl | 5 ++++- src/systems/problem_utils.jl | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 419cb0cd77..d899d7247d 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -31,7 +31,10 @@ function MTKParameters( t0 = nothing, substitution_limit = 1000, floatT = nothing, container_type = Vector) if !(container_type <: AbstractArray) - container_type = Array + throw(ArgumentError(""" + `container_type` for `MTKParameters` must be a subtype of `AbstractArray`. Found \ + $container_type. + """)) end ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 5116453ef2..67ecdd0f43 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1155,6 +1155,10 @@ function process_SciMLProblem( end evaluate_varmap!(op, ps; limit = substitution_limit) if is_split(sys) + # `pType` is usually `Dict` when the user passes key-value pairs. + if !(pType <: AbstractArray) + pType = Array + end p = MTKParameters(sys, op; floatT = floatT, container_type = pType) else p = better_varmap_to_vars(op, ps; tofloat, container_type = pType) From 5f7b2fbf4409536a8791888120f8886dab8dca2c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:12:21 +0530 Subject: [PATCH 04/40] fix: retain `split` kwarg when simplifying initialization system --- src/systems/diffeqs/abstractodesystem.jl | 4 ++-- test/initializationsystem.jl | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 942e508644..d802f49fee 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1479,7 +1479,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, end if simplify_system - isys = structural_simplify(isys; fully_determined) + isys = structural_simplify(isys; fully_determined, split = is_split(sys)) end ts = get_tearing_state(isys) @@ -1554,6 +1554,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, else NonlinearLeastSquaresProblem end - TProb(isys, u0map, parammap; kwargs..., + TProb{iip}(isys, u0map, parammap; kwargs..., build_initializeprob = false, is_initializeprob = true) end diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index fea989f0f3..cadf1f0a01 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1599,3 +1599,15 @@ end @test SciMLBase.successful_retcode(sol) @test sol.u[1] ≈ new_u0 end + +@testset "Initialization system retains `split` kwarg of parent" begin + @parameters g + @variables x(t) y(t) [state_priority = 10] λ(t) + eqs = [D(D(x)) ~ λ * x + D(D(y)) ~ λ * y - g + x^2 + y^2 ~ 1] + @mtkbuild pend=ODESystem(eqs, t) split=false + prob = ODEProblem(pend, [x => 1.0, D(x) => 0.0], (0.0, 1.0), + [g => 1.0]; guesses = [y => 1.0, λ => 1.0]) + @test !ModelingToolkit.is_split(prob.f.initialization_data.initializeprob.f.sys) +end From 7a5165233bd376b02124ac14081c1f0a3e05db0a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:13:13 +0530 Subject: [PATCH 05/40] refactor: make `EmptySciMLFunction` subtype `SciMLBase.AbstractSciMLFunction` --- src/systems/jumps/jumpsystem.jl | 6 +++--- src/systems/nonlinear/nonlinearsystem.jl | 2 +- src/systems/problem_utils.jl | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 06f5e1b623..34b6df3bf5 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -408,7 +408,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.") end - _f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + _f, u0, p = process_SciMLProblem(EmptySciMLFunction{true}, sys, u0map, parammap; t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false, cse) f = DiffEqBase.DISCRETE_INPLACE_DEFAULT @@ -449,7 +449,7 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`") end - _, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + _, u0, p = process_SciMLProblem(EmptySciMLFunction{iip}, sys, u0map, parammap; t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false) # identity function to make syms works quote @@ -506,7 +506,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi return ODEProblem(osys, u0map, tspan, parammap; check_length = false, build_initializeprob = false, kwargs...) else - _, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + _, u0, p = process_SciMLProblem(EmptySciMLFunction{true}, sys, u0map, parammap; t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false, cse) f = (du, u, p, t) -> (du .= 0; nothing) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 7146fb6b5e..949184d8f7 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -705,7 +705,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, obs = observed(sys) _, u0, p = process_SciMLProblem( - EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...) + EmptySciMLFunction{iip}, sys, u0map, parammap; eval_expression, eval_module, kwargs...) explicitfuns = [] nlfuns = [] diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 67ecdd0f43..8f9dd944cf 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -539,13 +539,13 @@ A simple utility meant to be used as the `constructor` passed to `process_SciMLP case constructing a SciMLFunction is not required. The arguments passed to it are available in the `args` field, and the keyword arguments in the `kwargs` field. """ -struct EmptySciMLFunction{A, K} +struct EmptySciMLFunction{iip, A, K} <: SciMLBase.AbstractSciMLFunction{iip} args::A kwargs::K end -function EmptySciMLFunction(args...; kwargs...) - return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs) +function EmptySciMLFunction{iip}(args...; kwargs...) where {iip} + return EmptySciMLFunction{iip, typeof(args), typeof(kwargs)}(args, kwargs) end """ From e56ea2a4606c2b11e1911839381a494dae3e167e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:16:18 +0530 Subject: [PATCH 06/40] fix: retain `iip` and `u0Type` in `SCCNonlinearProblem` constructor --- src/systems/nonlinear/nonlinearsystem.jl | 4 +++- test/scc_nonlinear_problem.jl | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 949184d8f7..f8182fa28a 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -835,7 +835,9 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, subprobs = [] for (f, vscc) in zip(nlfuns, var_sccs) - prob = NonlinearProblem(f, u0[vscc], p) + _u0 = SymbolicUtils.Code.create_array( + typeof(u0), eltype(u0), Val(1), Val(length(vscc)), u0[vscc]...) + prob = NonlinearProblem{iip}(f, _u0, p) push!(subprobs, prob) end diff --git a/test/scc_nonlinear_problem.jl b/test/scc_nonlinear_problem.jl index b2b326d090..ac7978d269 100644 --- a/test/scc_nonlinear_problem.jl +++ b/test/scc_nonlinear_problem.jl @@ -2,6 +2,7 @@ using ModelingToolkit using NonlinearSolve, SCCNonlinearSolve using OrdinaryDiffEq using SciMLBase, Symbolics +using StaticArrays using LinearAlgebra, Test using ModelingToolkit: t_nounits as t, D_nounits as D @@ -32,6 +33,12 @@ using ModelingToolkit: t_nounits as t, D_nounits as D @test SciMLBase.successful_retcode(sol1) @test SciMLBase.successful_retcode(sol2) @test sol1[u] ≈ sol2[u] + + sccprob = SCCNonlinearProblem{false}(model, SA[u => zeros(8)]) + for prob in sccprob.probs + @test prob.u0 isa SVector + @test !SciMLBase.isinplace(prob) + end end @testset "With parameters" begin From d1642f07d7bd9433d29a9bfd138e8b086a85bd8b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:16:52 +0530 Subject: [PATCH 07/40] feat: add `p_constructor` kwarg to `MTKParameters` constructor --- src/systems/parameter_buffer.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index d899d7247d..5c777dbfe4 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -29,7 +29,7 @@ the default behavior). function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, t0 = nothing, substitution_limit = 1000, floatT = nothing, - container_type = Vector) + container_type = Vector, p_constructor = identity) if !(container_type <: AbstractArray) throw(ArgumentError(""" `container_type` for `MTKParameters` must be a subtype of `AbstractArray`. Found \ @@ -140,21 +140,21 @@ function MTKParameters( end end end - tunable_buffer = narrow_buffer_type(tunable_buffer; container_type) + tunable_buffer = p_constructor(narrow_buffer_type(tunable_buffer; container_type)) if isempty(tunable_buffer) tunable_buffer = SizedVector{0, Float64}() end - initials_buffer = narrow_buffer_type(initials_buffer; container_type) + initials_buffer = p_constructor(narrow_buffer_type(initials_buffer; container_type)) if isempty(initials_buffer) initials_buffer = SizedVector{0, Float64}() end - disc_buffer = narrow_buffer_type.(disc_buffer; container_type) - const_buffer = narrow_buffer_type.(const_buffer; container_type) + disc_buffer = p_constructor.(narrow_buffer_type.(disc_buffer; container_type)) + const_buffer = p_constructor.(narrow_buffer_type.(const_buffer; container_type)) # Don't narrow nonnumeric types if !isempty(nonnumeric_buffer) nonnumeric_buffer = map(nonnumeric_buffer) do buf - SymbolicUtils.Code.create_array( - container_type, eltype(buf), Val(1), Val(length(buf)), buf...) + p_constructor(SymbolicUtils.Code.create_array( + container_type, eltype(buf), Val(1), Val(length(buf)), buf...)) end end From 91b93882d5f73345cadd0072cc77c2cd15f2aa24 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:17:07 +0530 Subject: [PATCH 08/40] feat: implement `ArrayInterface.ismutable` for `MTKParameters` --- src/systems/parameter_buffer.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 5c777dbfe4..82a44a94e8 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -366,6 +366,14 @@ function Base.copy(p::MTKParameters) ) end +function ArrayInterface.ismutable(::Type{MTKParameters{ + T, I, D, C, N, H}}) where {T, I, D, C, N, H} + ArrayInterface.ismutable(T) || ArrayInterface.ismutable(I) || + any(ArrayInterface.ismutable, fieldtypes(D)) || + any(ArrayInterface.ismutable, fieldtypes(C)) || + any(ArrayInterface.ismutable, fieldtypes(N)) +end + function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex) _ducktyped_parameter_values(p, pind) end From 762954616c237810197ffd716e93efed5f438259 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:17:45 +0530 Subject: [PATCH 09/40] fix: handle immutable MTKParameters in `remake_buffer` --- src/systems/parameter_buffer.jl | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 82a44a94e8..7c34f67deb 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -637,8 +637,9 @@ end nonnumerics = $(Expr(:tuple, (:($similar(oldbuf.nonnumeric[$i], $(nonnumericT[i]))) for i in 1:length(nonnumericT))...)) $((:($copyto!(nonnumerics[$i], oldbuf.nonnumeric[$i])) for i in 1:length(nonnumericT))...) + caches = copy.(oldbuf.caches) newbuf = MTKParameters( - tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches)) + tunables, initials, discretes, constants, nonnumerics, caches) end if idxs <: AbstractArray push!(expr.args, :(for (idx, val) in zip(idxs, vals) @@ -649,6 +650,22 @@ end push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i]))) end end + if !ArrayInterface.ismutable(oldbuf) + push!(expr.args, :(tunables = $similar_type($T, $tunablesT)(tunables))) + push!(expr.args, :(initials = $similar_type($I, $initialsT)(initials))) + push!(expr.args, + :(discretes = $(Expr(:tuple, + (:($similar_type($(fieldtype(D, i)), $(discretesT[i]))(discretes[$i])) for i in 1:length(discretesT))...)))) + push!(expr.args, + :(constants = $(Expr(:tuple, + (:($similar_type($(fieldtype(C, i)), $(constantsT[i]))(constants[$i])) for i in 1:length(constantsT))...)))) + push!(expr.args, + :(nonnumerics = $(Expr(:tuple, + (:($similar_type($(fieldtype(C, i)), $(nonnumericT[i]))(nonnumerics[$i])) for i in 1:length(nonnumericT))...)))) + push!(expr.args, + :(newbuf = MTKParameters( + tunables, initials, discretes, constants, nonnumerics, caches))) + end push!(expr.args, :(return newbuf)) return expr @@ -748,6 +765,19 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru oldbuf.constant, newbuf.constant) @set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.( oldbuf.nonnumeric, newbuf.nonnumeric) + if !ArrayInterface.ismutable(oldbuf) + @set! newbuf.tunable = similar_type(oldbuf.tunable, eltype(newbuf.tunable))(newbuf.tunable) + @set! newbuf.initials = similar_type(oldbuf.initials, eltype(newbuf.initials))(newbuf.initials) + @set! newbuf.discrete = ntuple(Val(length(newbuf.discrete))) do i + similar_type.(oldbuf.discrete[i], eltype(newbuf.discrete[i]))(newbuf.discrete[i]) + end + @set! newbuf.constant = ntuple(Val(length(newbuf.constant))) do i + similar_type.(oldbuf.constant[i], eltype(newbuf.constant[i]))(newbuf.constant[i]) + end + @set! newbuf.nonnumeric = ntuple(Val(length(newbuf.nonnumeric))) do i + similar_type.(oldbuf.nonnumeric[i], eltype(newbuf.nonnumeric[i]))(newbuf.nonnumeric[i]) + end + end return newbuf end From 354706282a7b377ae4fbc7efce154b58c33ed4c2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:21:23 +0530 Subject: [PATCH 10/40] feat: add `p_constructor` kwarg to problem constructors --- src/systems/problem_utils.jl | 14 ++++++++++---- test/initial_values.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 8f9dd944cf..071a47a355 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1016,6 +1016,7 @@ Keyword arguments: - `tofloat`, `is_initializeprob`: Passed to [`better_varmap_to_vars`](@ref) for building `u0` (and possibly `p`). - `u0_constructor`: A function to apply to the `u0` value returned from `better_varmap_to_vars` to construct the final `u0` value. +- `p_constructor`: A function to apply to each array buffer created when constructing the parameter object. - `du0map`: A map of derivatives to values. See `implicit_dae`. - `check_length`: Whether to check the number of equations along with number of unknowns and length of `u0` vector for consistency. If `false`, do not check with equations. This is @@ -1044,8 +1045,8 @@ function process_SciMLProblem( warn_initialize_determined = true, initialization_eqs = [], eval_expression = false, eval_module = @__MODULE__, fully_determined = nothing, check_initialization_units = false, tofloat = true, - u0_constructor = identity, du0map = nothing, check_length = true, - symbolic_u0 = false, warn_cyclic_dependency = false, + u0_constructor = identity, p_constructor = identity, du0map = nothing, + check_length = true, symbolic_u0 = false, warn_cyclic_dependency = false, circular_dependency_max_cycle_length = length(all_symbols(sys)), circular_dependency_max_cycles = 10, substitution_limit = 100, use_scc = true, @@ -1095,6 +1096,11 @@ function process_SciMLProblem( u0_constructor = vals -> SymbolicUtils.Code.create_array( u0Type, floatT, Val(1), Val(length(vals)), vals...) end + if p_constructor === identity && pType <: StaticArray + p_constructor = vals -> SymbolicUtils.Code.create_array( + pType, floatT, Val(1), Val(length(vals)), vals...) + end + if build_initializeprob kws = maybe_build_initialization_problem( sys, op, u0map, pmap, t, defs, guesses, missing_unknowns; @@ -1159,9 +1165,9 @@ function process_SciMLProblem( if !(pType <: AbstractArray) pType = Array end - p = MTKParameters(sys, op; floatT = floatT, container_type = pType) + p = MTKParameters(sys, op; floatT = floatT, container_type = pType, p_constructor) else - p = better_varmap_to_vars(op, ps; tofloat, container_type = pType) + p = p_constructor(better_varmap_to_vars(op, ps; tofloat, container_type = pType)) end if implicit_dae && du0map !== nothing diff --git a/test/initial_values.jl b/test/initial_values.jl index ea5e8e01bb..8cff39f784 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -335,3 +335,31 @@ end @test prob.p.tunable isa SVector @test prob.p.initials isa SVector end + +@testset "`p_constructor` keyword argument" begin + @parameters σ ρ β + @variables x(t) y(t) z(t) + + eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + + @mtkbuild sys = ODESystem(eqs, t) + + u0 = [D(x) => 2.0f0, + x => 1.0f0, + y => 0.0f0, + z => 0.0f0] + + p = [σ => 28.0f0, + ρ => 10.0f0, + β => 8.0f0 / 3] + u0_constructor = p_constructor = vals -> SVector{length(vals)}(vals...) + prob = ODEProblem(sys, u0, tspan, p; u0_constructor, p_constructor) + @test prob.p.tunable isa SVector + @test prob.p.initials isa SVector + + @mtkbuild sys=ODESystem(eqs, t) split=false + prob = ODEProblem(sys, u0, tspan, p; u0_constructor, p_constructor) + @test prob.p isa SVector +end From 35a7659cef421bc4a615ef1030f28a14ed2de83e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:27:37 +0530 Subject: [PATCH 11/40] feat: propagate `p_constructor` to `InitializationProblem` --- src/systems/problem_utils.jl | 8 +++++--- test/initial_values.jl | 39 ++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 071a47a355..501ca14cd6 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -869,7 +869,8 @@ All other keyword arguments are forwarded to `InitializationProblem`. function maybe_build_initialization_problem( sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs, guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, - floatT = Float64, initialization_eqs = [], use_scc = true, kwargs...) + p_constructor = identity, floatT = Float64, initialization_eqs = [], + use_scc = true, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) @@ -877,7 +878,8 @@ function maybe_build_initialization_problem( end initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}( - sys, t, u0map, pmap; guesses, initialization_eqs, use_scc, kwargs...) + sys, t, u0map, pmap; guesses, initialization_eqs, + use_scc, u0_constructor, p_constructor, kwargs...) if state_values(initializeprob) !== nothing initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob))) end @@ -1109,7 +1111,7 @@ function process_SciMLProblem( warn_cyclic_dependency, check_units = check_initialization_units, circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc, force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete, - u0_constructor, floatT) + u0_constructor, p_constructor, floatT) kwargs = merge(kwargs, kws) end diff --git a/test/initial_values.jl b/test/initial_values.jl index 8cff39f784..0ed8f7bffe 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -3,7 +3,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D, get_u0 using OrdinaryDiffEq using DataInterpolations using StaticArrays -using SymbolicIndexingInterface: getu +using SymbolicIndexingInterface @variables x(t)[1:3]=[1.0, 2.0, 3.0] y(t) z(t)[1:2] @@ -337,29 +337,28 @@ end end @testset "`p_constructor` keyword argument" begin - @parameters σ ρ β - @variables x(t) y(t) z(t) - - eqs = [D(D(x)) ~ σ * (y - x), - D(y) ~ x * (ρ - z) - y, - D(z) ~ x * y - β * z] - - @mtkbuild sys = ODESystem(eqs, t) - - u0 = [D(x) => 2.0f0, - x => 1.0f0, - y => 0.0f0, - z => 0.0f0] + @parameters g = 1.0 + @variables x(t) y(t) [state_priority = 10, guess = 1.0] λ(t) [guess = 1.0] + eqs = [D(D(x)) ~ λ * x + D(D(y)) ~ λ * y - g + x^2 + y^2 ~ 1] + @mtkbuild pend = ODESystem(eqs, t) - p = [σ => 28.0f0, - ρ => 10.0f0, - β => 8.0f0 / 3] + u0 = [x => 1.0, D(x) => 0.0] u0_constructor = p_constructor = vals -> SVector{length(vals)}(vals...) - prob = ODEProblem(sys, u0, tspan, p; u0_constructor, p_constructor) + tspan = (0.0, 5.0) + prob = ODEProblem(pend, u0, tspan; u0_constructor, p_constructor) + @test prob.u0 isa SVector @test prob.p.tunable isa SVector @test prob.p.initials isa SVector + initdata = prob.f.initialization_data + @test state_values(initdata.initializeprob) isa SVector + @test parameter_values(initdata.initializeprob).tunable isa SVector - @mtkbuild sys=ODESystem(eqs, t) split=false - prob = ODEProblem(sys, u0, tspan, p; u0_constructor, p_constructor) + @mtkbuild pend=ODESystem(eqs, t) split=false + prob = ODEProblem(pend, u0, tspan; u0_constructor, p_constructor) @test prob.p isa SVector + initdata = prob.f.initialization_data + @test state_values(initdata.initializeprob) isa SVector + @test parameter_values(initdata.initializeprob) isa SVector end From fd869e1d93c4a9d6b8192ab5c3aec28cd707716f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:30:10 +0530 Subject: [PATCH 12/40] fix: propagate `iip` of parent problem to `InitializationProblem` --- src/systems/problem_utils.jl | 7 ++++--- test/initializationsystem.jl | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 501ca14cd6..116adf9deb 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -867,7 +867,7 @@ denotes whether the `SciMLProblem` being constructed is in implicit DAE form (`D All other keyword arguments are forwarded to `InitializationProblem`. """ function maybe_build_initialization_problem( - sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs, + sys::AbstractSystem, iip, op::AbstractDict, u0map, pmap, t, defs, guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, p_constructor = identity, floatT = Float64, initialization_eqs = [], use_scc = true, kwargs...) @@ -877,7 +877,7 @@ function maybe_build_initialization_problem( t = zero(floatT) end - initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}( + initializeprob = ModelingToolkit.InitializationProblem{iip}( sys, t, u0map, pmap; guesses, initialization_eqs, use_scc, u0_constructor, p_constructor, kwargs...) if state_values(initializeprob) !== nothing @@ -1105,7 +1105,8 @@ function process_SciMLProblem( if build_initializeprob kws = maybe_build_initialization_problem( - sys, op, u0map, pmap, t, defs, guesses, missing_unknowns; + sys, constructor <: SciMLBase.AbstractSciMLFunction{true}, + op, u0map, pmap, t, defs, guesses, missing_unknowns; implicit_dae, warn_initialize_determined, initialization_eqs, eval_expression, eval_module, fully_determined, warn_cyclic_dependency, check_units = check_initialization_units, diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index cadf1f0a01..87ba755259 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1611,3 +1611,16 @@ end [g => 1.0]; guesses = [y => 1.0, λ => 1.0]) @test !ModelingToolkit.is_split(prob.f.initialization_data.initializeprob.f.sys) end + +@testset "`InitializationProblem` retains `iip` of parent" begin + @parameters g + @variables x(t) y(t) [state_priority = 10] λ(t) + eqs = [D(D(x)) ~ λ * x + D(D(y)) ~ λ * y - g + x^2 + y^2 ~ 1] + @mtkbuild pend = ODESystem(eqs, t) + prob = ODEProblem(pend, SA[x => 1.0, D(x) => 0.0], (0.0, 1.0), + SA[g => 1.0]; guesses = [y => 1.0, λ => 1.0]) + @test !SciMLBase.isinplace(prob) + @test !SciMLBase.isinplace(prob.f.initialization_data.initializeprob) +end From b69d79e591217668f5491e5ed8da521e1fd92007 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:33:06 +0530 Subject: [PATCH 13/40] fix: retain type of buffers when promoting `u0`/`p` of initialization problem --- src/systems/problem_utils.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 116adf9deb..5e7edf4e55 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -881,7 +881,13 @@ function maybe_build_initialization_problem( sys, t, u0map, pmap; guesses, initialization_eqs, use_scc, u0_constructor, p_constructor, kwargs...) if state_values(initializeprob) !== nothing - initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob))) + _u0 = state_values(initializeprob) + if ArrayInterface.ismutable(_u0) + _u0 = floatT.(_u0) + else + _u0 = similar_type(_u0, floatT)(_u0) + end + initializeprob = remake(initializeprob; u0 = _u0) end initp = parameter_values(initializeprob) if is_split(sys) @@ -890,9 +896,13 @@ function maybe_build_initialization_problem( buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp) initp = repack(floatT.(buffer)) elseif initp isa AbstractArray - initp′ = similar(initp, floatT) - copyto!(initp′, initp) - initp = initp′ + if ArrayInterface.ismutable(initp) + initp′ = similar(initp, floatT) + copyto!(initp′, initp) + initp = initp′ + else + initp = similar_type(initp, floatT)(initp) + end end initializeprob = remake(initializeprob; p = initp) From aff87fa0bf62f45a5d732cffb2dec682f50939b4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 16:34:36 +0530 Subject: [PATCH 14/40] fix: handle immutable buffers in initialization --- src/systems/nonlinear/initializesystem.jl | 5 +- src/systems/problem_utils.jl | 228 +++++++++++++++------- 2 files changed, 162 insertions(+), 71 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index c9d0c8f3a5..ce29db68b4 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -513,8 +513,7 @@ function SciMLBase.remake_initialization_data( length(oldinitprob.f.resid_prototype), new_initu0, new_initp)) end initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp) - return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!, - oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata) + return @set oldinitdata.initializeprob = initprob end dvs = unknowns(sys) @@ -627,7 +626,7 @@ function SciMLBase.late_binding_update_u0_p( if length(newu0) != length(prob.u0) throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))")) end - meta.set_initial_unknowns!(newp, newu0) + newp = meta.set_initial_unknowns!(newp, newu0) return newu0, newp end diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 5e7edf4e55..897b2fcc44 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -491,32 +491,6 @@ function scalarize_varmap!(varmap::AbstractDict) return varmap end -struct GetUpdatedMTKParameters{G, S} - # `getu` functor which gets parameters that are unknowns during initialization - getpunknowns::G - # `setu` functor which returns a modified MTKParameters using those parameters - setpunknowns::S -end - -function (f::GetUpdatedMTKParameters)(prob, initializesol) - p = parameter_values(prob) - p === nothing && return nothing - mtkp = copy(p) - f.setpunknowns(mtkp, f.getpunknowns(initializesol)) - mtkp -end - -struct UpdateInitializeprob{G, S} - # `getu` functor which gets all values from prob - getvals::G - # `setu` functor which updates initializeprob with values - setvals::S -end - -function (f::UpdateInitializeprob)(initializeprob, prob) - f.setvals(initializeprob, f.getvals(prob)) -end - function get_temporary_value(p, floatT = Float64) stype = symtype(unwrap(p)) return if stype == Real @@ -669,6 +643,72 @@ function concrete_getu(indp, syms::AbstractVector) return Base.Fix1(reduce, vcat) ∘ getu(indp, split_syms) end +""" + $(TYPEDSIGNATURES) + +Given a source system `srcsys` and destination system `dstsys`, return a function that +takes a value provider of `srcsys` and a value provider of `dstsys` and returns the +`MTKParameters` object of the latter with values from the former. + +# Keyword Arguments +- `initials`: Whether to include the `Initial` parameters of `dstsys` among the values + to be transferred. +- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`. +""" +function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; + initials = false, unwrap_initials = false, p_constructor = identity) + # if we call `getu` on this (and it were able to handle empty tuples) we get the + # fields of `MTKParameters` except caches. + syms = reorder_parameters( + dstsys, parameters(dstsys; initial_parameters = initials); flatten = false) + # `dstsys` is an initialization system, do basically everything is a tunable + # and tunables are a mix of different types in `srcsys`. No initials. Constants + # are going to be constants in `srcsys`, as are `nonnumeric`. + + # `syms[1]` is always the tunables because `srcsys` will have initials. + tunable_syms = syms[1] + tunable_getter = p_constructor ∘ concrete_getu(srcsys, tunable_syms) + rest_getters = map(Base.tail(Base.tail(syms))) do buf + if buf == () + return Returns(()) + else + return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) + end + end + initials_getter = if initials + initsyms = Vector{Any}(syms[2]) + allsyms = Set(all_symbols(srcsys)) + if unwrap_initials + for i in eachindex(initsyms) + sym = initsyms[i] + innersym = if operation(sym) === getindex + sym, idxs... = arguments(sym) + only(arguments(sym))[idxs...] + else + only(arguments(sym)) + end + if innersym in allsyms + initsyms[i] = innersym + end + end + end + p_constructor ∘ concrete_getu(srcsys, initsyms) + else + Returns(SizedVector{0, Float64}()) + end + getters = (tunable_getter, initials_getter, rest_getters...) + getter = let getters = getters + function _getter(valp, initprob) + oldcache = parameter_values(initprob).caches + MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp), + getters[4](valp), getters[5](valp), oldcache isa Tuple{} ? () : + copy.(oldcache)) + end + end + + return getter +end + """ $(TYPEDSIGNATURES) @@ -676,41 +716,16 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of ` with values from `srcsys`. """ function ReconstructInitializeprob( - srcsys::AbstractSystem, dstsys::AbstractSystem) + srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity) @assert is_initializesystem(dstsys) - ugetter = getu(srcsys, unknowns(dstsys)) + ugetter = u0_constructor ∘ getu(srcsys, unknowns(dstsys)) if is_split(dstsys) - # if we call `getu` on this (and it were able to handle empty tuples) we get the - # fields of `MTKParameters` except caches. - syms = reorder_parameters(dstsys, parameters(dstsys); flatten = false) - # `dstsys` is an initialization system, do basically everything is a tunable - # and tunables are a mix of different types in `srcsys`. No initials. Constants - # are going to be constants in `srcsys`, as are `nonnumeric`. - - # `syms[1]` is always the tunables because `srcsys` will have initials. - tunable_syms = syms[1] - tunable_getter = concrete_getu(srcsys, tunable_syms) - rest_getters = map(Base.tail(Base.tail(syms))) do buf - if buf == () - return Returns(()) - else - return getu(srcsys, buf) - end - end - getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...) - pgetter = let getters = getters - function _getter(valp, initprob) - oldcache = parameter_values(initprob).caches - MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp), - getters[4](valp), getters[5](valp), oldcache isa Tuple{} ? () : - copy.(oldcache)) - end - end + pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor) else syms = parameters(dstsys) - pgetter = let inner = concrete_getu(srcsys, syms) + pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor function _getter2(valp, initprob) - inner(valp) + p_constructor(inner(valp)) end end end @@ -763,6 +778,54 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) return u0, newp end +""" + $(TYPEDSIGNATURES) + +Given `sys` and its corresponding initialization system `initsys`, return the +`initializeprobpmap` function in `OverrideInitData` for the systems. +""" +function construct_initializeprobpmap( + sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity) + @assert is_initializesystem(initsys) + if is_split(sys) + return let getter = get_mtkparameters_reconstructor( + initsys, sys; initials = true, p_constructor) + function initprobpmap_split(prob, initsol) + getter(initsol, prob) + end + end + else + return let getter = getu(initsys, parameters(sys; initial_parameters = true)), + p_constructor = p_constructor + + function initprobpmap_nosplit(prob, initsol) + return p_constructor(getter(initsol)) + end + end + end +end + +function get_scimlfn(valp) + valp isa SciMLBase.AbstractSciMLFunction && return valp + if hasmethod(symbolic_container, Tuple{typeof(valp)}) && + (sc = symbolic_container(valp)) !== valp + return get_scimlfn(sc) + end + throw(ArgumentError("SciMLFunction not found. This should never happen.")) +end + +""" + $(TYPEDSIGNATURES) + +A function to be used as `update_initializeprob!` in `OverrideInitData`. Requires +`is_update_oop = Val{true}` to be passed to `update_initializeprob!`. +""" +function update_initializeprob!(initprob, prob) + p = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.getter( + prob, initprob) + return remake(initprob; p) +end + """ $(TYPEDEF) @@ -804,8 +867,8 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, GUU, SIU} """ get_updated_u0::GUU """ - A function which takes the `u0` of the problem and sets - `Initial.(unknowns(sys))`. + A function which takes parameter object and `u0` of the problem and sets + `Initial.(unknowns(sys))` in the former, returning the updated parameter object. """ set_initial_unknowns!::SIU end @@ -856,6 +919,38 @@ function (guu::GetUpdatedU0)(prob, initprob) return buffer end +struct SetInitialUnknowns{S} + setter!::S +end + +function SetInitialUnknowns(sys::AbstractSystem) + return SetInitialUnknowns(setu(sys, Initial.(unknowns(sys)))) +end + +function (siu::SetInitialUnknowns)(p::MTKParameters, u0) + if ArrayInterface.ismutable(p.initials) + siu.setter!(p, u0) + else + originalT = similar_type(p.initials) + @set! p.initials = MVector{length(p.initials), eltype(p.initials)}(p.initials) + siu.setter!(p, u0) + @set! p.initials = originalT(p.initials) + end + return p +end + +function (siu::SetInitialUnknowns)(p::Vector, u0) + if ArrayInterface.ismutable(p) + siu.setter!(p, u0) + else + originalT = similar_type(p) + p = MVector{length(p), eltype(p)}(p) + siu.setter!(p, u0) + p = originalT(p) + end + return p +end + """ $(TYPEDSIGNATURES) @@ -913,8 +1008,9 @@ function maybe_build_initialization_problem( end meta = InitializationMetadata( u0map, pmap, guesses, Vector{Equation}(initialization_eqs), - use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys), - get_initial_unknowns, setp(sys, Initial.(unknowns(sys)))) + use_scc, ReconstructInitializeprob( + sys, initializeprob.f.sys; u0_constructor, p_constructor), + get_initial_unknowns, SetInitialUnknowns(sys)) if is_time_dependent(sys) all_init_syms = Set(all_symbols(initializeprob)) @@ -930,11 +1026,8 @@ function maybe_build_initialization_problem( if initializeprobmap === nothing && isempty(punknowns) initializeprobpmap = nothing else - allsyms = all_symbols(initializeprob) - initdvs = filter(x -> any(isequal(x), allsyms), unknowns(sys)) - getpunknowns = getu(initializeprob, [punknowns; initdvs]) - setpunknowns = setp(sys, [punknowns; Initial.(initdvs)]) - initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) + initializeprobpmap = construct_initializeprobpmap( + sys, initializeprob.f.sys; p_constructor) end reqd_syms = parameter_symbols(initializeprob) @@ -942,8 +1035,7 @@ function maybe_build_initialization_problem( if initializeprobmap === nothing && initializeprobpmap === nothing update_initializeprob! = nothing else - update_initializeprob! = UpdateInitializeprob( - getu(sys, reqd_syms), setu(initializeprob, reqd_syms)) + update_initializeprob! = ModelingToolkit.update_initializeprob! end for p in punknowns @@ -967,7 +1059,7 @@ function maybe_build_initialization_problem( return (; initialization_data = SciMLBase.OverrideInitData( initializeprob, update_initializeprob!, initializeprobmap, - initializeprobpmap; metadata = meta)) + initializeprobpmap; metadata = meta, is_update_oop = Val{true})) end """ From ad0347b431ffc74606b61f0048e6e01a1d36d3b6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 21:40:50 +0530 Subject: [PATCH 15/40] fix: handle `u0_constructor`, `p_constructor` in `remake_initialization_data` --- src/systems/nonlinear/initializesystem.jl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index ce29db68b4..fe932418a4 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -581,11 +581,21 @@ function SciMLBase.remake_initialization_data( op, missing_unknowns, missing_pars = build_operating_point!(sys, u0map, pmap, defs, cmap, dvs, ps) floatT = float_type_from_varmap(op) + u0_constructor = p_constructor = identity + if newu0 isa StaticArray + u0_constructor = vals -> SymbolicUtils.Code.create_array( + typeof(newu0), floatT, Val(1), Val(length(vals)), vals...) + end + if newp isa StaticArray || newp isa MTKParameters && newp.initials isa StaticArray + p_constructor = vals -> SymbolicUtils.Code.create_array( + typeof(newp.initials), floatT, Val(1), Val(length(vals)), vals...) + end kws = maybe_build_initialization_problem( - sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; - use_scc, initialization_eqs, floatT, allow_incomplete = true) + sys, SciMLBase.isinplace(odefn), op, u0map, pmap, t0, defs, guesses, missing_unknowns; + use_scc, initialization_eqs, floatT, u0_constructor, p_constructor, allow_incomplete = true) - return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp) + odefn = remake(odefn; kws...) + return SciMLBase.remake_initialization_data(sys, odefn, newu0, t0, newp, newu0, newp) end function promote_u0_p(u0, p::MTKParameters, t0) From 926fcedfc186bae3e755c8c65f1fbb86e5f9a49e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 21:41:10 +0530 Subject: [PATCH 16/40] fix: handle immutable MTKParameters in symbolic `late_binding_update_u0_p` --- src/systems/nonlinear/initializesystem.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index fe932418a4..c7a590197f 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -640,8 +640,8 @@ function SciMLBase.late_binding_update_u0_p( return newu0, newp end - newp = p === missing ? copy(newp) : newp - + syms = [] + vals = [] allsyms = all_symbols(sys) for (k, v) in u0 v === nothing && continue @@ -653,9 +653,11 @@ function SciMLBase.late_binding_update_u0_p( k = k2 end is_parameter(sys, Initial(k)) || continue - setp(sys, Initial(k))(newp, v) + push!(syms, Initial(k)) + push!(vals, v) end + newp = setp_oop(sys, syms)(newp, vals) return newu0, newp end From 52ce6418efbc1a81e4076e32bf14b5bc833196ba Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 21:41:35 +0530 Subject: [PATCH 17/40] fix: handle edge case in floating point type promotion --- src/systems/problem_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 897b2fcc44..999e4405ff 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1187,7 +1187,7 @@ function process_SciMLProblem( u0map, pmap, defs, cmap, dvs, ps) floatT = Bool - if u0Type <: AbstractArray && eltype(u0Type) <: Real + if u0Type <: AbstractArray && eltype(u0Type) <: Real && eltype(u0Type) != Union{} floatT = float(eltype(u0Type)) else floatT = float_type_from_varmap(op, floatT) From 57e4778aa89059225eb0afef4a5499ce17c46390 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 21:43:05 +0530 Subject: [PATCH 18/40] fix: call `u0_constructor` on `resid_prototype` --- src/systems/nonlinear/nonlinearsystem.jl | 17 +--------------- src/systems/problem_utils.jl | 25 +++++++++++++++++++++++- test/nonlinearsystem.jl | 16 +++++++++++++++ 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index f8182fa28a..d0a12e22d4 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -345,16 +345,6 @@ function hessian_sparsity(sys::NonlinearSystem) unknowns(sys)) for eq in equations(sys)] end -function calculate_resid_prototype(N, u0, p) - u0ElType = u0 === nothing ? Float64 : eltype(u0) - if SciMLStructures.isscimlstructure(p) - u0ElType = promote_type( - eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]), - u0ElType) - end - return zeros(u0ElType, N) -end - """ ```julia SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys), @@ -381,6 +371,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s eval_module = @__MODULE__, sparse = false, simplify = false, initialization_data = nothing, cse = true, + resid_prototype = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunction`") @@ -402,12 +393,6 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s observedfun = ObservedFunctionCache( sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) - if length(dvs) == length(equations(sys)) - resid_prototype = nothing - else - resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p) - end - NonlinearFunction{iip}(f; sys = sys, jac = _jac === nothing ? nothing : _jac, diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 999e4405ff..fc377607e8 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1082,6 +1082,22 @@ function float_type_from_varmap(varmap, floatT = Bool) return float(floatT) end +""" + $(TYPEDSIGNATURES) + +Calculate the `resid_prototype` for a `NonlinearFunction` with `N` equations and the +provided `u0` and `p`. +""" +function calculate_resid_prototype(N::Int, u0, p) + u0ElType = u0 === nothing ? Float64 : eltype(u0) + if SciMLStructures.isscimlstructure(p) + u0ElType = promote_type( + eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]), + u0ElType) + end + return zeros(u0ElType, N) +end + """ $(TYPEDSIGNATURES) @@ -1293,7 +1309,14 @@ function process_SciMLProblem( end initialization_data = SciMLBase.remake_initialization_data( kwargs.initialization_data, kwargs, u0, t0, p, u0, p) - kwargs = merge(kwargs,) + kwargs = merge(kwargs, (; initialization_data)) + end + + if constructor <: NonlinearFunction && length(dvs) != length(eqs) + kwargs = merge(kwargs, + (; + resid_prototype = u0_constructor(calculate_resid_prototype( + length(eqs), u0, p)))) end f = constructor(sys, dvs, ps, u0; p = p, diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index a315371141..158475f7c9 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -442,3 +442,19 @@ end @test !in(D(y), vs) end end + +@testset "oop `NonlinearLeastSquaresProblem` with `u0 === nothing`" begin + @variables x y + @named sys = NonlinearSystem([0 ~ x - y], [], []; observed = [x ~ 1.0, y ~ 1.0]) + prob = NonlinearLeastSquaresProblem{false}(complete(sys), nothing) + sol = solve(prob) + resid = sol.resid + @test resid == [0.0] + @test resid isa Vector + prob = NonlinearLeastSquaresProblem{false}( + complete(sys), nothing; u0_constructor = splat(SVector)) + sol = solve(prob) + resid = sol.resid + @test resid == [0.0] + @test resid isa SVector +end From ba2bfff3c182053a71cdc326a09d3d2972e7a950 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 5 May 2025 21:43:16 +0530 Subject: [PATCH 19/40] test: test initialization on static array problems --- test/initializationsystem.jl | 80 ++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 87ba755259..df52291fbe 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1,6 +1,6 @@ using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test using StochasticDiffEq, DelayDiffEq, StochasticDelayDiffEq, JumpProcesses -using ForwardDiff +using ForwardDiff, StaticArrays using SymbolicIndexingInterface, SciMLStructures using SciMLStructures: Tunable using ModelingToolkit: t_nounits as t, D_nounits as D, observed @@ -594,22 +594,36 @@ end @parameters p q @brownian a b x = _x(t) - + sarray_ctor = splat(SVector) # `System` constructor creates appropriate type with mtkbuild # `Problem` and `alg` create the problem to test and allow calling `init` with # the correct solver. # `rhss` allows adding terms to the end of equations (only 2 equations allowed) to influence # the system type (brownian vars to turn it into an SDE). - @testset "$Problem with $(SciMLBase.parameterless_type(alg))" for (System, Problem, alg, rhss) in [ - (ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)), - (ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]), - (ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), - (ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b]) - ] + @testset "$Problem with $(SciMLBase.parameterless_type(alg)) and $ctor ctor" for ((System, Problem, alg, rhss), (ctor, expectedT)) in Iterators.product( + [ + (ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)), + (ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]), + (ModelingToolkit.System, DDEProblem, + MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), + (ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b]) + ], + [(identity, Any), (sarray_ctor, SVector)]) + u0_constructor = p_constructor = ctor + if ctor !== identity + Problem = Problem{false} + end function test_parameter(prob, sym, val) if prob.u0 !== nothing + @test prob.u0 isa expectedT @test init(prob, alg).ps[sym] ≈ val end + @test prob.p.tunable isa expectedT + initprob = prob.f.initialization_data.initializeprob + if state_values(initprob) !== nothing + @test state_values(initprob) isa expectedT + end + @test parameter_values(initprob).tunable isa expectedT @test solve(prob, alg).ps[sym] ≈ val end function test_initializesystem(sys, u0map, pmap, p, equation) @@ -626,64 +640,64 @@ end @mtkbuild sys = System( [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => missing], guesses = [p => 1.0]) pmap[p] = 2q - prob = Problem(sys, u0map, (0.0, 1.0), pmap) + prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor) test_parameter(prob, p, 2.0) prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # `missing` default, provided guess @mtkbuild sys = System( [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; defaults = [p => missing], guesses = [p => 0.0]) - prob = Problem(sys, u0map, (0.0, 1.0)) + prob = Problem(sys, u0map, (0.0, 1.0); u0_constructor, p_constructor) test_parameter(prob, p, 2.0) test_initializesystem(sys, u0map, pmap, p, 0 ~ p - x - y) prob2 = remake(prob; u0 = u0map) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # `missing` to Problem, equation from default @mtkbuild sys = System( [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) pmap[p] = missing - prob = Problem(sys, u0map, (0.0, 1.0), pmap) + prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor) test_parameter(prob, p, 2.0) test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p) prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # `missing` to Problem, provided guess @mtkbuild sys = System( [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [p => 0.0]) - prob = Problem(sys, u0map, (0.0, 1.0), pmap) + prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor) test_parameter(prob, p, 2.0) test_initializesystem(sys, u0map, pmap, p, 0 ~ x + y - p) prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # No `missing`, default and guess @mtkbuild sys = System( [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 0.0]) delete!(pmap, p) - prob = Problem(sys, u0map, (0.0, 1.0), pmap) + prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor) test_parameter(prob, p, 2.0) test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p) prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # Default overridden by Problem, guess provided @mtkbuild sys = System( [D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) _pmap = merge(pmap, Dict(p => q)) - prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor) test_parameter(prob, p, _pmap[q]) test_initializesystem(sys, u0map, _pmap, p, 0 ~ q - p) # Problem dependent value with guess, no `missing` @mtkbuild sys = System( [D(x) ~ y * q + p + rhss[1], D(y) ~ x * p + q + rhss[2]], t; guesses = [p => 0.0]) _pmap = merge(pmap, Dict(p => 3q)) - prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor) test_parameter(prob, p, 3pmap[q]) # Should not be solved for: @@ -691,7 +705,7 @@ end @mtkbuild sys = System( [D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) _pmap = merge(pmap, Dict(p => 1.0)) - prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor) @test prob.ps[p] ≈ 1.0 initsys = prob.f.initialization_data.initializeprob.f.sys @test is_parameter(initsys, p) @@ -700,7 +714,7 @@ end @parameters r::Int s::Int @mtkbuild sys = System( [D(x) ~ s * x + rhss[1], D(y) ~ y * r + rhss[2]], t; defaults = [s => 2r], guesses = [s => 1.0]) - prob = Problem(sys, u0map, (0.0, 1.0), [r => 1]) + prob = Problem(sys, u0map, (0.0, 1.0), [r => 1]; u0_constructor, p_constructor) @test prob.ps[r] == 1 @test prob.ps[s] == 2 initsys = prob.f.initialization_data.initializeprob.f.sys @@ -714,7 +728,7 @@ end # Unsatisfiable initialization prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), - [p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3]) + [p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3], u0_constructor, p_constructor) @test prob.f.initialization_data !== nothing @test solve(prob, alg).retcode == ReturnCode.InitialFailure cache = init(prob, alg) @@ -791,8 +805,17 @@ end prob_alg_combinations = zip( [NonlinearProblem, NonlinearLeastSquaresProblem], [nl_algs, nlls_algs]) - @testset "Parameter initialization" begin + sarray_ctor = splat(SVector) + @testset "Parameter initialization with ctor $ctor" for (ctor, expectedT) in [ + (identity, Any), + (sarray_ctor, SVector) + ] + u0_constructor = p_constructor = ctor function test_parameter(prob, alg, param, val) + if prob.u0 !== nothing + @test prob.u0 isa expectedT + end + @test prob.p.tunable isa expectedT integ = init(prob, alg) @test integ.ps[param]≈val rtol=1e-5 # some algorithms are a little temperamental @@ -818,7 +841,10 @@ end # guesses = [q => 1.0], initialization_eqs = [p^2 + q^2 + 2p * q ~ 0]) for (probT, algs) in prob_alg_combinations - prob = probT(sys, []) + if ctor != identity + probT = probT{false} + end + prob = probT(sys, []; u0_constructor, p_constructor) @test prob.f.initialization_data !== nothing @test prob.f.initialization_data.initializeprobmap === nothing for alg in algs @@ -826,11 +852,11 @@ end end # `update_initializeprob!` works - prob.ps[p] = -2.0 + prob = remake(prob; p = setp_oop(prob, p)(prob, -2.0)) for alg in algs test_parameter(prob, alg, q, 2.0) end - prob.ps[p] = 2.0 + prob = remake(prob; p = setp_oop(prob, p)(prob, 2.0)) # `remake` works prob2 = remake(prob; p = [p => -2.0]) From d3ce046101b5ccdfdbafacabd8d9c431f61ac900 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 8 May 2025 23:16:19 +0530 Subject: [PATCH 20/40] build: bump SciMLBase, StochasticDelayDiffEq compat --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 6d637ef0a0..e602e97fe2 100644 --- a/Project.toml +++ b/Project.toml @@ -142,7 +142,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.84" +SciMLBase = "2.88" SciMLStructures = "1.7" Serialization = "1" Setfield = "0.7, 0.8, 1" @@ -150,7 +150,7 @@ SimpleNonlinearSolve = "0.1.0, 1, 2" SparseArrays = "1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" -StochasticDelayDiffEq = "1.8.1" +StochasticDelayDiffEq = "1.10" StochasticDiffEq = "6.72.1" SymbolicIndexingInterface = "0.3.39" SymbolicUtils = "3.26.1" From d3512678bb6c8a7d5532862b540e776531366d14 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 12 May 2025 12:28:33 +0530 Subject: [PATCH 21/40] fix: handle empty `syms` in `concrete_getu` --- src/systems/problem_utils.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index fc377607e8..739d2c5d4d 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -667,7 +667,11 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac # `syms[1]` is always the tunables because `srcsys` will have initials. tunable_syms = syms[1] - tunable_getter = p_constructor ∘ concrete_getu(srcsys, tunable_syms) + tunable_getter = if isempty(tunable_syms) + Returns(SizedVector{0, Float64}()) + else + p_constructor ∘ concrete_getu(srcsys, tunable_syms) + end rest_getters = map(Base.tail(Base.tail(syms))) do buf if buf == () return Returns(()) @@ -675,7 +679,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) end end - initials_getter = if initials + initials_getter = if initials && !isempty(syms[2]) initsyms = Vector{Any}(syms[2]) allsyms = Set(all_symbols(srcsys)) if unwrap_initials From 329b7d5ab22f996151cd6060e227845785a9b36d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 12 May 2025 15:04:22 +0530 Subject: [PATCH 22/40] fix: fix `is_update_oop` passed as type --- src/systems/problem_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 739d2c5d4d..087f99dcb4 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -822,7 +822,7 @@ end $(TYPEDSIGNATURES) A function to be used as `update_initializeprob!` in `OverrideInitData`. Requires -`is_update_oop = Val{true}` to be passed to `update_initializeprob!`. +`is_update_oop = Val(true)` to be passed to `update_initializeprob!`. """ function update_initializeprob!(initprob, prob) p = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.getter( @@ -1063,7 +1063,7 @@ function maybe_build_initialization_problem( return (; initialization_data = SciMLBase.OverrideInitData( initializeprob, update_initializeprob!, initializeprobmap, - initializeprobpmap; metadata = meta, is_update_oop = Val{true})) + initializeprobpmap; metadata = meta, is_update_oop = Val(true))) end """ From 6c57557bd31e1787a5a25bcd633d35a78eaa5cd9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 12 May 2025 15:04:31 +0530 Subject: [PATCH 23/40] fix: fix call to `remake_initialization_data` --- src/systems/problem_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 087f99dcb4..3fef424ad6 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1312,7 +1312,7 @@ function process_SciMLProblem( t0 = zero(floatT) end initialization_data = SciMLBase.remake_initialization_data( - kwargs.initialization_data, kwargs, u0, t0, p, u0, p) + sys, kwargs, u0, t0, p, u0, p) kwargs = merge(kwargs, (; initialization_data)) end From f1e422c42d421f8385610d959875397207c9f11c Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 12 May 2025 12:52:39 +0000 Subject: [PATCH 24/40] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e602e97fe2..1099993e33 100644 --- a/Project.toml +++ b/Project.toml @@ -142,7 +142,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.88" +SciMLBase = "2.89.1" SciMLStructures = "1.7" Serialization = "1" Setfield = "0.7, 0.8, 1" From 7600f5f2e0d4a13a189b7c8cd0873eb544a6beee Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 12 May 2025 19:37:30 +0530 Subject: [PATCH 25/40] fix: fix `update_initializeprob!` --- src/systems/problem_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 3fef424ad6..c1e4642c79 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -825,7 +825,7 @@ A function to be used as `update_initializeprob!` in `OverrideInitData`. Require `is_update_oop = Val(true)` to be passed to `update_initializeprob!`. """ function update_initializeprob!(initprob, prob) - p = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.getter( + p = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter( prob, initprob) return remake(initprob; p) end From 4631a1e2062445d7e0e05aba175dd440673ef862 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 12 May 2025 23:49:03 +0530 Subject: [PATCH 26/40] fix: add `f` field to `MockIntegrator` --- src/linearization.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/linearization.jl b/src/linearization.jl index 77f4422b63..b30d275818 100644 --- a/src/linearization.jl +++ b/src/linearization.jl @@ -285,7 +285,7 @@ function (linfun::LinearizationFunction)(u, p, t) linfun.num_states == length(u) || error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))") integ_cache = (linfun.caches,) - integ = MockIntegrator{true}(u, p, t, integ_cache, nothing) + integ = MockIntegrator{true}(u, p, t, fun, integ_cache, nothing) u, p, success = SciMLBase.get_initial_values( linfun.prob, integ, fun, linfun.initializealg, Val(true); linfun.initialize_kwargs...) @@ -325,7 +325,7 @@ Mock `DEIntegrator` to allow using `CheckInit` without having to create a new in $(TYPEDFIELDS) """ -struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T} +struct MockIntegrator{iip, U, P, T, F, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T} """ The state vector. """ @@ -339,6 +339,10 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip """ t::T """ + The wrapped `SciMLFunction`. + """ + f::F + """ The integrator cache. """ cache::C @@ -348,8 +352,9 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip opts::O end -function MockIntegrator{iip}(u::U, p::P, t::T, cache::C, opts::O) where {iip, U, P, T, C, O} - return MockIntegrator{iip, U, P, T, C, O}(u, p, t, cache, opts) +function MockIntegrator{iip}( + u::U, p::P, t::T, f::F, cache::C, opts::O) where {iip, U, P, T, F, C, O} + return MockIntegrator{iip, U, P, T, F, C, O}(u, p, t, f, cache, opts) end SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u From a35ae7d8d796c33c4e25e290f9894f0d2d2c319b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 15 May 2025 17:23:03 +0530 Subject: [PATCH 27/40] fix: handle discretes properly in `get_mtkparameters_reconstructor` --- src/systems/problem_utils.jl | 47 ++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index c1e4642c79..48f03b07a4 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -672,13 +672,6 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac else p_constructor ∘ concrete_getu(srcsys, tunable_syms) end - rest_getters = map(Base.tail(Base.tail(syms))) do buf - if buf == () - return Returns(()) - else - return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) - end - end initials_getter = if initials && !isempty(syms[2]) initsyms = Vector{Any}(syms[2]) allsyms = Set(all_symbols(srcsys)) @@ -700,7 +693,29 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac else Returns(SizedVector{0, Float64}()) end - getters = (tunable_getter, initials_getter, rest_getters...) + discs_getter = if isempty(syms[3]) + Returns(()) + else + ic = get_index_cache(dstsys) + blockarrsizes = Tuple(map(ic.discrete_buffer_sizes) do bufsizes + map(x -> x.length, bufsizes) + end) + # discretes need to be blocked arrays + # the `getu` returns a tuple of arrays corresponding to `p.discretes` + # `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple + # `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a + # tuple of `BlockedArray`s + Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘ + Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[3]) + end + rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf + if buf == () + return Returns(()) + else + return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) + end + end + getters = (tunable_getter, initials_getter, discs_getter, rest_getters...) getter = let getters = getters function _getter(valp, initprob) oldcache = parameter_values(initprob).caches @@ -772,12 +787,14 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) copyto!(newbuf, buf) newp = repack(newbuf) end - # and initials portion - buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) - if eltype(buf) != T - newbuf = similar(buf, T) - copyto!(newbuf, buf) - newp = repack(newbuf) + if newp isa MTKParameters + # and initials portion + buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) + if eltype(buf) != T + newbuf = similar(buf, T) + copyto!(newbuf, buf) + newp = repack(newbuf) + end end return u0, newp end @@ -793,7 +810,7 @@ function construct_initializeprobpmap( @assert is_initializesystem(initsys) if is_split(sys) return let getter = get_mtkparameters_reconstructor( - initsys, sys; initials = true, p_constructor) + initsys, sys; initials = true, unwrap_initials = true, p_constructor) function initprobpmap_split(prob, initsol) getter(initsol, prob) end From 1de4cd5e11549e6d2a05874e7fb64926a9d0e4ba Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 15 May 2025 17:43:12 +0530 Subject: [PATCH 28/40] refactor: move ChainRulesCoreExt into main package --- Project.toml | 3 +- src/ModelingToolkit.jl | 4 +++ .../adjoints.jl | 32 +++++++------------ 3 files changed, 17 insertions(+), 22 deletions(-) rename ext/MTKChainRulesCoreExt.jl => src/adjoints.jl (73%) diff --git a/Project.toml b/Project.toml index 1099993e33..302c74b1bc 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -65,7 +66,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac" InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57" @@ -74,7 +74,6 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" [extensions] MTKBifurcationKitExt = "BifurcationKit" MTKCasADiDynamicOptExt = "CasADi" -MTKChainRulesCoreExt = "ChainRulesCore" MTKDeepDiffsExt = "DeepDiffs" MTKFMIExt = "FMI" MTKInfiniteOptExt = "InfiniteOpt" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6d8d89a976..9ed6b9994a 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -62,6 +62,8 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc using OffsetArrays: Origin import CommonSolve import EnumX +import ChainRulesCore +import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr @@ -204,6 +206,8 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") +include("adjoints.jl") + for S in subtypes(ModelingToolkit.AbstractSystem) S = nameof(S) @eval convert_system(::Type{<:$S}, sys::$S) = sys diff --git a/ext/MTKChainRulesCoreExt.jl b/src/adjoints.jl similarity index 73% rename from ext/MTKChainRulesCoreExt.jl rename to src/adjoints.jl index a2974ea2dd..98266de938 100644 --- a/ext/MTKChainRulesCoreExt.jl +++ b/src/adjoints.jl @@ -1,17 +1,11 @@ -module MTKChainRulesCoreExt - -import ModelingToolkit as MTK -import ChainRulesCore -import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk - -function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...) +function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...) function mtp_pullback(dt) dt = unthunk(dt) dtunables = dt isa AbstractArray ? dt : dt.tunable (NoTangent(), dtunables[1:length(tunables)], ntuple(_ -> NoTangent(), length(args))...) end - MTK.MTKParameters(tunables, args...), mtp_pullback + MTKParameters(tunables, args...), mtp_pullback end function subset_idxs(idxs, portion, template) @@ -70,23 +64,23 @@ function selected_tangents( end function ChainRulesCore.rrule( - ::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals) + ::typeof(remake_buffer), indp, oldbuf::MTKParameters, idxs, vals) if idxs isa AbstractSet idxs = collect(idxs) end idxs = map(idxs) do i - i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i) + i isa ParameterIndex ? i : parameter_index(indp, i) end - newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals) + newbuf = remake_buffer(indp, oldbuf, idxs, vals) tunable_idxs = reduce( - vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable); + vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Tunable); init = Union{Int, AbstractVector{Int}}[]) initials_idxs = reduce( - vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials); + vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Initials); init = Union{Int, AbstractVector{Int}}[]) - disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete) - const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant) - nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric) + disc_idxs = subset_idxs(idxs, SciMLStructures.Discrete(), oldbuf.discrete) + const_idxs = subset_idxs(idxs, SciMLStructures.Constants(), oldbuf.constant) + nn_idxs = subset_idxs(idxs, NONNUMERIC_PORTION, oldbuf.nonnumeric) pullback = let idxs = idxs function remake_buffer_pullback(buf′) @@ -102,13 +96,11 @@ function ChainRulesCore.rrule( oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, initials, discrete, constant, nonnumeric) idxs′ = NoTangent() - vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs) + vals′ = map(i -> _ducktyped_parameter_values(buf′, i), idxs) return f′, indp′, oldbuf′, idxs′, vals′ end end newbuf, pullback end -ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol) - -end +ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol) From 1dac3e7d92eaa49229aad4dcf450bcec62ccb04e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 15 May 2025 17:43:23 +0530 Subject: [PATCH 29/40] fix: use `@ignore_derivatives` inside `update_initializeprob!` --- src/systems/problem_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 48f03b07a4..e8ce6bd443 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -842,8 +842,8 @@ A function to be used as `update_initializeprob!` in `OverrideInitData`. Require `is_update_oop = Val(true)` to be passed to `update_initializeprob!`. """ function update_initializeprob!(initprob, prob) - p = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter( - prob, initprob) + pgetter = ChainRulesCore.@ignore_derivatives get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter + p = pgetter(prob, initprob) return remake(initprob; p) end From 179f0bd958b479b8d76013b62f87ef8c324681e7 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 15 May 2025 20:09:28 +0000 Subject: [PATCH 30/40] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 302c74b1bc..423d73b90b 100644 --- a/Project.toml +++ b/Project.toml @@ -141,7 +141,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.89.1" +SciMLBase = "2.90.0" SciMLStructures = "1.7" Serialization = "1" Setfield = "0.7, 0.8, 1" From 25699664039ad1060f1a68e1556f6ed247f0a74d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 16 May 2025 11:20:43 +0530 Subject: [PATCH 31/40] fix: add `struct IntervalNonlinearFunctionExpr` --- src/systems/nonlinear/nonlinearsystem.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index d0a12e22d4..cf1c46207f 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -495,6 +495,8 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys), !linenumbers ? Base.remove_linenums!(ex) : ex end +struct IntervalNonlinearFunctionExpr end + """ $(TYPEDSIGNATURES) From 0987a9f5cbecd211e1a441dda03b46dc93394b62 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 16 May 2025 13:23:54 +0530 Subject: [PATCH 32/40] ci: add SciMLSensitivity/Core8 to downstream CI --- .github/workflows/Downstream.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 4f84e2c45d..6d79ee746a 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -38,6 +38,7 @@ jobs: - {user: SciML, repo: MethodOfLines.jl, group: 2D_Diffusion} - {user: SciML, repo: MethodOfLines.jl, group: DAE} - {user: SciML, repo: ModelingToolkitNeuralNets.jl, group: All} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core8} - {user: Neuroblox, repo: Neuroblox.jl, group: All} steps: From 37f5320fcf250779071a578c18cdc438843b464e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 18 May 2025 02:17:25 +0530 Subject: [PATCH 33/40] build: bump SciMLBase compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 423d73b90b..ed9981dfdd 100644 --- a/Project.toml +++ b/Project.toml @@ -141,7 +141,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.90.0" +SciMLBase = "2.91.1" SciMLStructures = "1.7" Serialization = "1" Setfield = "0.7, 0.8, 1" From 22151e7801d124fece04b653aa6b83b2e5cddf74 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 May 2025 17:37:02 +0530 Subject: [PATCH 34/40] fix: only promote tunables/initials in `promote_u0_p` if non-empty --- src/systems/nonlinear/initializesystem.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index c7a590197f..d2a988dc07 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -602,10 +602,14 @@ function promote_u0_p(u0, p::MTKParameters, t0) u0 = DiffEqBase.promote_u0(u0, p.tunable, t0) u0 = DiffEqBase.promote_u0(u0, p.initials, t0) - tunables = DiffEqBase.promote_u0(p.tunable, u0, t0) - initials = DiffEqBase.promote_u0(p.initials, u0, t0) - p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables) - p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials) + if !isempty(p.tunable) + tunables = DiffEqBase.promote_u0(p.tunable, u0, t0) + p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables) + end + if !isempty(p.initials) + initials = DiffEqBase.promote_u0(p.initials, u0, t0) + p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials) + end return u0, p end From 2c0976e21581e42e204cb3568e554d62dbbcc8d4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 May 2025 17:38:00 +0530 Subject: [PATCH 35/40] refactor: reorder tests --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1b8b5e03db..86d69228ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -138,9 +138,9 @@ end activate_extensions_env() @safetestset "Dynamic Optimization Collocation Solvers" include("extensions/dynamic_optimization.jl") @safetestset "HomotopyContinuation Extension Test" include("extensions/homotopy_continuation.jl") - @safetestset "Auto Differentiation Test" include("extensions/ad.jl") @safetestset "LabelledArrays Test" include("labelledarrays.jl") @safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl") @safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl") + @safetestset "Auto Differentiation Test" include("extensions/ad.jl") end end From 6d9e49b0f8b777f4baa0b96ecfd2402d9f22d095 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 May 2025 23:09:59 +0530 Subject: [PATCH 36/40] fix: always construct unscalarized observed equations for array variables --- src/structural_transformation/symbolics_tearing.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 548c7da519..d41cbf12d2 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -860,12 +860,9 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr # map of array observed variable (unscalarized) to number of its # scalarized terms that appear in observed equations arr_obs_occurrences = Dict() - # to check if array variables occur in unscalarized form anywhere - all_vars = Set() for (i, eq) in enumerate(obs) lhs = eq.lhs rhs = eq.rhs - vars!(all_vars, rhs) # HACK 1 if cse && is_getindexed_array(rhs) @@ -920,7 +917,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr tempvar; T = Symbolics.symtype(rhs_arr))) tempvar = setmetadata( tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) - vars!(all_vars, rhs_arr) tempeq = tempvar ~ rhs_arr rhs_to_tempvar[rhs_arr] = tempvar push!(obs, tempeq) @@ -946,18 +942,10 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr cnt == 0 && continue arr_obs_occurrences[arg1] = cnt + 1 end - for eq in neweqs - vars!(all_vars, eq.rhs) - end - # also count unscalarized variables used in callbacks - for ev in Iterators.flatten((continuous_events(sys), discrete_events(sys))) - vars!(all_vars, ev) - end obs_arr_eqs = Equation[] for (arrvar, cnt) in arr_obs_occurrences cnt == length(arrvar) || continue - arrvar in all_vars || continue # firstindex returns 1 for multidimensional array symbolics firstind = first(eachindex(arrvar)) scal = [arrvar[i] for i in eachindex(arrvar)] From ee9ee8a2171a92ff2f9ccff214d94b1e83af4d02 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 May 2025 23:10:26 +0530 Subject: [PATCH 37/40] test: update tests to account for changes to array hack --- test/initializationsystem.jl | 4 ++-- test/reduction.jl | 2 +- test/structural_transformation/utils.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index df52291fbe..7c512d37af 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1347,8 +1347,8 @@ end guesses = [p[1] => q, p[2] => 2q]) @test ModelingToolkit.is_parameter_solvable(p, Dict(), defaults(sys), guesses(sys)) prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [q => 2.0]) - @test length(ModelingToolkit.observed(prob.f.initialization_data.initializeprob.f.sys)) == - 3 + initsys = prob.f.initialization_data.initializeprob.f.sys + @test length(ModelingToolkit.observed(initsys)) == 4 sol = solve(prob, Tsit5()) @test sol.ps[p] ≈ [2.0, 4.0] end diff --git a/test/reduction.jl b/test/reduction.jl index fa9029a652..adeb4005d7 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -176,7 +176,7 @@ A = reshape(1:(N^2), N, N) eqs = xs ~ A * xs @named sys′ = NonlinearSystem(eqs, [xs], []) sys = structural_simplify(sys′) -@test length(equations(sys)) == 3 && length(observed(sys)) == 2 +@test length(equations(sys)) == 3 && length(observed(sys)) == 3 # issue 958 @parameters k₁ k₂ k₋₁ E₀ diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index b5335ad6b1..67cf2f72a0 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -76,7 +76,7 @@ end end @mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t) @test length(equations(sys)) == 1 - @test length(observed(sys)) == 3 + @test length(observed(sys)) == 4 prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2]) val[] = 0 @test_nowarn prob.f(prob.u0, prob.p, 0.0) From 518959810701abdd8ae2f250153b64b4b99a4077 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 21 May 2025 00:27:52 +0530 Subject: [PATCH 38/40] fix: fix `narrow_buffer_type` for `BlockedArray` of arrays --- src/systems/parameter_buffer.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 7c34f67deb..3eb3063f77 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -198,6 +198,9 @@ function narrow_buffer_type( end function narrow_buffer_type(buffer::BlockedArray; container_type = typeof(parent(buffer))) + if eltype(buffer) <: AbstractArray + buffer = narrow_buffer_type.(buffer; container_type) + end type = Union{} for x in buffer type = promote_type(type, typeof(x)) From b1818ba0ea7ee438982c928dd902cdb0a808c4e7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 21 May 2025 13:46:36 +0530 Subject: [PATCH 39/40] refactor: remove `container_type` kwarg of `MTKParameters`, use `p_constructor` --- src/systems/parameter_buffer.jl | 44 +++++++++++---------------------- test/mtkparameters.jl | 4 +-- 2 files changed, 17 insertions(+), 31 deletions(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 3eb3063f77..c3d2a0e831 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -29,13 +29,7 @@ the default behavior). function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, t0 = nothing, substitution_limit = 1000, floatT = nothing, - container_type = Vector, p_constructor = identity) - if !(container_type <: AbstractArray) - throw(ArgumentError(""" - `container_type` for `MTKParameters` must be a subtype of `AbstractArray`. Found \ - $container_type. - """)) - end + p_constructor = identity) ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) else @@ -140,22 +134,19 @@ function MTKParameters( end end end - tunable_buffer = p_constructor(narrow_buffer_type(tunable_buffer; container_type)) + tunable_buffer = narrow_buffer_type(tunable_buffer; p_constructor) if isempty(tunable_buffer) tunable_buffer = SizedVector{0, Float64}() end - initials_buffer = p_constructor(narrow_buffer_type(initials_buffer; container_type)) + initials_buffer = narrow_buffer_type(initials_buffer; p_constructor) if isempty(initials_buffer) initials_buffer = SizedVector{0, Float64}() end - disc_buffer = p_constructor.(narrow_buffer_type.(disc_buffer; container_type)) - const_buffer = p_constructor.(narrow_buffer_type.(const_buffer; container_type)) + disc_buffer = narrow_buffer_type.(disc_buffer; p_constructor) + const_buffer = narrow_buffer_type.(const_buffer; p_constructor) # Don't narrow nonnumeric types if !isempty(nonnumeric_buffer) - nonnumeric_buffer = map(nonnumeric_buffer) do buf - p_constructor(SymbolicUtils.Code.create_array( - container_type, eltype(buf), Val(1), Val(length(buf)), buf...)) - end + nonnumeric_buffer = map(p_constructor, nonnumeric_buffer) end mtkps = MTKParameters{ @@ -172,17 +163,16 @@ function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate.. @set p.caches = buffers end -function narrow_buffer_type(buffer::AbstractArray; container_type = typeof(buffer)) +function narrow_buffer_type(buffer::AbstractArray; p_constructor = identity) type = Union{} for x in buffer type = promote_type(type, typeof(x)) end - return SymbolicUtils.Code.create_array( - container_type, type, Val(ndims(buffer)), Val(length(buffer)), buffer...) + return p_constructor(type.(buffer)) end function narrow_buffer_type( - buffer::AbstractArray{<:AbstractArray}; container_type = typeof(buffer)) + buffer::AbstractArray{<:AbstractArray}; p_constructor = identity) type = Union{} for arr in buffer for x in arr @@ -190,27 +180,23 @@ function narrow_buffer_type( end end buffer = map(buffer) do buf - SymbolicUtils.Code.create_array( - container_type, type, Val(ndims(buf)), Val(size(buf)), buf...) + p_constructor(type.(buf)) end - return SymbolicUtils.Code.create_array( - container_type, nothing, Val(ndims(buffer)), Val(size(buffer)), buffer...) + return p_constructor(buffer) end -function narrow_buffer_type(buffer::BlockedArray; container_type = typeof(parent(buffer))) +function narrow_buffer_type(buffer::BlockedArray; p_constructor = identity) if eltype(buffer) <: AbstractArray - buffer = narrow_buffer_type.(buffer; container_type) + buffer = narrow_buffer_type.(buffer; p_constructor) end type = Union{} for x in buffer type = promote_type(type, typeof(x)) end - tmp = SymbolicUtils.Code.create_array( - container_type, type, Val(ndims(buffer)), Val(size(buffer)), buffer...) + tmp = p_constructor(type.(buffer)) blocks = ntuple(Val(ndims(buffer))) do i bsizes = blocksizes(buffer, i) - SymbolicUtils.Code.create_array( - container_type, Int, Val(1), Val(length(bsizes)), bsizes...) + p_constructor(Int.(bsizes)) end return BlockedArray(tmp, blocks...) end diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 9c857b6a55..b7acbb84a8 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -27,8 +27,8 @@ end @test getp(sys, a)(ps) == getp(sys, b)(ps) == getp(sys, c)(ps) == 0.0 @test getp(sys, d)(ps) isa Int -@testset "`container_type`" begin - ps2 = MTKParameters(sys, ivs; container_type = SVector) +@testset "`p_constructor`" begin + ps2 = MTKParameters(sys, ivs; p_constructor = x -> SArray{Tuple{size(x)...}}(x)) @test ps2.tunable isa SVector @test ps2.initials isa SVector @test ps2.discrete isa Tuple{<:BlockedVector{Float64, <:SVector}} From 01dd661057a75f25ebe77a79d2cfe017b2caf15f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 21 May 2025 13:46:50 +0530 Subject: [PATCH 40/40] fix: fix unwrapping of views in MTKParameters reconstructor --- src/systems/problem_utils.jl | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index e8ce6bd443..58173dd46c 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -643,6 +643,32 @@ function concrete_getu(indp, syms::AbstractVector) return Base.Fix1(reduce, vcat) ∘ getu(indp, split_syms) end +""" + $(TYPEDEF) + +A callable struct which applies `p_constructor` to possibly nested arrays. It also +ensures that views (including nested ones) are concretized. +""" +struct PConstructorApplicator{F} + p_constructor::F +end + +function (pca::PConstructorApplicator)(x::AbstractArray) + pca.p_constructor(x) +end + +function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray) + collect(x) +end + +function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray}) + collect(pca.(x)) +end + +function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray}) + pca.p_constructor(pca.(x)) +end + """ $(TYPEDSIGNATURES) @@ -657,6 +683,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns """ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; initials = false, unwrap_initials = false, p_constructor = identity) + p_constructor = PConstructorApplicator(p_constructor) # if we call `getu` on this (and it were able to handle empty tuples) we get the # fields of `MTKParameters` except caches. syms = reorder_parameters( @@ -698,7 +725,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac else ic = get_index_cache(dstsys) blockarrsizes = Tuple(map(ic.discrete_buffer_sizes) do bufsizes - map(x -> x.length, bufsizes) + p_constructor(map(x -> x.length, bufsizes)) end) # discretes need to be blocked arrays # the `getu` returns a tuple of arrays corresponding to `p.discretes` @@ -706,7 +733,8 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac # `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a # tuple of `BlockedArray`s Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘ - Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[3]) + Base.Fix1(broadcast, p_constructor) ∘ + getu(srcsys, syms[3]) end rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf if buf == () @@ -1307,7 +1335,7 @@ function process_SciMLProblem( if !(pType <: AbstractArray) pType = Array end - p = MTKParameters(sys, op; floatT = floatT, container_type = pType, p_constructor) + p = MTKParameters(sys, op; floatT = floatT, p_constructor) else p = p_constructor(better_varmap_to_vars(op, ps; tofloat, container_type = pType)) end