diff --git a/Project.toml b/Project.toml index 76c04c372..f6e4137c0 100644 --- a/Project.toml +++ b/Project.toml @@ -7,8 +7,10 @@ version = "6.70.0" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" @@ -17,6 +19,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" @@ -36,12 +39,14 @@ ChainRulesCore = "0.10, 1" DataStructures = "0.18" DocStringExtensions = "0.8" FastBroadcast = "0.1.4" +ForwardDiff = "0.10" FunctionWrappers = "1.0" IterativeSolvers = "0.9" LabelledArrays = "1.1" MuladdMacro = "0.2.1" NonlinearSolve = "0.3.0" Parameters = "0.12.0" +PreallocationTools = "0.1.0" RecursiveArrayTools = "2" RecursiveFactorization = "0.1" Reexport = "0.2, 1.0" diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index f64612120..0ec1b679f 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -20,6 +20,10 @@ using Statistics using FastBroadcast: @.. +using PreallocationTools +import PreallocationTools: get_tmp + +import Distributions import ChainRulesCore import LabelledArrays import RecursiveArrayTools @@ -29,6 +33,8 @@ import ZygoteRules using Setfield +using ForwardDiff + @reexport using SciMLBase using SciMLBase: @def, DEIntegrator, DEProblem, AbstractDiffEqOperator, @@ -121,6 +127,7 @@ include("data_array.jl") include("solve.jl") include("internal_euler.jl") include("init.jl") +include("forwarddiff.jl") include("chainrules.jl") """ diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl new file mode 100644 index 000000000..5a848e8bc --- /dev/null +++ b/src/forwarddiff.jl @@ -0,0 +1,42 @@ +promote_u0(u0::AbstractArray{<:ForwardDiff.Dual},p::AbstractArray{<:ForwardDiff.Dual},t0) = u0 +promote_u0(u0,p::AbstractArray{<:ForwardDiff.Dual},t0) = eltype(p).(u0) +promote_u0(u0,p::NTuple{N,<:ForwardDiff.Dual},t0) where N = eltype(p).(u0) +promote_u0(u0,p::ForwardDiff.Dual,t0) where N = eltype(p).(u0) + +function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual},p,tspan::Tuple{<:ForwardDiff.Dual,<:ForwardDiff.Dual},prob,kwargs) + return tspan +end + +function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual},p,tspan,prob,kwargs) + if (haskey(kwargs,:callback) && has_continuous_callback(kwargs[:callback])) || + (haskey(prob.kwargs,:callback) && has_continuous_callback(prob.kwargs[:callback])) + + return eltype(u0).(tspan) + else + return tspan + end +end + +value(x::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} = V +value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x)) + +@inline fastpow(x::ForwardDiff.Dual, y::ForwardDiff.Dual) = x^y + +sse(x::Number) = x^2 +sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x)) +totallength(x::Number) = 1 +totallength(x::ForwardDiff.Dual) = totallength(ForwardDiff.value(x)) + sum(totallength, ForwardDiff.partials(x)) +totallength(x::AbstractArray) = sum(totallength,x) + +@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual,::Any) = sqrt(sse(u)) +@inline ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual},t::Any) = sqrt(sum(sse,u) / totallength(u)) +@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual,::ForwardDiff.Dual) = sqrt(sse(u)) +@inline ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual},::ForwardDiff.Dual) = sqrt(sum(x->sse(x),u) / totallength(u)) + +if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual}) + # Type piracy. Should upstream + Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(d.value), d.partials) + Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(d.value), d.partials) +end + +# bisection(f, tup::Tuple{T,T}, t_forward::Bool) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi()) diff --git a/src/init.jl b/src/init.jl index e738fd796..5f3c737f4 100644 --- a/src/init.jl +++ b/src/init.jl @@ -14,85 +14,6 @@ function __init__() eval_u0(u0::ApproxFun.Fun) = false end - @require Distributions="31c24e10-a181-5473-b8eb-7969acd0382f" begin - handle_distribution_u0(_u0::Distributions.Sampleable) = rand(_u0) - isdistribution(_u0::Distributions.Sampleable) = true - end - - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin - - promote_u0(u0::AbstractArray{<:ForwardDiff.Dual},p::AbstractArray{<:ForwardDiff.Dual},t0) = u0 - promote_u0(u0,p::AbstractArray{<:ForwardDiff.Dual},t0) = eltype(p).(u0) - promote_u0(u0,p::NTuple{N,<:ForwardDiff.Dual},t0) where N = eltype(p).(u0) - promote_u0(u0,p::ForwardDiff.Dual,t0) where N = eltype(p).(u0) - - function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual},p,tspan::Tuple{<:ForwardDiff.Dual,<:ForwardDiff.Dual},prob,kwargs) - return tspan - end - - function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual},p,tspan,prob,kwargs) - if (haskey(kwargs,:callback) && has_continuous_callback(kwargs[:callback])) || - (haskey(prob.kwargs,:callback) && has_continuous_callback(prob.kwargs[:callback])) - - return eltype(u0).(tspan) - else - return tspan - end - end - - value(x::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} = V - value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x)) - - @inline fastpow(x::ForwardDiff.Dual, y::ForwardDiff.Dual) = x^y - - sse(x::Number) = x^2 - sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x)) - totallength(x::Number) = 1 - totallength(x::ForwardDiff.Dual) = totallength(ForwardDiff.value(x)) + sum(totallength, ForwardDiff.partials(x)) - totallength(x::AbstractArray) = sum(totallength,x) - - @inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual,::Any) = sqrt(sse(u)) - @inline ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual},t::Any) = sqrt(sum(sse,u) / totallength(u)) - @inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual,::ForwardDiff.Dual) = sqrt(sse(u)) - @inline ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual},::ForwardDiff.Dual) = sqrt(sum(x->sse(x),u) / totallength(u)) - - if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual}) - # Type piracy. Should upstream - Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(d.value), d.partials) - Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(d.value), d.partials) - end - - struct DiffCache{T<:AbstractArray, S<:AbstractArray} - du::T - dual_du::S - end - - function DiffCache(u::AbstractArray{T}, siz, ::Type{Val{chunk_size}}) where {T, chunk_size} - x = ArrayInterface.restructure(u,zeros(ForwardDiff.Dual{nothing,T,chunk_size}, siz...)) - DiffCache(u, x) - end - - dualcache(u::AbstractArray, N=Val{ForwardDiff.pickchunksize(length(u))}) = DiffCache(u, size(u), N) - - function get_tmp(dc::DiffCache, u::T) where T<:ForwardDiff.Dual - x = reinterpret(T, dc.dual_du) - end - - function get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual - x = reinterpret(T, dc.dual_du) - end - - function DiffEqBase.get_tmp(dc::DiffEqBase.DiffCache, u::LabelledArrays.LArray{T,N,D,Syms}) where {T,N,D,Syms} - x = reinterpret(T, dc.dual_du.__x) - LabelledArrays.LArray{T,N,D,Syms}(x) - end - - get_tmp(dc::DiffCache, u::Number) = dc.du - get_tmp(dc::DiffCache, u::AbstractArray) = dc.du - - # bisection(f, tup::Tuple{T,T}, t_forward::Bool) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi()) - end - @require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin promote_u0(u0::AbstractArray{<:Measurements.Measurement},p::AbstractArray{<:Measurements.Measurement},t0) = u0 @@ -152,40 +73,6 @@ function __init__() end # Piracy, should get upstreamed - @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin - cuify(x::AbstractArray) = CuArrays.CuArray(x) - function LinearAlgebra.ldiv!(x::CuArrays.CuArray,_qr::CuArrays.CUSOLVER.CuQR,b::CuArrays.CuArray) - _x = UpperTriangular(_qr.R) \ (_qr.Q' * reshape(b,length(b),1)) - x .= vec(_x) - CuArrays.unsafe_free!(_x) - return x - end - # make `\` work - LinearAlgebra.ldiv!(F::CuArrays.CUSOLVER.CuQR, b::CuArrays.CuArray) = (x = similar(b); ldiv!(x, F, b); x) - default_factorize(A::CuArrays.CuArray) = qr(A) - function findall_events(affect!,affect_neg!,prev_sign::CuArrays.CuArray,next_sign::CuArrays.CuArray) - hasaffect::Bool = affect! !== nothing - hasaffectneg::Bool = affect_neg! !== nothing - f = (p,n)-> ((p < 0 && hasaffect) || (p > 0 && hasaffectneg)) && p*n<=0 - A = map(f,prev_sign,next_sign) - out = findall(A) - CuArrays.unsafe_free!(A) - out - end - - ODE_DEFAULT_NORM(u::CuArrays.CuArray,t) = sqrt(real(sum(abs2,u))/length(u)) - - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin - @inline function ODE_DEFAULT_NORM(u::CuArrays.CuArray{<:ForwardDiff.Dual},t) - sqrt(sum(abs2,value.(u)) / length(u)) - end - - @inline function ODE_DEFAULT_NORM(u::CuArrays.CuArray{<:ForwardDiff.Dual},t::ForwardDiff.Dual) - sqrt(sum(abs2,u) / length(u)) - end - end - end - @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin cuify(x::AbstractArray) = CUDA.CuArray(x) function LinearAlgebra.ldiv!(x::CUDA.CuArray,_qr::CUDA.CUSOLVER.CuQR,b::CUDA.CuArray) @@ -209,14 +96,12 @@ function __init__() ODE_DEFAULT_NORM(u::CUDA.CuArray,t) = sqrt(real(sum(abs2,u))/length(u)) - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin - @inline function ODE_DEFAULT_NORM(u::CUDA.CuArray{<:ForwardDiff.Dual},t) - sqrt(sum(abs2,value.(u)) / length(u)) - end + @inline function ODE_DEFAULT_NORM(u::CUDA.CuArray{<:ForwardDiff.Dual},t) + sqrt(sum(abs2,value.(u)) / length(u)) + end - @inline function ODE_DEFAULT_NORM(u::CUDA.CuArray{<:ForwardDiff.Dual},t::ForwardDiff.Dual) - sqrt(sum(abs2,u) / length(u)) - end + @inline function ODE_DEFAULT_NORM(u::CUDA.CuArray{<:ForwardDiff.Dual},t::ForwardDiff.Dual) + sqrt(sum(abs2,u) / length(u)) end end diff --git a/src/solve.jl b/src/solve.jl index 514774650..1754387a9 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -67,7 +67,7 @@ function solve(prob::DEProblem,args...;sensealg=nothing, u0 = nothing, p = nothing, kwargs...) u0 = u0 !== nothing ? u0 : prob.u0 p = p !== nothing ? p : prob.p - if sensealg === nothing && hasproperty(prob,:kwargs) && haskey(prob.kwargs,:sensealg) + if sensealg === nothing && haskey(prob.kwargs,:sensealg) sensealg = prob.kwargs[:sensealg] end solve_up(prob,sensealg,u0,p,args...;kwargs...) @@ -265,6 +265,9 @@ function get_concrete_p(prob, kwargs) end handle_distribution_u0(_u0) = _u0 +handle_distribution_u0(_u0::Distributions.Sampleable) = rand(_u0) +isdistribution(_u0::Distributions.Sampleable) = true + eval_u0(u0::Function) = true eval_u0(u0) = false