diff --git a/Project.toml b/Project.toml index 1b7c1df58..396693f09 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index eb5259c63..9c39e8cbb 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -34,6 +34,7 @@ using Setfield using ForwardDiff +import FunctionWrappersWrappers @reexport using SciMLBase using SciMLBase: @def, DEIntegrator, DEProblem, AbstractDiffEqOperator, @@ -119,6 +120,7 @@ include("init.jl") include("forwarddiff.jl") include("chainrules.jl") +include("norecompile.jl") # This is only used for oop stiff solvers default_factorize(A) = lu(A; check = false) diff --git a/src/norecompile.jl b/src/norecompile.jl new file mode 100644 index 000000000..387c40517 --- /dev/null +++ b/src/norecompile.jl @@ -0,0 +1,168 @@ +struct OrdinaryDiffEqTag end + +const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} +const NORECOMPILE_IIP_SUPPORTED_ARGS = (Tuple{Vector{Float64}, Vector{Float64}, + Vector{Float64}, Float64}, + Tuple{Vector{Float64}, Vector{Float64}, + SciMLBase.NullParameters, Float64}) +const iip_arglists = (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}, + Tuple{Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, + Float64 + }, + Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT}, + Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64}, + Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64 + }, + Tuple{Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT + }) +const iip_returnlists = ntuple(x -> Nothing, length(iip_arglists)) +function void(@nospecialize(f::Function)) + function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::dualT)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT)) + f(du, u, p, t) + nothing + end + precompile(f, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) + precompile(f, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) + precompile(f, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) + precompile(f, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) + + precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) + precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) + precompile(f2, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) + precompile(f2, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) + f2 +end + +const oop_arglists = (Tuple{Vector{Float64}, Vector{Float64}, Float64}, + Tuple{Vector{Float64}, SciMLBase.NullParameters, Float64}, + Tuple{Vector{Float64}, Vector{Float64}, dualT}, + Tuple{Vector{dualT}, Vector{Float64}, Float64}, + Tuple{Vector{dualT}, SciMLBase.NullParameters, Float64}, + Tuple{Vector{Float64}, SciMLBase.NullParameters, dualT}) + +const NORECOMPILE_OOP_SUPPORTED_ARGS = (Tuple{Vector{Float64}, + Vector{Float64}, Float64}, + Tuple{Vector{Float64}, + SciMLBase.NullParameters, Float64}) +const oop_returnlists = (Vector{Float64}, Vector{Float64}, + ntuple(x -> Vector{dualT}, length(oop_arglists) - 2)...) + +function typestablemapping(@nospecialize(f::Function)) + function f2(@nospecialize(u::Vector{Float64}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(u, p, t)::Vector{Float64} + end + + function f2(@nospecialize(u::Vector{Float64}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(u, p, t)::Vector{Float64} + end + + function f2(@nospecialize(u::Vector{dualT}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(u, p, t)::Vector{dualT} + end + + function f2(@nospecialize(u::Vector{dualT}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(u, p, t)::Vector{dualT} + end + + function f2(@nospecialize(u::Vector{Float64}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::dualT)) + f(u, p, t)::Vector{dualT} + end + + function f2(@nospecialize(u::Vector{Float64}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT)) + f(u, p, t)::Vector{dualT} + end + precompile(f, (Vector{Float64}, Vector{Float64}, Float64)) + precompile(f, (Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{Float64}, Float64)) + precompile(f, (Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{Float64}, Vector{Float64}, dualT)) + precompile(f, (Vector{Float64}, SciMLBase.NullParameters, dualT)) + + precompile(f2, (Vector{Float64}, Vector{Float64}, Float64)) + precompile(f2, (Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{Float64}, Float64)) + precompile(f2, (Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{Float64}, Vector{Float64}, dualT)) + precompile(f2, (Vector{Float64}, SciMLBase.NullParameters, dualT)) + f2 +end + +const NORECOMPILE_ARGUMENT_MESSAGE = """ + No-recompile mode is only supported for state arguments + of type `Vector{Float64}`, time arguments of `Float64` + and parameter arguments of type `Vector{Float64}` or + `SciMLBase.NullParameters`. + """ + +struct NoRecompileArgumentError <: Exception + args::Any +end + +function Base.showerror(io::IO, e::NoRecompileArgumentError) + println(io, NORECOMPILE_ARGUMENT_MESSAGE) + print(io, "Attempted arguments: ") + print(io, e.args) +end + +function wrapfun_oop(ff, inputs::Tuple) + IT = Tuple{map(typeof, inputs)...} + if IT ∉ NORECOMPILE_OOP_SUPPORTED_ARGS + throw(NoRecompileArgumentError(IT)) + end + FunctionWrappersWrappers.FunctionWrappersWrapper(typestablemapping(ff), oop_arglists, + oop_returnlists) +end + +function wrapfun_iip(ff, inputs::Tuple) + IT = Tuple{map(typeof, inputs)...} + if IT ∉ NORECOMPILE_IIP_SUPPORTED_ARGS + throw(NoRecompileArgumentError(IT)) + end + FunctionWrappersWrappers.FunctionWrappersWrapper(void(ff), iip_arglists, + iip_returnlists) +end + +function unwrap_fw(fw::FunctionWrapper) + fw.obj[] +end diff --git a/src/utils.jl b/src/utils.jl index bfda082f8..7fcfd3369 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,17 +1,3 @@ -function wrapfun_oop(ff, inputs::Tuple) - IT = map(typeof, inputs) - FunctionWrapper{IT[1], Tuple{IT...}}((args...) -> (ff(args...))) -end - -function wrapfun_iip(ff, inputs::Tuple) - IT = map(typeof, inputs) - FunctionWrapper{Nothing, Tuple{IT...}}((args...) -> (ff(args...); nothing)) -end - -function unwrap_fw(fw::FunctionWrapper) - fw.obj[] -end - _vec(v) = vec(v) _vec(v::Number) = v _vec(v::AbstractVector) = v