diff --git a/Project.toml b/Project.toml index f75148b..b043522 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,8 @@ authors = ["Chris Rackauckas "] version = "4.26.1" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" @@ -15,15 +17,20 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Sundials_jll = "fb77eaff-e24c-56d4-86b1-d163f2edb164" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [compat] +Accessors = "0.1.38" +ArrayInterface = "7.17.1" CEnum = "0.5" DataStructures = "0.18" DiffEqBase = "6.154" +ModelingToolkit = "9.54" PrecompileTools = "1" Reexport = "1.0" -SciMLBase = "2.9" +SciMLBase = "2.63.1" Sundials_jll = "5.2" +SymbolicIndexingInterface = "0.3.35" julia = "1.9" [extras] @@ -34,9 +41,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" -SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "AlgebraicMultigrid", "DiffEqCallbacks", "ODEProblemLibrary", "DAEProblemLibrary", "ForwardDiff", "SparseDiffTools", "SparseConnectivityTracer", "IncompleteLU", "ModelingToolkit"] +test = ["Test", "AlgebraicMultigrid", "DiffEqCallbacks", "ODEProblemLibrary", "DAEProblemLibrary", "ForwardDiff", "SparseDiffTools", "SparseConnectivityTracer", "IncompleteLU", "ModelingToolkit", "SafeTestsets"] diff --git a/src/Sundials.jl b/src/Sundials.jl index 6a1474a..03a6f6c 100644 --- a/src/Sundials.jl +++ b/src/Sundials.jl @@ -5,6 +5,10 @@ module Sundials import Reexport Reexport.@reexport using DiffEqBase using SciMLBase: AbstractSciMLOperator +import Accessors: @reset +import ArrayInterface +import SymbolicIndexingInterface as SII +import SymbolicIndexingInterface: ParameterIndexingProxy import DataStructures import Logging import DiffEqBase @@ -81,6 +85,7 @@ include("common_interface/verbosity.jl") include("common_interface/algorithms.jl") include("common_interface/integrator_types.jl") include("common_interface/integrator_utils.jl") +include("common_interface/initialize_dae.jl") include("common_interface/solve.jl") import PrecompileTools diff --git a/src/common_interface/initialize_dae.jl b/src/common_interface/initialize_dae.jl new file mode 100644 index 0000000..eb66f17 --- /dev/null +++ b/src/common_interface/initialize_dae.jl @@ -0,0 +1,78 @@ +struct SundialsDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end + +function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator, initializealg = integrator.initializealg) + _initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) +end + +struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm +end + +function _initialize_dae!(integrator::IDAIntegrator, prob, + initializealg::IDADefaultInit, isinplace) + if integrator.u_modified + IDAReinit!(integrator) + end + integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t) + tstart, tend = integrator.sol.prob.tspan + if any(abs.(integrator.tmp) .>= integrator.opts.reltol) + if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all + error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.") + end + if integrator.alg.init_all + init_type = IDA_Y_INIT + else + init_type = IDA_YA_YDP_INIT + integrator.flag = IDASetId(integrator.mem, + vec(integrator.sol.prob.differential_vars)) + end + dt = integrator.dt == tstart ? tend : integrator.dt + integrator.flag = IDACalcIC(integrator.mem, init_type, dt) + + # Reflect consistent initial conditions back into the integrator's + # shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}). + IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec) + end + if integrator.t == tstart && integrator.flag < 0 + integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, + ReturnCode.InitialFailure) + end +end + +function _initialize_dae!(integrator, prob, ::SundialsDefaultInit, isinplace) + if SciMLBase.has_initializeprob(prob.f) + _initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) + elseif integrator isa IDAIntegrator + _initialize_dae!(integrator, prob, IDADefaultInit(), isinplace) + end +end + +function _initialize_dae!(integrator, prob, initalg::SciMLBase.NoInit, isinplace) end + +function _initialize_dae!(integrator, prob, initalg::SciMLBase.OverrideInit, isinplace::Union{Val{true}, Val{false}}) + nlsolve_alg = KINSOL() + u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) + + if isinplace === Val{true}() + integrator.u .= u0 + if length(integrator.sol.u) == 1 + integrator.sol.u[1] .= u0 + end + else + integrator.u = u0 + if length(integrator.sol.u) == 1 + integrator.sol.u[1] = u0 + end + end + integrator.p = p + sol = integrator.sol + @reset sol.prob.p = integrator.p + integrator.sol = sol + + if !success + integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.InitialFailure) + end +end + +function _initialize_dae!(integrator, prob, initalg::SciMLBase.CheckInit, isinplace::Union{Val{true}, Val{false}}) + SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; abstol = integrator.opts.abstol) +end diff --git a/src/common_interface/integrator_types.jl b/src/common_interface/integrator_types.jl index a1b1605..022fdc8 100644 --- a/src/common_interface/integrator_types.jl +++ b/src/common_interface/integrator_types.jl @@ -40,7 +40,8 @@ mutable struct CVODEIntegrator{N, oType, LStype, Atype, - CallbackCacheType} <: AbstractSundialsIntegrator{algType} + CallbackCacheType, + IA} <: AbstractSundialsIntegrator{algType} u::Array{Float64, N} u_nvec::NVector p::pType @@ -66,6 +67,7 @@ mutable struct CVODEIntegrator{N, vector_event_last_time::Int callback_cache::CallbackCacheType last_event_error::Float64 + initializealg::IA end function (integrator::CVODEIntegrator)(t::Number, @@ -96,7 +98,8 @@ mutable struct ARKODEIntegrator{N, Atype, MLStype, Mtype, - CallbackCacheType} <: AbstractSundialsIntegrator{ARKODE} + CallbackCacheType, + IA} <: AbstractSundialsIntegrator{ARKODE} u::Array{Float64, N} u_nvec::NVector p::pType @@ -124,6 +127,7 @@ mutable struct ARKODEIntegrator{N, vector_event_last_time::Int callback_cache::CallbackCacheType last_event_error::Float64 + initializealg::IA end function (integrator::ARKODEIntegrator)(t::Number, diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index d0192ca..30c5626 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -168,6 +168,8 @@ end @inline function Base.getproperty(integrator::AbstractSundialsIntegrator, sym::Symbol) if sym == :dt return integrator.t - integrator.tprev + elseif sym == :ps + return ParameterIndexingProxy(integrator) else return getfield(integrator, sym) end @@ -185,42 +187,6 @@ end # Required for callbacks DiffEqBase.set_proposed_dt!(i::AbstractSundialsIntegrator, dt) = nothing -DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator) = nothing - -struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm -end - -function DiffEqBase.initialize_dae!(integrator::IDAIntegrator, - initializealg::IDADefaultInit) - if integrator.u_modified - IDAReinit!(integrator) - end - integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t) - tstart, tend = integrator.sol.prob.tspan - if any(abs.(integrator.tmp) .>= integrator.opts.reltol) - if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all - error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.") - end - if integrator.alg.init_all - init_type = IDA_Y_INIT - else - init_type = IDA_YA_YDP_INIT - integrator.flag = IDASetId(integrator.mem, - vec(integrator.sol.prob.differential_vars)) - end - dt = integrator.dt == tstart ? tend : integrator.dt - integrator.flag = IDACalcIC(integrator.mem, init_type, dt) - - # Reflect consistent initial conditions back into the integrator's - # shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}). - IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec) - end - if integrator.t == tstart && integrator.flag < 0 - integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, - ReturnCode.InitialFailure) - end -end - DiffEqBase.has_reinit(integrator::AbstractSundialsIntegrator) = true function DiffEqBase.reinit!(integrator::AbstractSundialsIntegrator, u0 = integrator.sol.prob.u0; @@ -294,3 +260,6 @@ DiffEqBase.get_tstops(integ::AbstractSundialsIntegrator) = integ.opts.tstops DiffEqBase.get_tstops_array(integ::AbstractSundialsIntegrator) = get_tstops(integ).valtree DiffEqBase.get_tstops_max(integ::AbstractSundialsIntegrator) = maximum(get_tstops_array(integ)) + +# SII +SII.symbolic_container(integ::AbstractSundialsIntegrator) = integ.sol diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index a1cb3c0..42b854f 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -124,6 +124,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i stop_at_next_tstop = false, userdata = nothing, alias_u0 = false, + initializealg = SundialsDefaultInit(), kwargs...) where {uType, tupType, isinplace, Method, LinearSolver } tType = eltype(tupType) @@ -457,7 +458,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i 0, 1, callback_cache, - 0.0) + 0.0, + initializealg) + DiffEqBase.initialize_dae!(integrator) initialize_callbacks!(integrator) integrator end # function solve @@ -499,6 +502,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i stop_at_next_tstop = false, userdata = nothing, alias_u0 = false, + initializealg = SundialsDefaultInit(), kwargs...) where {uType, tupType, isinplace, Method, LinearSolver, MassLinearSolver} @@ -945,8 +949,10 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i 0, 1, callback_cache, - 0.0) + 0.0, + initializealg) + DiffEqBase.initialize_dae!(integrator) initialize_callbacks!(integrator) integrator end # function solve @@ -1010,7 +1016,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu advance_to_tstop = false, stop_at_next_tstop = false, userdata = nothing, - initializealg = IDADefaultInit(), + initializealg = SundialsDefaultInit(), kwargs...) where {uType, duType, tupType, isinplace, LinearSolver } tType = eltype(tupType) @@ -1313,7 +1319,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu dutmp, initializealg) - DiffEqBase.initialize_dae!(integrator, initializealg) + DiffEqBase.initialize_dae!(integrator) integrator.u_modified && IDAReinit!(integrator) if save_start diff --git a/test/common_interface/initialization.jl b/test/common_interface/initialization.jl new file mode 100644 index 0000000..8a6876c --- /dev/null +++ b/test/common_interface/initialization.jl @@ -0,0 +1,61 @@ +using ModelingToolkit, SciMLBase, Sundials, Test +using SymbolicIndexingInterface +using ModelingToolkit: t_nounits as t, D_nounits as D + +@testset "ODE" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p = missing [guess = 1.0] q = missing [guess = 1.0] + @mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, D(y) ~ 5x + q], t; initialization_eqs = [p ^2 + q^2 ~ 3, x^3 + y^3 ~ 5]) + + @testset "IIP: $iip" for iip in [true, false] + prob = ODEProblem{iip}(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]) + + @testset "$alg" for alg in [CVODE_BDF, CVODE_Adams, ARKODE] + integ = init(prob, alg()) + @test integ.initializealg isa Sundials.SundialsDefaultInit + @test integ[x] ≈ 1.0 + @test integ[y] ≈ cbrt(4) + @test integ.ps[p] ≈ 1.0 + @test integ.ps[q] ≈ sqrt(2) + sol = solve(prob, alg()) + @test SciMLBase.successful_retcode(sol) + @test sol[x, 1] ≈ 1.0 + @test sol[y, 1] ≈ cbrt(4) + @test sol.ps[p] ≈ 1.0 + @test sol.ps[q] ≈ sqrt(2) + end + end +end +@testset "DAE" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p = missing [guess = 1.0] q = missing [guess = 1.0] + @mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, x^3 + y^3 ~ 5], t; initialization_eqs = [p ^2 + q^2 ~ 3]) + + @testset "DAEProblem{$iip}" for iip in [true, false] + prob = DAEProblem{iip}(sys, [D(x) => cbrt(4), D(y) => -1 / cbrt(4)], [x => 1.0], (0.0, 1.0), [p => 1.0]) + + @testset "OverrideInit" begin + integ = init(prob, IDA()) + @test integ.initializealg isa Sundials.SundialsDefaultInit + @test integ[x] ≈ 1.0 + @test integ[y] ≈ cbrt(4) + @test integ.ps[p] ≈ 1.0 + @test integ.ps[q] ≈ sqrt(2) + sol = solve(prob, IDA()) + @test SciMLBase.successful_retcode(sol) + @test sol[x, 1] ≈ 1.0 + @test sol[y, 1] ≈ cbrt(4) + @test sol.ps[p] ≈ 1.0 + @test sol.ps[q] ≈ sqrt(2) + end + @testset "CheckInit" begin + @test_throws SciMLBase.CheckInitFailureError init(prob, IDA(); initializealg = SciMLBase.CheckInit()) + prob[x] = 1.0 + prob[y] = cbrt(4) + prob.ps[p] = 1 + prob.ps[q] = sqrt(2) + @test_nowarn init(prob, IDA(); initializealg = SciMLBase.CheckInit()) + + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 26ecf42..4814e4e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using Sundials +using SafeTestsets using Test @testset "Generator" begin @@ -6,23 +7,23 @@ using Test end @testset "CVODE" begin - @testset "Roberts CVODE Simplified" begin + @safetestset "Roberts CVODE Simplified" begin include("cvode_Roberts_simplified.jl") end - @testset "Roberts CVODE Direct" begin + @safetestset "Roberts CVODE Direct" begin include("cvode_Roberts_dns.jl") end #@testset "CVODES Direct" begin include("cvodes_dns.jl") end end @testset "IDA" begin - @testset "Roberts IDA Simplified" begin + @safetestset "Roberts IDA Simplified" begin include("ida_Roberts_simplified.jl") end - @testset "Roberts IDA Direct" begin + @safetestset "Roberts IDA Direct" begin include("ida_Roberts_dns.jl") end - @testset "Heat IDA Direct" begin + @safetestset "Heat IDA Direct" begin include("ida_Heat2D.jl") end # Commented out because still uses the syntax from Grid which is a deprecated package @@ -30,26 +31,26 @@ end end @testset "ARK" begin - @testset "Roberts ARKStep Direct" begin + @safetestset "Roberts ARKStep Direct" begin include("arkstep_Roberts_dns.jl") end - @testset "NonLinear ERKStep Direct" begin + @safetestset "NonLinear ERKStep Direct" begin include("erkstep_nonlin.jl") end #@testset "MRI two way couple" begin include("mri_twowaycouple.jl") end end @testset "Kinsol" begin - @testset "Kinsol Simplified" begin + @safetestset "Kinsol Simplified" begin include("kinsol_mkin_simplified.jl") end - @testset "Kinsol MKin" begin + @safetestset "Kinsol MKin" begin include("kinsol_mkinTest.jl") end - @testset "Kinsol Banded" begin + @safetestset "Kinsol Banded" begin include("kinsol_banded.jl") end - @testset "Kinsol NonlinearSolve" begin + @safetestset "Kinsol NonlinearSolve" begin include("kinsol_nonlinear_solve.jl") end end @@ -58,33 +59,36 @@ end end @testset "Common Interface" begin - @testset "CVODE" begin + @safetestset "CVODE" begin include("common_interface/cvode.jl") end - @testset "ARKODE" begin + @safetestset "ARKODE" begin include("common_interface/arkode.jl") end - @testset "IDA" begin + @safetestset "IDA" begin include("common_interface/ida.jl") end - @testset "Jacobians" begin + @safetestset "Jacobians" begin include("common_interface/jacobians.jl") end - @testset "Callbacks" begin + @safetestset "Callbacks" begin include("common_interface/callbacks.jl") end - @testset "Iterator" begin + @safetestset "Iterator" begin include("common_interface/iterators.jl") end - @testset "Errors" begin + @safetestset "Errors" begin include("common_interface/errors.jl") end - @testset "Mass Matrix" begin + @safetestset "Mass Matrix" begin include("common_interface/mass_matrix.jl") end - @testset "Preconditioners" begin + @safetestset "Preconditioners" begin include("common_interface/precs.jl") end + @safetestset "Initialization" begin + include("common_interface/initialization.jl") + end end @testset "Interpolation" begin