Skip to content

Setup OrdinaryDiffEq for norecompile mode #1627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand Down Expand Up @@ -58,7 +59,7 @@ Polyester = "0.3, 0.4, 0.5, 0.6"
PreallocationTools = "0.2, 0.3, 0.4"
RecursiveArrayTools = "2.26.3"
Reexport = "0.2, 1.0"
SciMLBase = "1.44"
SciMLBase = "1.50"
SnoopPrecompile = "1"
SparseDiffTools = "1.19.1"
StaticArrays = "0.11, 0.12, 1.0"
Expand Down
112 changes: 84 additions & 28 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ using DocStringExtensions
using DiffEqBase: Convergence, Divergence
const TryAgain = DiffEqBase.SlowConvergence

import DiffEqBase: calculate_residuals, calculate_residuals!, unwrap_cache, @tight_loop_macros, islinear, timedepentdtmin
import DiffEqBase: calculate_residuals, calculate_residuals!, unwrap_cache, @tight_loop_macros,
islinear, timedepentdtmin, OrdinaryDiffEqTag

import SparseDiffTools
import SparseDiffTools: matrix_colors, forwarddiff_color_jacobian!,
Expand All @@ -80,9 +81,8 @@ using DocStringExtensions
ForwardDiff.Dual{ForwardDiff.Tag{T,W},K,3} where {T,W<:Union{Float64,Float32},
K<:Union{Float64,Float32}}}

struct OrdinaryDiffEqTag end

import ArrayInterfaceStaticArrays, ArrayInterfaceGPUArrays
import FunctionWrappersWrappers

DEFAULT_PRECS(W,du,u,p,t,newW,Plprev,Prprev,solverdata) = nothing,nothing

Expand Down Expand Up @@ -193,31 +193,87 @@ using DocStringExtensions
import SnoopPrecompile

SnoopPrecompile.@precompile_all_calls begin
function lorenz(du,u,p,t)
du[1] = 10.0(u[2]-u[1])
du[2] = u[1]*(28.0-u[3]) - u[2]
du[3] = u[1]*u[2] - (8/3)*u[3]
end
lorenzprob = ODEProblem(lorenz,[1.0;0.0;0.0],(0.0,1.0))
solve(lorenzprob,BS3())
solve(lorenzprob,Tsit5())
solve(lorenzprob,Vern7())
solve(lorenzprob,Vern9())
solve(lorenzprob,Rosenbrock23())(5.0)
solve(lorenzprob,TRBDF2())
solve(lorenzprob,Rodas4(autodiff=false))
solve(lorenzprob,KenCarp4(autodiff=false))
solve(lorenzprob,Rodas5())
solve(lorenzprob,QNDF())
solve(lorenzprob,QNDF(autodiff=false))
solve(lorenzprob,AutoTsit5(Rosenbrock23()))
solve(lorenzprob,AutoTsit5(Rosenbrock23(autodiff=false)))
solve(lorenzprob,AutoTsit5(TRBDF2(autodiff=false)))
solve(lorenzprob,AutoVern7(Rodas4(autodiff=false)))
solve(lorenzprob,AutoVern7(TRBDF2(autodiff=false)))
solve(lorenzprob,AutoVern9(Rodas5(autodiff=false)))
solve(lorenzprob,AutoVern9(KenCarp47(autodiff=false)))
solve(lorenzprob,AutoVern7(Rodas5()))
function lorenz(du,u,p,t)
du[1] = 10.0(u[2]-u[1])
du[2] = u[1]*(28.0-u[3]) - u[2]
du[3] = u[1]*u[2] - (8/3)*u[3]
end

function lorenz_oop(u,p,t)
[10.0(u[2]-u[1]),u[1]*(28.0-u[3]) - u[2],u[1]*u[2] - (8/3)*u[3]]
end

solver_list = [
BS3(), Tsit5(), Vern7(), Vern9(),

Rosenbrock23(), Rosenbrock23(autodiff=false),
Rosenbrock23(chunk_size = 1), Rosenbrock23(chunk_size = Val{1}()),

Rodas4(), Rodas4(autodiff=false),
#Rodas4(chunk_size = 1), Rodas4(chunk_size = Val{1}()),

Rodas5(), Rodas5(autodiff=false),
#Rodas5(chunk_size = 1), Rodas5(chunk_size = Val{1}()),

Rodas5P(), Rodas5P(autodiff=false),
Rodas5P(chunk_size = 1), Rodas5P(chunk_size = Val{1}()),

TRBDF2(), TRBDF2(autodiff=false),
#TRBDF2(chunk_size = 1), TRBDF2(chunk_size = Val{1}()),

KenCarp4(), KenCarp4(autodiff=false),
#KenCarp4(chunk_size = 1), KenCarp4(chunk_size = Val{1}()),

QNDF(), QNDF(autodiff=false),
#QNDF(chunk_size = 1), QNDF(chunk_size = Val{1}()),

AutoTsit5(Rosenbrock23()), AutoTsit5(Rosenbrock23(autodiff=false)),
AutoTsit5(Rosenbrock23(chunk_size = 1)),
AutoTsit5(Rosenbrock23(chunk_size = Val{1}())),

AutoTsit5(TRBDF2()), AutoTsit5(TRBDF2(autodiff=false)),
#AutoTsit5(TRBDF2(chunk_size = 1)),
#AutoTsit5(TRBDF2(chunk_size = Val{1}())),

AutoVern9(KenCarp47()), AutoVern9(KenCarp47(autodiff=false)),
#AutoVern9(KenCarp47(chunk_size = 1)),
#AutoVern9(KenCarp47(chunk_size = Val{1}())),

AutoVern9(Rodas5()), AutoVern9(Rodas5(autodiff=false)),
AutoVern9(Rodas5(chunk_size = 1)),
AutoVern9(Rodas5(chunk_size = Val{1}())),

AutoVern9(Rodas5P()), AutoVern9(Rodas5P(autodiff=false)),
AutoVern9(Rodas5P(chunk_size = 1)),
AutoVern9(Rodas5P(chunk_size = Val{1}())),

AutoVern7(Rodas4()), AutoVern7(Rodas4(autodiff=false)),
#AutoVern7(Rodas4(chunk_size = 1)),
#AutoVern7(Rodas4(chunk_size = Val{1}())),

#AutoVern7(Rodas5P()), AutoVern7(Rodas5P(autodiff=false)),
#AutoVern7(Rodas5P(chunk_size = 1)),
#AutoVern7(Rodas5P(chunk_size = Val{1}())),

AutoVern7(TRBDF2()), AutoVern7(TRBDF2(autodiff=false)),
#AutoVern7(TRBDF2(chunk_size = 1)),
#AutoVern7(TRBDF2(chunk_size = Val{1}())),
]

prob_list = [
ODEProblem(lorenz,[1.0;0.0;0.0],(0.0,1.0))
ODEProblem{true,false}(lorenz,[1.0;0.0;0.0],(0.0,1.0))
ODEProblem{true,false}(lorenz,[1.0;0.0;0.0],(0.0,1.0),Float64[])
ODEProblem(lorenz_oop,[1.0;0.0;0.0],(0.0,1.0))
#ODEProblem{false,false}(lorenz_oop,[1.0;0.0;0.0],(0.0,1.0))
#ODEProblem{false,false}(lorenz_oop,[1.0;0.0;0.0],(0.0,1.0),Float64[])
]

for prob in prob_list, solver in solver_list
solve(prob,solver)(5.0)
end

prob_list = nothing
end

#General Functions
Expand Down
7 changes: 6 additions & 1 deletion src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorit
linsolve = alg.linsolve
end

isbitstype(T) && sizeof(T) > 24 && return remake(alg, chunk_size=Val{1}(), linsolve=linsolve)
# If norecompile mode or very large bitsize, like a dual number u0 already, then
# don't use a large chunksize as it will either error or not be beneficial
if (isbitstype(T) && sizeof(T) > 24) || (prob.f isa ODEFunction && prob.f.f isa
FunctionWrappersWrappers.FunctionWrappersWrapper)
return remake(alg, chunk_size=Val{1}(), linsolve=linsolve)
end

L = ArrayInterface.known_length(typeof(u0))
if L === nothing # dynamic sized
Expand Down
57 changes: 57 additions & 0 deletions test/interface/norecompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using OrdinaryDiffEq, Test
function f(du, u, p, t)
du[1] = 0.2u[1]
du[2] = 0.4u[2]
end
u0 = ones(2)
tspan = (0.0, 1.0)
prob = ODEProblem{true,false}(f, u0, tspan, Float64[])

function lorenz(du, u, p, t)
du[1] = 10.0(u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
lorenzprob = ODEProblem{true,false}(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[])
@test typeof(prob) === typeof(lorenzprob)

t1 = @elapsed sol = solve(lorenzprob, Rosenbrock23())
t2 = @elapsed sol = solve(lorenzprob, Rosenbrock23(autodiff=false))

lorenzprob2 = ODEProblem(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[])

t3 = @elapsed sol = solve(lorenzprob2, Rosenbrock23())
t4 = @elapsed sol = solve(lorenzprob2, Rosenbrock23(autodiff=false))

if VERSION >= v"1.8"
@test 5t1 < t3
@test t2 < t4
end

function f_oop(u, p, t)
[0.2u[1], 0.4u[2]]
end
u0 = ones(2)
tspan = (0.0, 1.0)
prob = ODEProblem{false,false}(f_oop, u0, tspan, Float64[])

function lorenz_oop(u, p, t)
[10.0(u[2] - u[1]), u[1] * (28.0 - u[3]) - u[2], u[1] * u[2] - (8 / 3) * u[3]]
end
lorenzprob = ODEProblem{false,false}(lorenz_oop, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[])
@test typeof(prob) === typeof(lorenzprob)

# This one is fundamentally hard / broken
# Since the equation is not dependent on `t`, the output is not dual of t
# This is problem-dependent, so it is hard to deduce a priori
@test_broken t1 = @elapsed sol = solve(lorenzprob, Rosenbrock23())

t2 = @elapsed sol = solve(lorenzprob, Rosenbrock23(autodiff=false))

lorenzprob2 = ODEProblem(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[])

t3 = @elapsed sol = solve(lorenzprob2, Rosenbrock23())
t4 = @elapsed sol = solve(lorenzprob2, Rosenbrock23(autodiff=false))

#@test 5t1 < t3
#@test t2 < t4
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ if GROUP == "All" || GROUP == "InterfaceI" || GROUP == "Interface"
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceII" || GROUP == "Interface")
@time @safetestset "No Recompile Tests" begin include("interface/norecompile.jl") end
@time @safetestset "Linear Nonlinear Solver Tests" begin include("interface/linear_nonlinear_tests.jl") end
@time @safetestset "Linear Solver Tests" begin include("interface/linear_solver_test.jl") end
@time @safetestset "Linear Solver Split ODE Tests" begin include("interface/linear_solver_split_ode_test.jl") end
Expand Down