diff --git a/Project.toml b/Project.toml index d17ebc20ca..64862eeef2 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -58,7 +59,7 @@ Polyester = "0.3, 0.4, 0.5, 0.6" PreallocationTools = "0.2, 0.3, 0.4" RecursiveArrayTools = "2.26.3" Reexport = "0.2, 1.0" -SciMLBase = "1.44" +SciMLBase = "1.50" SnoopPrecompile = "1" SparseDiffTools = "1.19.1" StaticArrays = "0.11, 0.12, 1.0" diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 14c3723f35..ce3a626f62 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -66,7 +66,8 @@ using DocStringExtensions using DiffEqBase: Convergence, Divergence const TryAgain = DiffEqBase.SlowConvergence - import DiffEqBase: calculate_residuals, calculate_residuals!, unwrap_cache, @tight_loop_macros, islinear, timedepentdtmin + import DiffEqBase: calculate_residuals, calculate_residuals!, unwrap_cache, @tight_loop_macros, + islinear, timedepentdtmin, OrdinaryDiffEqTag import SparseDiffTools import SparseDiffTools: matrix_colors, forwarddiff_color_jacobian!, @@ -80,9 +81,8 @@ using DocStringExtensions ForwardDiff.Dual{ForwardDiff.Tag{T,W},K,3} where {T,W<:Union{Float64,Float32}, K<:Union{Float64,Float32}}} - struct OrdinaryDiffEqTag end - import ArrayInterfaceStaticArrays, ArrayInterfaceGPUArrays + import FunctionWrappersWrappers DEFAULT_PRECS(W,du,u,p,t,newW,Plprev,Prprev,solverdata) = nothing,nothing @@ -193,31 +193,87 @@ using DocStringExtensions import SnoopPrecompile SnoopPrecompile.@precompile_all_calls begin - function lorenz(du,u,p,t) - du[1] = 10.0(u[2]-u[1]) - du[2] = u[1]*(28.0-u[3]) - u[2] - du[3] = u[1]*u[2] - (8/3)*u[3] - end - lorenzprob = ODEProblem(lorenz,[1.0;0.0;0.0],(0.0,1.0)) - solve(lorenzprob,BS3()) - solve(lorenzprob,Tsit5()) - solve(lorenzprob,Vern7()) - solve(lorenzprob,Vern9()) - solve(lorenzprob,Rosenbrock23())(5.0) - solve(lorenzprob,TRBDF2()) - solve(lorenzprob,Rodas4(autodiff=false)) - solve(lorenzprob,KenCarp4(autodiff=false)) - solve(lorenzprob,Rodas5()) - solve(lorenzprob,QNDF()) - solve(lorenzprob,QNDF(autodiff=false)) - solve(lorenzprob,AutoTsit5(Rosenbrock23())) - solve(lorenzprob,AutoTsit5(Rosenbrock23(autodiff=false))) - solve(lorenzprob,AutoTsit5(TRBDF2(autodiff=false))) - solve(lorenzprob,AutoVern7(Rodas4(autodiff=false))) - solve(lorenzprob,AutoVern7(TRBDF2(autodiff=false))) - solve(lorenzprob,AutoVern9(Rodas5(autodiff=false))) - solve(lorenzprob,AutoVern9(KenCarp47(autodiff=false))) - solve(lorenzprob,AutoVern7(Rodas5())) + function lorenz(du,u,p,t) + du[1] = 10.0(u[2]-u[1]) + du[2] = u[1]*(28.0-u[3]) - u[2] + du[3] = u[1]*u[2] - (8/3)*u[3] + end + + function lorenz_oop(u,p,t) + [10.0(u[2]-u[1]),u[1]*(28.0-u[3]) - u[2],u[1]*u[2] - (8/3)*u[3]] + end + + solver_list = [ + BS3(), Tsit5(), Vern7(), Vern9(), + + Rosenbrock23(), Rosenbrock23(autodiff=false), + Rosenbrock23(chunk_size = 1), Rosenbrock23(chunk_size = Val{1}()), + + Rodas4(), Rodas4(autodiff=false), + #Rodas4(chunk_size = 1), Rodas4(chunk_size = Val{1}()), + + Rodas5(), Rodas5(autodiff=false), + #Rodas5(chunk_size = 1), Rodas5(chunk_size = Val{1}()), + + Rodas5P(), Rodas5P(autodiff=false), + Rodas5P(chunk_size = 1), Rodas5P(chunk_size = Val{1}()), + + TRBDF2(), TRBDF2(autodiff=false), + #TRBDF2(chunk_size = 1), TRBDF2(chunk_size = Val{1}()), + + KenCarp4(), KenCarp4(autodiff=false), + #KenCarp4(chunk_size = 1), KenCarp4(chunk_size = Val{1}()), + + QNDF(), QNDF(autodiff=false), + #QNDF(chunk_size = 1), QNDF(chunk_size = Val{1}()), + + AutoTsit5(Rosenbrock23()), AutoTsit5(Rosenbrock23(autodiff=false)), + AutoTsit5(Rosenbrock23(chunk_size = 1)), + AutoTsit5(Rosenbrock23(chunk_size = Val{1}())), + + AutoTsit5(TRBDF2()), AutoTsit5(TRBDF2(autodiff=false)), + #AutoTsit5(TRBDF2(chunk_size = 1)), + #AutoTsit5(TRBDF2(chunk_size = Val{1}())), + + AutoVern9(KenCarp47()), AutoVern9(KenCarp47(autodiff=false)), + #AutoVern9(KenCarp47(chunk_size = 1)), + #AutoVern9(KenCarp47(chunk_size = Val{1}())), + + AutoVern9(Rodas5()), AutoVern9(Rodas5(autodiff=false)), + AutoVern9(Rodas5(chunk_size = 1)), + AutoVern9(Rodas5(chunk_size = Val{1}())), + + AutoVern9(Rodas5P()), AutoVern9(Rodas5P(autodiff=false)), + AutoVern9(Rodas5P(chunk_size = 1)), + AutoVern9(Rodas5P(chunk_size = Val{1}())), + + AutoVern7(Rodas4()), AutoVern7(Rodas4(autodiff=false)), + #AutoVern7(Rodas4(chunk_size = 1)), + #AutoVern7(Rodas4(chunk_size = Val{1}())), + + #AutoVern7(Rodas5P()), AutoVern7(Rodas5P(autodiff=false)), + #AutoVern7(Rodas5P(chunk_size = 1)), + #AutoVern7(Rodas5P(chunk_size = Val{1}())), + + AutoVern7(TRBDF2()), AutoVern7(TRBDF2(autodiff=false)), + #AutoVern7(TRBDF2(chunk_size = 1)), + #AutoVern7(TRBDF2(chunk_size = Val{1}())), + ] + + prob_list = [ + ODEProblem(lorenz,[1.0;0.0;0.0],(0.0,1.0)) + ODEProblem{true,false}(lorenz,[1.0;0.0;0.0],(0.0,1.0)) + ODEProblem{true,false}(lorenz,[1.0;0.0;0.0],(0.0,1.0),Float64[]) + ODEProblem(lorenz_oop,[1.0;0.0;0.0],(0.0,1.0)) + #ODEProblem{false,false}(lorenz_oop,[1.0;0.0;0.0],(0.0,1.0)) + #ODEProblem{false,false}(lorenz_oop,[1.0;0.0;0.0],(0.0,1.0),Float64[]) + ] + + for prob in prob_list, solver in solver_list + solve(prob,solver)(5.0) + end + + prob_list = nothing end #General Functions diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 84d748934f..937fbad257 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -199,7 +199,12 @@ function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorit linsolve = alg.linsolve end - isbitstype(T) && sizeof(T) > 24 && return remake(alg, chunk_size=Val{1}(), linsolve=linsolve) + # If norecompile mode or very large bitsize, like a dual number u0 already, then + # don't use a large chunksize as it will either error or not be beneficial + if (isbitstype(T) && sizeof(T) > 24) || (prob.f isa ODEFunction && prob.f.f isa + FunctionWrappersWrappers.FunctionWrappersWrapper) + return remake(alg, chunk_size=Val{1}(), linsolve=linsolve) + end L = ArrayInterface.known_length(typeof(u0)) if L === nothing # dynamic sized diff --git a/test/interface/norecompile.jl b/test/interface/norecompile.jl new file mode 100644 index 0000000000..70fb935815 --- /dev/null +++ b/test/interface/norecompile.jl @@ -0,0 +1,57 @@ +using OrdinaryDiffEq, Test +function f(du, u, p, t) + du[1] = 0.2u[1] + du[2] = 0.4u[2] +end +u0 = ones(2) +tspan = (0.0, 1.0) +prob = ODEProblem{true,false}(f, u0, tspan, Float64[]) + +function lorenz(du, u, p, t) + du[1] = 10.0(u[2] - u[1]) + du[2] = u[1] * (28.0 - u[3]) - u[2] + du[3] = u[1] * u[2] - (8 / 3) * u[3] +end +lorenzprob = ODEProblem{true,false}(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[]) +@test typeof(prob) === typeof(lorenzprob) + +t1 = @elapsed sol = solve(lorenzprob, Rosenbrock23()) +t2 = @elapsed sol = solve(lorenzprob, Rosenbrock23(autodiff=false)) + +lorenzprob2 = ODEProblem(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[]) + +t3 = @elapsed sol = solve(lorenzprob2, Rosenbrock23()) +t4 = @elapsed sol = solve(lorenzprob2, Rosenbrock23(autodiff=false)) + +if VERSION >= v"1.8" + @test 5t1 < t3 + @test t2 < t4 +end + +function f_oop(u, p, t) + [0.2u[1], 0.4u[2]] +end +u0 = ones(2) +tspan = (0.0, 1.0) +prob = ODEProblem{false,false}(f_oop, u0, tspan, Float64[]) + +function lorenz_oop(u, p, t) + [10.0(u[2] - u[1]), u[1] * (28.0 - u[3]) - u[2], u[1] * u[2] - (8 / 3) * u[3]] +end +lorenzprob = ODEProblem{false,false}(lorenz_oop, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[]) +@test typeof(prob) === typeof(lorenzprob) + +# This one is fundamentally hard / broken +# Since the equation is not dependent on `t`, the output is not dual of t +# This is problem-dependent, so it is hard to deduce a priori +@test_broken t1 = @elapsed sol = solve(lorenzprob, Rosenbrock23()) + +t2 = @elapsed sol = solve(lorenzprob, Rosenbrock23(autodiff=false)) + +lorenzprob2 = ODEProblem(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[]) + +t3 = @elapsed sol = solve(lorenzprob2, Rosenbrock23()) +t4 = @elapsed sol = solve(lorenzprob2, Rosenbrock23(autodiff=false)) + +#@test 5t1 < t3 +#@test t2 < t4 diff --git a/test/runtests.jl b/test/runtests.jl index c32b92bea2..b93b8e7eda 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,6 +42,7 @@ if GROUP == "All" || GROUP == "InterfaceI" || GROUP == "Interface" end if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceII" || GROUP == "Interface") + @time @safetestset "No Recompile Tests" begin include("interface/norecompile.jl") end @time @safetestset "Linear Nonlinear Solver Tests" begin include("interface/linear_nonlinear_tests.jl") end @time @safetestset "Linear Solver Tests" begin include("interface/linear_solver_test.jl") end @time @safetestset "Linear Solver Split ODE Tests" begin include("interface/linear_solver_split_ode_test.jl") end