diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 4f84e2c45d..6d79ee746a 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -38,6 +38,7 @@ jobs: - {user: SciML, repo: MethodOfLines.jl, group: 2D_Diffusion} - {user: SciML, repo: MethodOfLines.jl, group: DAE} - {user: SciML, repo: ModelingToolkitNeuralNets.jl, group: All} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core8} - {user: Neuroblox, repo: Neuroblox.jl, group: All} steps: diff --git a/Project.toml b/Project.toml index 6d637ef0a0..ed9981dfdd 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -65,7 +66,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac" InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57" @@ -74,7 +74,6 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" [extensions] MTKBifurcationKitExt = "BifurcationKit" MTKCasADiDynamicOptExt = "CasADi" -MTKChainRulesCoreExt = "ChainRulesCore" MTKDeepDiffsExt = "DeepDiffs" MTKFMIExt = "FMI" MTKInfiniteOptExt = "InfiniteOpt" @@ -142,7 +141,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.84" +SciMLBase = "2.91.1" SciMLStructures = "1.7" Serialization = "1" Setfield = "0.7, 0.8, 1" @@ -150,7 +149,7 @@ SimpleNonlinearSolve = "0.1.0, 1, 2" SparseArrays = "1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" -StochasticDelayDiffEq = "1.8.1" +StochasticDelayDiffEq = "1.10" StochasticDiffEq = "6.72.1" SymbolicIndexingInterface = "0.3.39" SymbolicUtils = "3.26.1" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6d8d89a976..9ed6b9994a 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -62,6 +62,8 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc using OffsetArrays: Origin import CommonSolve import EnumX +import ChainRulesCore +import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr @@ -204,6 +206,8 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") +include("adjoints.jl") + for S in subtypes(ModelingToolkit.AbstractSystem) S = nameof(S) @eval convert_system(::Type{<:$S}, sys::$S) = sys diff --git a/ext/MTKChainRulesCoreExt.jl b/src/adjoints.jl similarity index 73% rename from ext/MTKChainRulesCoreExt.jl rename to src/adjoints.jl index a2974ea2dd..98266de938 100644 --- a/ext/MTKChainRulesCoreExt.jl +++ b/src/adjoints.jl @@ -1,17 +1,11 @@ -module MTKChainRulesCoreExt - -import ModelingToolkit as MTK -import ChainRulesCore -import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk - -function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...) +function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...) function mtp_pullback(dt) dt = unthunk(dt) dtunables = dt isa AbstractArray ? dt : dt.tunable (NoTangent(), dtunables[1:length(tunables)], ntuple(_ -> NoTangent(), length(args))...) end - MTK.MTKParameters(tunables, args...), mtp_pullback + MTKParameters(tunables, args...), mtp_pullback end function subset_idxs(idxs, portion, template) @@ -70,23 +64,23 @@ function selected_tangents( end function ChainRulesCore.rrule( - ::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals) + ::typeof(remake_buffer), indp, oldbuf::MTKParameters, idxs, vals) if idxs isa AbstractSet idxs = collect(idxs) end idxs = map(idxs) do i - i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i) + i isa ParameterIndex ? i : parameter_index(indp, i) end - newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals) + newbuf = remake_buffer(indp, oldbuf, idxs, vals) tunable_idxs = reduce( - vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable); + vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Tunable); init = Union{Int, AbstractVector{Int}}[]) initials_idxs = reduce( - vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials); + vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Initials); init = Union{Int, AbstractVector{Int}}[]) - disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete) - const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant) - nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric) + disc_idxs = subset_idxs(idxs, SciMLStructures.Discrete(), oldbuf.discrete) + const_idxs = subset_idxs(idxs, SciMLStructures.Constants(), oldbuf.constant) + nn_idxs = subset_idxs(idxs, NONNUMERIC_PORTION, oldbuf.nonnumeric) pullback = let idxs = idxs function remake_buffer_pullback(buf′) @@ -102,13 +96,11 @@ function ChainRulesCore.rrule( oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, initials, discrete, constant, nonnumeric) idxs′ = NoTangent() - vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs) + vals′ = map(i -> _ducktyped_parameter_values(buf′, i), idxs) return f′, indp′, oldbuf′, idxs′, vals′ end end newbuf, pullback end -ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol) - -end +ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol) diff --git a/src/linearization.jl b/src/linearization.jl index 77f4422b63..b30d275818 100644 --- a/src/linearization.jl +++ b/src/linearization.jl @@ -285,7 +285,7 @@ function (linfun::LinearizationFunction)(u, p, t) linfun.num_states == length(u) || error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))") integ_cache = (linfun.caches,) - integ = MockIntegrator{true}(u, p, t, integ_cache, nothing) + integ = MockIntegrator{true}(u, p, t, fun, integ_cache, nothing) u, p, success = SciMLBase.get_initial_values( linfun.prob, integ, fun, linfun.initializealg, Val(true); linfun.initialize_kwargs...) @@ -325,7 +325,7 @@ Mock `DEIntegrator` to allow using `CheckInit` without having to create a new in $(TYPEDFIELDS) """ -struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T} +struct MockIntegrator{iip, U, P, T, F, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T} """ The state vector. """ @@ -339,6 +339,10 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip """ t::T """ + The wrapped `SciMLFunction`. + """ + f::F + """ The integrator cache. """ cache::C @@ -348,8 +352,9 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip opts::O end -function MockIntegrator{iip}(u::U, p::P, t::T, cache::C, opts::O) where {iip, U, P, T, C, O} - return MockIntegrator{iip, U, P, T, C, O}(u, p, t, cache, opts) +function MockIntegrator{iip}( + u::U, p::P, t::T, f::F, cache::C, opts::O) where {iip, U, P, T, F, C, O} + return MockIntegrator{iip, U, P, T, F, C, O}(u, p, t, f, cache, opts) end SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 548c7da519..d41cbf12d2 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -860,12 +860,9 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr # map of array observed variable (unscalarized) to number of its # scalarized terms that appear in observed equations arr_obs_occurrences = Dict() - # to check if array variables occur in unscalarized form anywhere - all_vars = Set() for (i, eq) in enumerate(obs) lhs = eq.lhs rhs = eq.rhs - vars!(all_vars, rhs) # HACK 1 if cse && is_getindexed_array(rhs) @@ -920,7 +917,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr tempvar; T = Symbolics.symtype(rhs_arr))) tempvar = setmetadata( tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) - vars!(all_vars, rhs_arr) tempeq = tempvar ~ rhs_arr rhs_to_tempvar[rhs_arr] = tempvar push!(obs, tempeq) @@ -946,18 +942,10 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr cnt == 0 && continue arr_obs_occurrences[arg1] = cnt + 1 end - for eq in neweqs - vars!(all_vars, eq.rhs) - end - # also count unscalarized variables used in callbacks - for ev in Iterators.flatten((continuous_events(sys), discrete_events(sys))) - vars!(all_vars, ev) - end obs_arr_eqs = Equation[] for (arrvar, cnt) in arr_obs_occurrences cnt == length(arrvar) || continue - arrvar in all_vars || continue # firstindex returns 1 for multidimensional array symbolics firstind = first(eachindex(arrvar)) scal = [arrvar[i] for i in eachindex(arrvar)] diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 942e508644..d802f49fee 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1479,7 +1479,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, end if simplify_system - isys = structural_simplify(isys; fully_determined) + isys = structural_simplify(isys; fully_determined, split = is_split(sys)) end ts = get_tearing_state(isys) @@ -1554,6 +1554,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, else NonlinearLeastSquaresProblem end - TProb(isys, u0map, parammap; kwargs..., + TProb{iip}(isys, u0map, parammap; kwargs..., build_initializeprob = false, is_initializeprob = true) end diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 06f5e1b623..34b6df3bf5 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -408,7 +408,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.") end - _f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + _f, u0, p = process_SciMLProblem(EmptySciMLFunction{true}, sys, u0map, parammap; t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false, cse) f = DiffEqBase.DISCRETE_INPLACE_DEFAULT @@ -449,7 +449,7 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`") end - _, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + _, u0, p = process_SciMLProblem(EmptySciMLFunction{iip}, sys, u0map, parammap; t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false) # identity function to make syms works quote @@ -506,7 +506,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi return ODEProblem(osys, u0map, tspan, parammap; check_length = false, build_initializeprob = false, kwargs...) else - _, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + _, u0, p = process_SciMLProblem(EmptySciMLFunction{true}, sys, u0map, parammap; t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false, cse) f = (du, u, p, t) -> (du .= 0; nothing) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index c9d0c8f3a5..d2a988dc07 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -513,8 +513,7 @@ function SciMLBase.remake_initialization_data( length(oldinitprob.f.resid_prototype), new_initu0, new_initp)) end initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp) - return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!, - oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata) + return @set oldinitdata.initializeprob = initprob end dvs = unknowns(sys) @@ -582,21 +581,35 @@ function SciMLBase.remake_initialization_data( op, missing_unknowns, missing_pars = build_operating_point!(sys, u0map, pmap, defs, cmap, dvs, ps) floatT = float_type_from_varmap(op) + u0_constructor = p_constructor = identity + if newu0 isa StaticArray + u0_constructor = vals -> SymbolicUtils.Code.create_array( + typeof(newu0), floatT, Val(1), Val(length(vals)), vals...) + end + if newp isa StaticArray || newp isa MTKParameters && newp.initials isa StaticArray + p_constructor = vals -> SymbolicUtils.Code.create_array( + typeof(newp.initials), floatT, Val(1), Val(length(vals)), vals...) + end kws = maybe_build_initialization_problem( - sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; - use_scc, initialization_eqs, floatT, allow_incomplete = true) + sys, SciMLBase.isinplace(odefn), op, u0map, pmap, t0, defs, guesses, missing_unknowns; + use_scc, initialization_eqs, floatT, u0_constructor, p_constructor, allow_incomplete = true) - return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp) + odefn = remake(odefn; kws...) + return SciMLBase.remake_initialization_data(sys, odefn, newu0, t0, newp, newu0, newp) end function promote_u0_p(u0, p::MTKParameters, t0) u0 = DiffEqBase.promote_u0(u0, p.tunable, t0) u0 = DiffEqBase.promote_u0(u0, p.initials, t0) - tunables = DiffEqBase.promote_u0(p.tunable, u0, t0) - initials = DiffEqBase.promote_u0(p.initials, u0, t0) - p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables) - p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials) + if !isempty(p.tunable) + tunables = DiffEqBase.promote_u0(p.tunable, u0, t0) + p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables) + end + if !isempty(p.initials) + initials = DiffEqBase.promote_u0(p.initials, u0, t0) + p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials) + end return u0, p end @@ -627,12 +640,12 @@ function SciMLBase.late_binding_update_u0_p( if length(newu0) != length(prob.u0) throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))")) end - meta.set_initial_unknowns!(newp, newu0) + newp = meta.set_initial_unknowns!(newp, newu0) return newu0, newp end - newp = p === missing ? copy(newp) : newp - + syms = [] + vals = [] allsyms = all_symbols(sys) for (k, v) in u0 v === nothing && continue @@ -644,9 +657,11 @@ function SciMLBase.late_binding_update_u0_p( k = k2 end is_parameter(sys, Initial(k)) || continue - setp(sys, Initial(k))(newp, v) + push!(syms, Initial(k)) + push!(vals, v) end + newp = setp_oop(sys, syms)(newp, vals) return newu0, newp end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 7146fb6b5e..cf1c46207f 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -345,16 +345,6 @@ function hessian_sparsity(sys::NonlinearSystem) unknowns(sys)) for eq in equations(sys)] end -function calculate_resid_prototype(N, u0, p) - u0ElType = u0 === nothing ? Float64 : eltype(u0) - if SciMLStructures.isscimlstructure(p) - u0ElType = promote_type( - eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]), - u0ElType) - end - return zeros(u0ElType, N) -end - """ ```julia SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys), @@ -381,6 +371,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s eval_module = @__MODULE__, sparse = false, simplify = false, initialization_data = nothing, cse = true, + resid_prototype = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunction`") @@ -402,12 +393,6 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s observedfun = ObservedFunctionCache( sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) - if length(dvs) == length(equations(sys)) - resid_prototype = nothing - else - resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p) - end - NonlinearFunction{iip}(f; sys = sys, jac = _jac === nothing ? nothing : _jac, @@ -510,6 +495,8 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys), !linenumbers ? Base.remove_linenums!(ex) : ex end +struct IntervalNonlinearFunctionExpr end + """ $(TYPEDSIGNATURES) @@ -705,7 +692,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, obs = observed(sys) _, u0, p = process_SciMLProblem( - EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...) + EmptySciMLFunction{iip}, sys, u0map, parammap; eval_expression, eval_module, kwargs...) explicitfuns = [] nlfuns = [] @@ -835,7 +822,9 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, subprobs = [] for (f, vscc) in zip(nlfuns, var_sccs) - prob = NonlinearProblem(f, u0[vscc], p) + _u0 = SymbolicUtils.Code.create_array( + typeof(u0), eltype(u0), Val(1), Val(length(vscc)), u0[vscc]...) + prob = NonlinearProblem{iip}(f, _u0, p) push!(subprobs, prob) end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 6142c95776..c3d2a0e831 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -28,7 +28,8 @@ the default behavior). """ function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, - t0 = nothing, substitution_limit = 1000, floatT = nothing) + t0 = nothing, substitution_limit = 1000, floatT = nothing, + p_constructor = identity) ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) else @@ -133,18 +134,20 @@ function MTKParameters( end end end - tunable_buffer = narrow_buffer_type(tunable_buffer) + tunable_buffer = narrow_buffer_type(tunable_buffer; p_constructor) if isempty(tunable_buffer) tunable_buffer = SizedVector{0, Float64}() end - initials_buffer = narrow_buffer_type(initials_buffer) + initials_buffer = narrow_buffer_type(initials_buffer; p_constructor) if isempty(initials_buffer) initials_buffer = SizedVector{0, Float64}() end - disc_buffer = narrow_buffer_type.(disc_buffer) - const_buffer = narrow_buffer_type.(const_buffer) + disc_buffer = narrow_buffer_type.(disc_buffer; p_constructor) + const_buffer = narrow_buffer_type.(const_buffer; p_constructor) # Don't narrow nonnumeric types - nonnumeric_buffer = nonnumeric_buffer + if !isempty(nonnumeric_buffer) + nonnumeric_buffer = map(p_constructor, nonnumeric_buffer) + end mtkps = MTKParameters{ typeof(tunable_buffer), typeof(initials_buffer), typeof(disc_buffer), @@ -160,21 +163,42 @@ function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate.. @set p.caches = buffers end -function narrow_buffer_type(buffer::AbstractArray) +function narrow_buffer_type(buffer::AbstractArray; p_constructor = identity) type = Union{} for x in buffer type = promote_type(type, typeof(x)) end - return convert.(type, buffer) + return p_constructor(type.(buffer)) end -function narrow_buffer_type(buffer::AbstractArray{<:AbstractArray}) - buffer = narrow_buffer_type.(buffer) +function narrow_buffer_type( + buffer::AbstractArray{<:AbstractArray}; p_constructor = identity) + type = Union{} + for arr in buffer + for x in arr + type = promote_type(type, typeof(x)) + end + end + buffer = map(buffer) do buf + p_constructor(type.(buf)) + end + return p_constructor(buffer) +end + +function narrow_buffer_type(buffer::BlockedArray; p_constructor = identity) + if eltype(buffer) <: AbstractArray + buffer = narrow_buffer_type.(buffer; p_constructor) + end type = Union{} for x in buffer - type = promote_type(type, eltype(x)) + type = promote_type(type, typeof(x)) end - return broadcast.(convert, type, buffer) + tmp = p_constructor(type.(buffer)) + blocks = ntuple(Val(ndims(buffer))) do i + bsizes = blocksizes(buffer, i) + p_constructor(Int.(bsizes)) + end + return BlockedArray(tmp, blocks...) end function buffer_to_arraypartition(buf) @@ -331,6 +355,14 @@ function Base.copy(p::MTKParameters) ) end +function ArrayInterface.ismutable(::Type{MTKParameters{ + T, I, D, C, N, H}}) where {T, I, D, C, N, H} + ArrayInterface.ismutable(T) || ArrayInterface.ismutable(I) || + any(ArrayInterface.ismutable, fieldtypes(D)) || + any(ArrayInterface.ismutable, fieldtypes(C)) || + any(ArrayInterface.ismutable, fieldtypes(N)) +end + function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex) _ducktyped_parameter_values(p, pind) end @@ -594,8 +626,9 @@ end nonnumerics = $(Expr(:tuple, (:($similar(oldbuf.nonnumeric[$i], $(nonnumericT[i]))) for i in 1:length(nonnumericT))...)) $((:($copyto!(nonnumerics[$i], oldbuf.nonnumeric[$i])) for i in 1:length(nonnumericT))...) + caches = copy.(oldbuf.caches) newbuf = MTKParameters( - tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches)) + tunables, initials, discretes, constants, nonnumerics, caches) end if idxs <: AbstractArray push!(expr.args, :(for (idx, val) in zip(idxs, vals) @@ -606,6 +639,22 @@ end push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i]))) end end + if !ArrayInterface.ismutable(oldbuf) + push!(expr.args, :(tunables = $similar_type($T, $tunablesT)(tunables))) + push!(expr.args, :(initials = $similar_type($I, $initialsT)(initials))) + push!(expr.args, + :(discretes = $(Expr(:tuple, + (:($similar_type($(fieldtype(D, i)), $(discretesT[i]))(discretes[$i])) for i in 1:length(discretesT))...)))) + push!(expr.args, + :(constants = $(Expr(:tuple, + (:($similar_type($(fieldtype(C, i)), $(constantsT[i]))(constants[$i])) for i in 1:length(constantsT))...)))) + push!(expr.args, + :(nonnumerics = $(Expr(:tuple, + (:($similar_type($(fieldtype(C, i)), $(nonnumericT[i]))(nonnumerics[$i])) for i in 1:length(nonnumericT))...)))) + push!(expr.args, + :(newbuf = MTKParameters( + tunables, initials, discretes, constants, nonnumerics, caches))) + end push!(expr.args, :(return newbuf)) return expr @@ -705,6 +754,19 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru oldbuf.constant, newbuf.constant) @set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.( oldbuf.nonnumeric, newbuf.nonnumeric) + if !ArrayInterface.ismutable(oldbuf) + @set! newbuf.tunable = similar_type(oldbuf.tunable, eltype(newbuf.tunable))(newbuf.tunable) + @set! newbuf.initials = similar_type(oldbuf.initials, eltype(newbuf.initials))(newbuf.initials) + @set! newbuf.discrete = ntuple(Val(length(newbuf.discrete))) do i + similar_type.(oldbuf.discrete[i], eltype(newbuf.discrete[i]))(newbuf.discrete[i]) + end + @set! newbuf.constant = ntuple(Val(length(newbuf.constant))) do i + similar_type.(oldbuf.constant[i], eltype(newbuf.constant[i]))(newbuf.constant[i]) + end + @set! newbuf.nonnumeric = ntuple(Val(length(newbuf.nonnumeric))) do i + similar_type.(oldbuf.nonnumeric[i], eltype(newbuf.nonnumeric[i]))(newbuf.nonnumeric[i]) + end + end return newbuf end diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 4525b0e46b..58173dd46c 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -491,32 +491,6 @@ function scalarize_varmap!(varmap::AbstractDict) return varmap end -struct GetUpdatedMTKParameters{G, S} - # `getu` functor which gets parameters that are unknowns during initialization - getpunknowns::G - # `setu` functor which returns a modified MTKParameters using those parameters - setpunknowns::S -end - -function (f::GetUpdatedMTKParameters)(prob, initializesol) - p = parameter_values(prob) - p === nothing && return nothing - mtkp = copy(p) - f.setpunknowns(mtkp, f.getpunknowns(initializesol)) - mtkp -end - -struct UpdateInitializeprob{G, S} - # `getu` functor which gets all values from prob - getvals::G - # `setu` functor which updates initializeprob with values - setvals::S -end - -function (f::UpdateInitializeprob)(initializeprob, prob) - f.setvals(initializeprob, f.getvals(prob)) -end - function get_temporary_value(p, floatT = Float64) stype = symtype(unwrap(p)) return if stype == Real @@ -539,13 +513,13 @@ A simple utility meant to be used as the `constructor` passed to `process_SciMLP case constructing a SciMLFunction is not required. The arguments passed to it are available in the `args` field, and the keyword arguments in the `kwargs` field. """ -struct EmptySciMLFunction{A, K} +struct EmptySciMLFunction{iip, A, K} <: SciMLBase.AbstractSciMLFunction{iip} args::A kwargs::K end -function EmptySciMLFunction(args...; kwargs...) - return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs) +function EmptySciMLFunction{iip}(args...; kwargs...) where {iip} + return EmptySciMLFunction{iip, typeof(args), typeof(kwargs)}(args, kwargs) end """ @@ -669,6 +643,119 @@ function concrete_getu(indp, syms::AbstractVector) return Base.Fix1(reduce, vcat) ∘ getu(indp, split_syms) end +""" + $(TYPEDEF) + +A callable struct which applies `p_constructor` to possibly nested arrays. It also +ensures that views (including nested ones) are concretized. +""" +struct PConstructorApplicator{F} + p_constructor::F +end + +function (pca::PConstructorApplicator)(x::AbstractArray) + pca.p_constructor(x) +end + +function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray) + collect(x) +end + +function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray}) + collect(pca.(x)) +end + +function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray}) + pca.p_constructor(pca.(x)) +end + +""" + $(TYPEDSIGNATURES) + +Given a source system `srcsys` and destination system `dstsys`, return a function that +takes a value provider of `srcsys` and a value provider of `dstsys` and returns the +`MTKParameters` object of the latter with values from the former. + +# Keyword Arguments +- `initials`: Whether to include the `Initial` parameters of `dstsys` among the values + to be transferred. +- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`. +""" +function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; + initials = false, unwrap_initials = false, p_constructor = identity) + p_constructor = PConstructorApplicator(p_constructor) + # if we call `getu` on this (and it were able to handle empty tuples) we get the + # fields of `MTKParameters` except caches. + syms = reorder_parameters( + dstsys, parameters(dstsys; initial_parameters = initials); flatten = false) + # `dstsys` is an initialization system, do basically everything is a tunable + # and tunables are a mix of different types in `srcsys`. No initials. Constants + # are going to be constants in `srcsys`, as are `nonnumeric`. + + # `syms[1]` is always the tunables because `srcsys` will have initials. + tunable_syms = syms[1] + tunable_getter = if isempty(tunable_syms) + Returns(SizedVector{0, Float64}()) + else + p_constructor ∘ concrete_getu(srcsys, tunable_syms) + end + initials_getter = if initials && !isempty(syms[2]) + initsyms = Vector{Any}(syms[2]) + allsyms = Set(all_symbols(srcsys)) + if unwrap_initials + for i in eachindex(initsyms) + sym = initsyms[i] + innersym = if operation(sym) === getindex + sym, idxs... = arguments(sym) + only(arguments(sym))[idxs...] + else + only(arguments(sym)) + end + if innersym in allsyms + initsyms[i] = innersym + end + end + end + p_constructor ∘ concrete_getu(srcsys, initsyms) + else + Returns(SizedVector{0, Float64}()) + end + discs_getter = if isempty(syms[3]) + Returns(()) + else + ic = get_index_cache(dstsys) + blockarrsizes = Tuple(map(ic.discrete_buffer_sizes) do bufsizes + p_constructor(map(x -> x.length, bufsizes)) + end) + # discretes need to be blocked arrays + # the `getu` returns a tuple of arrays corresponding to `p.discretes` + # `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple + # `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a + # tuple of `BlockedArray`s + Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘ + Base.Fix1(broadcast, p_constructor) ∘ + getu(srcsys, syms[3]) + end + rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf + if buf == () + return Returns(()) + else + return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) + end + end + getters = (tunable_getter, initials_getter, discs_getter, rest_getters...) + getter = let getters = getters + function _getter(valp, initprob) + oldcache = parameter_values(initprob).caches + MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp), + getters[4](valp), getters[5](valp), oldcache isa Tuple{} ? () : + copy.(oldcache)) + end + end + + return getter +end + """ $(TYPEDSIGNATURES) @@ -676,41 +763,16 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of ` with values from `srcsys`. """ function ReconstructInitializeprob( - srcsys::AbstractSystem, dstsys::AbstractSystem) + srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity) @assert is_initializesystem(dstsys) - ugetter = getu(srcsys, unknowns(dstsys)) + ugetter = u0_constructor ∘ getu(srcsys, unknowns(dstsys)) if is_split(dstsys) - # if we call `getu` on this (and it were able to handle empty tuples) we get the - # fields of `MTKParameters` except caches. - syms = reorder_parameters(dstsys, parameters(dstsys); flatten = false) - # `dstsys` is an initialization system, do basically everything is a tunable - # and tunables are a mix of different types in `srcsys`. No initials. Constants - # are going to be constants in `srcsys`, as are `nonnumeric`. - - # `syms[1]` is always the tunables because `srcsys` will have initials. - tunable_syms = syms[1] - tunable_getter = concrete_getu(srcsys, tunable_syms) - rest_getters = map(Base.tail(Base.tail(syms))) do buf - if buf == () - return Returns(()) - else - return getu(srcsys, buf) - end - end - getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...) - pgetter = let getters = getters - function _getter(valp, initprob) - oldcache = parameter_values(initprob).caches - MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp), - getters[4](valp), getters[5](valp), oldcache isa Tuple{} ? () : - copy.(oldcache)) - end - end + pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor) else syms = parameters(dstsys) - pgetter = let inner = concrete_getu(srcsys, syms) + pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor function _getter2(valp, initprob) - inner(valp) + p_constructor(inner(valp)) end end end @@ -753,16 +815,66 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) copyto!(newbuf, buf) newp = repack(newbuf) end - # and initials portion - buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) - if eltype(buf) != T - newbuf = similar(buf, T) - copyto!(newbuf, buf) - newp = repack(newbuf) + if newp isa MTKParameters + # and initials portion + buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) + if eltype(buf) != T + newbuf = similar(buf, T) + copyto!(newbuf, buf) + newp = repack(newbuf) + end end return u0, newp end +""" + $(TYPEDSIGNATURES) + +Given `sys` and its corresponding initialization system `initsys`, return the +`initializeprobpmap` function in `OverrideInitData` for the systems. +""" +function construct_initializeprobpmap( + sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity) + @assert is_initializesystem(initsys) + if is_split(sys) + return let getter = get_mtkparameters_reconstructor( + initsys, sys; initials = true, unwrap_initials = true, p_constructor) + function initprobpmap_split(prob, initsol) + getter(initsol, prob) + end + end + else + return let getter = getu(initsys, parameters(sys; initial_parameters = true)), + p_constructor = p_constructor + + function initprobpmap_nosplit(prob, initsol) + return p_constructor(getter(initsol)) + end + end + end +end + +function get_scimlfn(valp) + valp isa SciMLBase.AbstractSciMLFunction && return valp + if hasmethod(symbolic_container, Tuple{typeof(valp)}) && + (sc = symbolic_container(valp)) !== valp + return get_scimlfn(sc) + end + throw(ArgumentError("SciMLFunction not found. This should never happen.")) +end + +""" + $(TYPEDSIGNATURES) + +A function to be used as `update_initializeprob!` in `OverrideInitData`. Requires +`is_update_oop = Val(true)` to be passed to `update_initializeprob!`. +""" +function update_initializeprob!(initprob, prob) + pgetter = ChainRulesCore.@ignore_derivatives get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter + p = pgetter(prob, initprob) + return remake(initprob; p) +end + """ $(TYPEDEF) @@ -804,8 +916,8 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, GUU, SIU} """ get_updated_u0::GUU """ - A function which takes the `u0` of the problem and sets - `Initial.(unknowns(sys))`. + A function which takes parameter object and `u0` of the problem and sets + `Initial.(unknowns(sys))` in the former, returning the updated parameter object. """ set_initial_unknowns!::SIU end @@ -856,6 +968,38 @@ function (guu::GetUpdatedU0)(prob, initprob) return buffer end +struct SetInitialUnknowns{S} + setter!::S +end + +function SetInitialUnknowns(sys::AbstractSystem) + return SetInitialUnknowns(setu(sys, Initial.(unknowns(sys)))) +end + +function (siu::SetInitialUnknowns)(p::MTKParameters, u0) + if ArrayInterface.ismutable(p.initials) + siu.setter!(p, u0) + else + originalT = similar_type(p.initials) + @set! p.initials = MVector{length(p.initials), eltype(p.initials)}(p.initials) + siu.setter!(p, u0) + @set! p.initials = originalT(p.initials) + end + return p +end + +function (siu::SetInitialUnknowns)(p::Vector, u0) + if ArrayInterface.ismutable(p) + siu.setter!(p, u0) + else + originalT = similar_type(p) + p = MVector{length(p), eltype(p)}(p) + siu.setter!(p, u0) + p = originalT(p) + end + return p +end + """ $(TYPEDSIGNATURES) @@ -867,19 +1011,27 @@ denotes whether the `SciMLProblem` being constructed is in implicit DAE form (`D All other keyword arguments are forwarded to `InitializationProblem`. """ function maybe_build_initialization_problem( - sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs, + sys::AbstractSystem, iip, op::AbstractDict, u0map, pmap, t, defs, guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, - floatT = Float64, initialization_eqs = [], use_scc = true, kwargs...) + p_constructor = identity, floatT = Float64, initialization_eqs = [], + use_scc = true, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) t = zero(floatT) end - initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}( - sys, t, u0map, pmap; guesses, initialization_eqs, use_scc, kwargs...) + initializeprob = ModelingToolkit.InitializationProblem{iip}( + sys, t, u0map, pmap; guesses, initialization_eqs, + use_scc, u0_constructor, p_constructor, kwargs...) if state_values(initializeprob) !== nothing - initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob))) + _u0 = state_values(initializeprob) + if ArrayInterface.ismutable(_u0) + _u0 = floatT.(_u0) + else + _u0 = similar_type(_u0, floatT)(_u0) + end + initializeprob = remake(initializeprob; u0 = _u0) end initp = parameter_values(initializeprob) if is_split(sys) @@ -888,9 +1040,13 @@ function maybe_build_initialization_problem( buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp) initp = repack(floatT.(buffer)) elseif initp isa AbstractArray - initp′ = similar(initp, floatT) - copyto!(initp′, initp) - initp = initp′ + if ArrayInterface.ismutable(initp) + initp′ = similar(initp, floatT) + copyto!(initp′, initp) + initp = initp′ + else + initp = similar_type(initp, floatT)(initp) + end end initializeprob = remake(initializeprob; p = initp) @@ -901,8 +1057,9 @@ function maybe_build_initialization_problem( end meta = InitializationMetadata( u0map, pmap, guesses, Vector{Equation}(initialization_eqs), - use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys), - get_initial_unknowns, setp(sys, Initial.(unknowns(sys)))) + use_scc, ReconstructInitializeprob( + sys, initializeprob.f.sys; u0_constructor, p_constructor), + get_initial_unknowns, SetInitialUnknowns(sys)) if is_time_dependent(sys) all_init_syms = Set(all_symbols(initializeprob)) @@ -918,11 +1075,8 @@ function maybe_build_initialization_problem( if initializeprobmap === nothing && isempty(punknowns) initializeprobpmap = nothing else - allsyms = all_symbols(initializeprob) - initdvs = filter(x -> any(isequal(x), allsyms), unknowns(sys)) - getpunknowns = getu(initializeprob, [punknowns; initdvs]) - setpunknowns = setp(sys, [punknowns; Initial.(initdvs)]) - initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) + initializeprobpmap = construct_initializeprobpmap( + sys, initializeprob.f.sys; p_constructor) end reqd_syms = parameter_symbols(initializeprob) @@ -930,8 +1084,7 @@ function maybe_build_initialization_problem( if initializeprobmap === nothing && initializeprobpmap === nothing update_initializeprob! = nothing else - update_initializeprob! = UpdateInitializeprob( - getu(sys, reqd_syms), setu(initializeprob, reqd_syms)) + update_initializeprob! = ModelingToolkit.update_initializeprob! end for p in punknowns @@ -955,7 +1108,7 @@ function maybe_build_initialization_problem( return (; initialization_data = SciMLBase.OverrideInitData( initializeprob, update_initializeprob!, initializeprobmap, - initializeprobpmap; metadata = meta)) + initializeprobpmap; metadata = meta, is_update_oop = Val(true))) end """ @@ -978,6 +1131,22 @@ function float_type_from_varmap(varmap, floatT = Bool) return float(floatT) end +""" + $(TYPEDSIGNATURES) + +Calculate the `resid_prototype` for a `NonlinearFunction` with `N` equations and the +provided `u0` and `p`. +""" +function calculate_resid_prototype(N::Int, u0, p) + u0ElType = u0 === nothing ? Float64 : eltype(u0) + if SciMLStructures.isscimlstructure(p) + u0ElType = promote_type( + eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]), + u0ElType) + end + return zeros(u0ElType, N) +end + """ $(TYPEDSIGNATURES) @@ -1016,6 +1185,7 @@ Keyword arguments: - `tofloat`, `is_initializeprob`: Passed to [`better_varmap_to_vars`](@ref) for building `u0` (and possibly `p`). - `u0_constructor`: A function to apply to the `u0` value returned from `better_varmap_to_vars` to construct the final `u0` value. +- `p_constructor`: A function to apply to each array buffer created when constructing the parameter object. - `du0map`: A map of derivatives to values. See `implicit_dae`. - `check_length`: Whether to check the number of equations along with number of unknowns and length of `u0` vector for consistency. If `false`, do not check with equations. This is @@ -1044,8 +1214,8 @@ function process_SciMLProblem( warn_initialize_determined = true, initialization_eqs = [], eval_expression = false, eval_module = @__MODULE__, fully_determined = nothing, check_initialization_units = false, tofloat = true, - u0_constructor = identity, du0map = nothing, check_length = true, - symbolic_u0 = false, warn_cyclic_dependency = false, + u0_constructor = identity, p_constructor = identity, du0map = nothing, + check_length = true, symbolic_u0 = false, warn_cyclic_dependency = false, circular_dependency_max_cycle_length = length(all_symbols(sys)), circular_dependency_max_cycles = 10, substitution_limit = 100, use_scc = true, @@ -1082,7 +1252,7 @@ function process_SciMLProblem( u0map, pmap, defs, cmap, dvs, ps) floatT = Bool - if u0Type <: AbstractArray && eltype(u0Type) <: Real + if u0Type <: AbstractArray && eltype(u0Type) <: Real && eltype(u0Type) != Union{} floatT = float(eltype(u0Type)) else floatT = float_type_from_varmap(op, floatT) @@ -1095,15 +1265,21 @@ function process_SciMLProblem( u0_constructor = vals -> SymbolicUtils.Code.create_array( u0Type, floatT, Val(1), Val(length(vals)), vals...) end + if p_constructor === identity && pType <: StaticArray + p_constructor = vals -> SymbolicUtils.Code.create_array( + pType, floatT, Val(1), Val(length(vals)), vals...) + end + if build_initializeprob kws = maybe_build_initialization_problem( - sys, op, u0map, pmap, t, defs, guesses, missing_unknowns; + sys, constructor <: SciMLBase.AbstractSciMLFunction{true}, + op, u0map, pmap, t, defs, guesses, missing_unknowns; implicit_dae, warn_initialize_determined, initialization_eqs, eval_expression, eval_module, fully_determined, warn_cyclic_dependency, check_units = check_initialization_units, circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc, force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete, - u0_constructor, floatT) + u0_constructor, p_constructor, floatT) kwargs = merge(kwargs, kws) end @@ -1155,9 +1331,13 @@ function process_SciMLProblem( end evaluate_varmap!(op, ps; limit = substitution_limit) if is_split(sys) - p = MTKParameters(sys, op; floatT = floatT) + # `pType` is usually `Dict` when the user passes key-value pairs. + if !(pType <: AbstractArray) + pType = Array + end + p = MTKParameters(sys, op; floatT = floatT, p_constructor) else - p = better_varmap_to_vars(op, ps; tofloat, container_type = pType) + p = p_constructor(better_varmap_to_vars(op, ps; tofloat, container_type = pType)) end if implicit_dae && du0map !== nothing @@ -1177,8 +1357,15 @@ function process_SciMLProblem( t0 = zero(floatT) end initialization_data = SciMLBase.remake_initialization_data( - kwargs.initialization_data, kwargs, u0, t0, p, u0, p) - kwargs = merge(kwargs,) + sys, kwargs, u0, t0, p, u0, p) + kwargs = merge(kwargs, (; initialization_data)) + end + + if constructor <: NonlinearFunction && length(dvs) != length(eqs) + kwargs = merge(kwargs, + (; + resid_prototype = u0_constructor(calculate_resid_prototype( + length(eqs), u0, p)))) end f = constructor(sys, dvs, ps, u0; p = p, diff --git a/test/initial_values.jl b/test/initial_values.jl index b3614de0f4..0ed8f7bffe 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -2,7 +2,8 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D, get_u0 using OrdinaryDiffEq using DataInterpolations -using SymbolicIndexingInterface: getu +using StaticArrays +using SymbolicIndexingInterface @variables x(t)[1:3]=[1.0, 2.0, 3.0] y(t) z(t)[1:2] @@ -309,3 +310,55 @@ end @test prob[w2] ≈ -1.0 @test prob.ps[β] ≈ 8 / 3 end + +@testset "MTKParameters uses given `pType` for inner buffers" begin + @parameters σ ρ β + @variables x(t) y(t) z(t) + + eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + + @mtkbuild sys = ODESystem(eqs, t) + + u0 = SA[D(x) => 2.0f0, + x => 1.0f0, + y => 0.0f0, + z => 0.0f0] + + p = SA[σ => 28.0f0, + ρ => 10.0f0, + β => 8.0f0 / 3] + + tspan = (0.0f0, 100.0f0) + prob = ODEProblem(sys, u0, tspan, p) + @test prob.p.tunable isa SVector + @test prob.p.initials isa SVector +end + +@testset "`p_constructor` keyword argument" begin + @parameters g = 1.0 + @variables x(t) y(t) [state_priority = 10, guess = 1.0] λ(t) [guess = 1.0] + eqs = [D(D(x)) ~ λ * x + D(D(y)) ~ λ * y - g + x^2 + y^2 ~ 1] + @mtkbuild pend = ODESystem(eqs, t) + + u0 = [x => 1.0, D(x) => 0.0] + u0_constructor = p_constructor = vals -> SVector{length(vals)}(vals...) + tspan = (0.0, 5.0) + prob = ODEProblem(pend, u0, tspan; u0_constructor, p_constructor) + @test prob.u0 isa SVector + @test prob.p.tunable isa SVector + @test prob.p.initials isa SVector + initdata = prob.f.initialization_data + @test state_values(initdata.initializeprob) isa SVector + @test parameter_values(initdata.initializeprob).tunable isa SVector + + @mtkbuild pend=ODESystem(eqs, t) split=false + prob = ODEProblem(pend, u0, tspan; u0_constructor, p_constructor) + @test prob.p isa SVector + initdata = prob.f.initialization_data + @test state_values(initdata.initializeprob) isa SVector + @test parameter_values(initdata.initializeprob) isa SVector +end diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index fea989f0f3..7c512d37af 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1,6 +1,6 @@ using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test using StochasticDiffEq, DelayDiffEq, StochasticDelayDiffEq, JumpProcesses -using ForwardDiff +using ForwardDiff, StaticArrays using SymbolicIndexingInterface, SciMLStructures using SciMLStructures: Tunable using ModelingToolkit: t_nounits as t, D_nounits as D, observed @@ -594,22 +594,36 @@ end @parameters p q @brownian a b x = _x(t) - + sarray_ctor = splat(SVector) # `System` constructor creates appropriate type with mtkbuild # `Problem` and `alg` create the problem to test and allow calling `init` with # the correct solver. # `rhss` allows adding terms to the end of equations (only 2 equations allowed) to influence # the system type (brownian vars to turn it into an SDE). - @testset "$Problem with $(SciMLBase.parameterless_type(alg))" for (System, Problem, alg, rhss) in [ - (ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)), - (ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]), - (ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), - (ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b]) - ] + @testset "$Problem with $(SciMLBase.parameterless_type(alg)) and $ctor ctor" for ((System, Problem, alg, rhss), (ctor, expectedT)) in Iterators.product( + [ + (ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)), + (ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]), + (ModelingToolkit.System, DDEProblem, + MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), + (ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b]) + ], + [(identity, Any), (sarray_ctor, SVector)]) + u0_constructor = p_constructor = ctor + if ctor !== identity + Problem = Problem{false} + end function test_parameter(prob, sym, val) if prob.u0 !== nothing + @test prob.u0 isa expectedT @test init(prob, alg).ps[sym] ≈ val end + @test prob.p.tunable isa expectedT + initprob = prob.f.initialization_data.initializeprob + if state_values(initprob) !== nothing + @test state_values(initprob) isa expectedT + end + @test parameter_values(initprob).tunable isa expectedT @test solve(prob, alg).ps[sym] ≈ val end function test_initializesystem(sys, u0map, pmap, p, equation) @@ -626,64 +640,64 @@ end @mtkbuild sys = System( [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => missing], guesses = [p => 1.0]) pmap[p] = 2q - prob = Problem(sys, u0map, (0.0, 1.0), pmap) + prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor) test_parameter(prob, p, 2.0) prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # `missing` default, provided guess @mtkbuild sys = System( [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; defaults = [p => missing], guesses = [p => 0.0]) - prob = Problem(sys, u0map, (0.0, 1.0)) + prob = Problem(sys, u0map, (0.0, 1.0); u0_constructor, p_constructor) test_parameter(prob, p, 2.0) test_initializesystem(sys, u0map, pmap, p, 0 ~ p - x - y) prob2 = remake(prob; u0 = u0map) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # `missing` to Problem, equation from default @mtkbuild sys = System( [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) pmap[p] = missing - prob = Problem(sys, u0map, (0.0, 1.0), pmap) + prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor) test_parameter(prob, p, 2.0) test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p) prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # `missing` to Problem, provided guess @mtkbuild sys = System( [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [p => 0.0]) - prob = Problem(sys, u0map, (0.0, 1.0), pmap) + prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor) test_parameter(prob, p, 2.0) test_initializesystem(sys, u0map, pmap, p, 0 ~ x + y - p) prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # No `missing`, default and guess @mtkbuild sys = System( [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 0.0]) delete!(pmap, p) - prob = Problem(sys, u0map, (0.0, 1.0), pmap) + prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor) test_parameter(prob, p, 2.0) test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p) prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 + prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0)) test_parameter(prob2, p, 2.0) # Default overridden by Problem, guess provided @mtkbuild sys = System( [D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) _pmap = merge(pmap, Dict(p => q)) - prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor) test_parameter(prob, p, _pmap[q]) test_initializesystem(sys, u0map, _pmap, p, 0 ~ q - p) # Problem dependent value with guess, no `missing` @mtkbuild sys = System( [D(x) ~ y * q + p + rhss[1], D(y) ~ x * p + q + rhss[2]], t; guesses = [p => 0.0]) _pmap = merge(pmap, Dict(p => 3q)) - prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor) test_parameter(prob, p, 3pmap[q]) # Should not be solved for: @@ -691,7 +705,7 @@ end @mtkbuild sys = System( [D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) _pmap = merge(pmap, Dict(p => 1.0)) - prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor) @test prob.ps[p] ≈ 1.0 initsys = prob.f.initialization_data.initializeprob.f.sys @test is_parameter(initsys, p) @@ -700,7 +714,7 @@ end @parameters r::Int s::Int @mtkbuild sys = System( [D(x) ~ s * x + rhss[1], D(y) ~ y * r + rhss[2]], t; defaults = [s => 2r], guesses = [s => 1.0]) - prob = Problem(sys, u0map, (0.0, 1.0), [r => 1]) + prob = Problem(sys, u0map, (0.0, 1.0), [r => 1]; u0_constructor, p_constructor) @test prob.ps[r] == 1 @test prob.ps[s] == 2 initsys = prob.f.initialization_data.initializeprob.f.sys @@ -714,7 +728,7 @@ end # Unsatisfiable initialization prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), - [p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3]) + [p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3], u0_constructor, p_constructor) @test prob.f.initialization_data !== nothing @test solve(prob, alg).retcode == ReturnCode.InitialFailure cache = init(prob, alg) @@ -791,8 +805,17 @@ end prob_alg_combinations = zip( [NonlinearProblem, NonlinearLeastSquaresProblem], [nl_algs, nlls_algs]) - @testset "Parameter initialization" begin + sarray_ctor = splat(SVector) + @testset "Parameter initialization with ctor $ctor" for (ctor, expectedT) in [ + (identity, Any), + (sarray_ctor, SVector) + ] + u0_constructor = p_constructor = ctor function test_parameter(prob, alg, param, val) + if prob.u0 !== nothing + @test prob.u0 isa expectedT + end + @test prob.p.tunable isa expectedT integ = init(prob, alg) @test integ.ps[param]≈val rtol=1e-5 # some algorithms are a little temperamental @@ -818,7 +841,10 @@ end # guesses = [q => 1.0], initialization_eqs = [p^2 + q^2 + 2p * q ~ 0]) for (probT, algs) in prob_alg_combinations - prob = probT(sys, []) + if ctor != identity + probT = probT{false} + end + prob = probT(sys, []; u0_constructor, p_constructor) @test prob.f.initialization_data !== nothing @test prob.f.initialization_data.initializeprobmap === nothing for alg in algs @@ -826,11 +852,11 @@ end end # `update_initializeprob!` works - prob.ps[p] = -2.0 + prob = remake(prob; p = setp_oop(prob, p)(prob, -2.0)) for alg in algs test_parameter(prob, alg, q, 2.0) end - prob.ps[p] = 2.0 + prob = remake(prob; p = setp_oop(prob, p)(prob, 2.0)) # `remake` works prob2 = remake(prob; p = [p => -2.0]) @@ -1321,8 +1347,8 @@ end guesses = [p[1] => q, p[2] => 2q]) @test ModelingToolkit.is_parameter_solvable(p, Dict(), defaults(sys), guesses(sys)) prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [q => 2.0]) - @test length(ModelingToolkit.observed(prob.f.initialization_data.initializeprob.f.sys)) == - 3 + initsys = prob.f.initialization_data.initializeprob.f.sys + @test length(ModelingToolkit.observed(initsys)) == 4 sol = solve(prob, Tsit5()) @test sol.ps[p] ≈ [2.0, 4.0] end @@ -1599,3 +1625,28 @@ end @test SciMLBase.successful_retcode(sol) @test sol.u[1] ≈ new_u0 end + +@testset "Initialization system retains `split` kwarg of parent" begin + @parameters g + @variables x(t) y(t) [state_priority = 10] λ(t) + eqs = [D(D(x)) ~ λ * x + D(D(y)) ~ λ * y - g + x^2 + y^2 ~ 1] + @mtkbuild pend=ODESystem(eqs, t) split=false + prob = ODEProblem(pend, [x => 1.0, D(x) => 0.0], (0.0, 1.0), + [g => 1.0]; guesses = [y => 1.0, λ => 1.0]) + @test !ModelingToolkit.is_split(prob.f.initialization_data.initializeprob.f.sys) +end + +@testset "`InitializationProblem` retains `iip` of parent" begin + @parameters g + @variables x(t) y(t) [state_priority = 10] λ(t) + eqs = [D(D(x)) ~ λ * x + D(D(y)) ~ λ * y - g + x^2 + y^2 ~ 1] + @mtkbuild pend = ODESystem(eqs, t) + prob = ODEProblem(pend, SA[x => 1.0, D(x) => 0.0], (0.0, 1.0), + SA[g => 1.0]; guesses = [y => 1.0, λ => 1.0]) + @test !SciMLBase.isinplace(prob) + @test !SciMLBase.isinplace(prob.f.initialization_data.initializeprob) +end diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 55de0768e0..b7acbb84a8 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -1,8 +1,8 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters -using SymbolicIndexingInterface +using SymbolicIndexingInterface, StaticArrays using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants -using BlockArrays: BlockedArray, Block +using BlockArrays: BlockedArray, BlockedVector, Block using OrdinaryDiffEq using ForwardDiff using JET @@ -27,6 +27,15 @@ end @test getp(sys, a)(ps) == getp(sys, b)(ps) == getp(sys, c)(ps) == 0.0 @test getp(sys, d)(ps) isa Int +@testset "`p_constructor`" begin + ps2 = MTKParameters(sys, ivs; p_constructor = x -> SArray{Tuple{size(x)...}}(x)) + @test ps2.tunable isa SVector + @test ps2.initials isa SVector + @test ps2.discrete isa Tuple{<:BlockedVector{Float64, <:SVector}} + @test ps2.constant isa Tuple{<:SVector, <:SVector, <:SVector{1, <:SMatrix}} + @test ps2.nonnumeric isa Tuple{<:SVector} +end + ivs[a] = 1.0 ps = MTKParameters(sys, ivs) for (p, val) in ivs diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index a315371141..158475f7c9 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -442,3 +442,19 @@ end @test !in(D(y), vs) end end + +@testset "oop `NonlinearLeastSquaresProblem` with `u0 === nothing`" begin + @variables x y + @named sys = NonlinearSystem([0 ~ x - y], [], []; observed = [x ~ 1.0, y ~ 1.0]) + prob = NonlinearLeastSquaresProblem{false}(complete(sys), nothing) + sol = solve(prob) + resid = sol.resid + @test resid == [0.0] + @test resid isa Vector + prob = NonlinearLeastSquaresProblem{false}( + complete(sys), nothing; u0_constructor = splat(SVector)) + sol = solve(prob) + resid = sol.resid + @test resid == [0.0] + @test resid isa SVector +end diff --git a/test/reduction.jl b/test/reduction.jl index fa9029a652..adeb4005d7 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -176,7 +176,7 @@ A = reshape(1:(N^2), N, N) eqs = xs ~ A * xs @named sys′ = NonlinearSystem(eqs, [xs], []) sys = structural_simplify(sys′) -@test length(equations(sys)) == 3 && length(observed(sys)) == 2 +@test length(equations(sys)) == 3 && length(observed(sys)) == 3 # issue 958 @parameters k₁ k₂ k₋₁ E₀ diff --git a/test/runtests.jl b/test/runtests.jl index 1b8b5e03db..86d69228ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -138,9 +138,9 @@ end activate_extensions_env() @safetestset "Dynamic Optimization Collocation Solvers" include("extensions/dynamic_optimization.jl") @safetestset "HomotopyContinuation Extension Test" include("extensions/homotopy_continuation.jl") - @safetestset "Auto Differentiation Test" include("extensions/ad.jl") @safetestset "LabelledArrays Test" include("labelledarrays.jl") @safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl") @safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl") + @safetestset "Auto Differentiation Test" include("extensions/ad.jl") end end diff --git a/test/scc_nonlinear_problem.jl b/test/scc_nonlinear_problem.jl index b2b326d090..ac7978d269 100644 --- a/test/scc_nonlinear_problem.jl +++ b/test/scc_nonlinear_problem.jl @@ -2,6 +2,7 @@ using ModelingToolkit using NonlinearSolve, SCCNonlinearSolve using OrdinaryDiffEq using SciMLBase, Symbolics +using StaticArrays using LinearAlgebra, Test using ModelingToolkit: t_nounits as t, D_nounits as D @@ -32,6 +33,12 @@ using ModelingToolkit: t_nounits as t, D_nounits as D @test SciMLBase.successful_retcode(sol1) @test SciMLBase.successful_retcode(sol2) @test sol1[u] ≈ sol2[u] + + sccprob = SCCNonlinearProblem{false}(model, SA[u => zeros(8)]) + for prob in sccprob.probs + @test prob.u0 isa SVector + @test !SciMLBase.isinplace(prob) + end end @testset "With parameters" begin diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index b5335ad6b1..67cf2f72a0 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -76,7 +76,7 @@ end end @mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t) @test length(equations(sys)) == 1 - @test length(observed(sys)) == 3 + @test length(observed(sys)) == 4 prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2]) val[] = 0 @test_nowarn prob.f(prob.u0, prob.p, 0.0)