diff --git a/Project.toml b/Project.toml index 3be52cd14..638a0ca22 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GlobalSensitivity = "af5da776-676b-467e-8baf-acd8249e4f0f" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" @@ -30,7 +31,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383" @@ -57,6 +57,7 @@ FFTW = "1.1" FiniteDiff = "2" ForwardDiff = "0.10" GlobalSensitivity = "1.0" +GPUArrays = "8" LinearSolve = "1" OrdinaryDiffEq = "5.60, 6" Parameters = "0.12" @@ -65,7 +66,6 @@ QuasiMonteCarlo = "0.1, 0.2" RandomNumbers = "1.5.3" RecursiveArrayTools = "2.4.2" Reexport = "0.2, 1.0" -Requires = "1" ReverseDiff = "1.9" SciMLBase = "1.24" StochasticDiffEq = "6.20" diff --git a/README.md b/README.md index 4b9ab4619..9e38fad05 100644 --- a/README.md +++ b/README.md @@ -9,4 +9,3 @@ DiffEqSensitivity.jl is a component package in the [SciML Scientific Machine Learning ecosystem](https://sciml.ai/). It holds the sensitivity analysis utilities. Users interested in using this functionality should check out [DifferentialEquations.jl](https://github.com/JuliaDiffEq/DifferentialEquations.jl). - diff --git a/src/DiffEqSensitivity.jl b/src/DiffEqSensitivity.jl index 037914ef0..fdd30e2e3 100644 --- a/src/DiffEqSensitivity.jl +++ b/src/DiffEqSensitivity.jl @@ -6,7 +6,6 @@ using DiffEqOperators using Adapt using LinearSolve using Parameters: @unpack -using Requires using StochasticDiffEq using SharedArrays using EllipsisNotation @@ -16,6 +15,7 @@ using Random import ZygoteRules, Zygote, ReverseDiff import ArrayInterface import Enzyme +import GPUArrays using Cassette, DiffRules using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot @@ -26,7 +26,6 @@ import ChainRulesCore: @thunk, NoTangent, @not_implemented abstract type SensitivityFunction end abstract type TransformedFunction end -include("require.jl") include("hasbranching.jl") include("sensitivity_algorithms.jl") include("derivative_wrappers.jl") @@ -45,6 +44,10 @@ include("second_order.jl") include("steadystate_adjoint.jl") include("sde_tools.jl") +# AD Extensions +include("reversediff.jl") +include("tracker.jl") +include("zygote.jl") export extract_local_sensitivities diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 79f44ae1c..103808c8b 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -42,7 +42,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{ODEProblem,SDEProblem}, !(eltype(p) <: Complex) && length(u0) + length(p) <= 100 ForwardDiffSensitivity() - elseif isgpu(u0) || !DiffEqBase.isinplace(prob) + elseif u0 isa GPUArrays.AbstractGPUArray || !DiffEqBase.isinplace(prob) # only Zygote is GPU compatible and fast # so if out-of-place, try Zygote if p === nothing || p === DiffEqBase.NullParameters() @@ -66,7 +66,7 @@ end function DiffEqBase._concrete_solve_adjoint(prob::Union{NonlinearProblem,SteadyStateProblem},alg, sensealg::Nothing,u0,p,args...;kwargs...) - default_sensealg = if isgpu(u0) || !DiffEqBase.isinplace(prob) + default_sensealg = if u0 isa GPUArrays.AbstractGPUArray || !DiffEqBase.isinplace(prob) # autodiff = false because forwarddiff fails on many GPU kernels # this only effects the Jacobian calculation and is same computation order SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP()) diff --git a/src/require.jl b/src/require.jl deleted file mode 100644 index 1a3ba6ffe..000000000 --- a/src/require.jl +++ /dev/null @@ -1,32 +0,0 @@ -isgpu(x) = false -function __init__() - @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin - gpu_or_cpu(x::CuArrays.CuArray) = CuArrays.CuArray - gpu_or_cpu(x::TrackedArray{<:Any,<:Any,<:CuArrays.CuArray}) = CuArrays.CuArray - gpu_or_cpu(x::Transpose{<:Any,<:CuArrays.CuArray}) = CuArrays.CuArray - gpu_or_cpu(x::Adjoint{<:Any,<:CuArrays.CuArray}) = CuArrays.CuArray - gpu_or_cpu(x::Adjoint{<:Any,TrackedArray{<:Any,<:Any,<:CuArrays.CuArray}}) = CuArrays.CuArray - gpu_or_cpu(x::Transpose{<:Any,TrackedArray{<:Any,<:Any,<:CuArrays.CuArray}}) = CuArrays.CuArray - isgpu(::CuArrays.CuArray) = true - isgpu(::TrackedArray{<:Any,<:Any,<:CuArrays.CuArray}) = true - isgpu(::Transpose{<:Any,<:CuArrays.CuArray}) = true - isgpu(::Adjoint{<:Any,<:CuArrays.CuArray}) = true - isgpu(::Adjoint{<:Any,TrackedArray{<:Any,<:Any,<:CuArrays.CuArray}}) = true - isgpu(::Transpose{<:Any,TrackedArray{<:Any,<:Any,<:CuArrays.CuArray}}) = true - end - - @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin - gpu_or_cpu(x::CUDA.CuArray) = CUDA.CuArray - gpu_or_cpu(x::TrackedArray{<:Any,<:Any,<:CUDA.CuArray}) = CUDA.CuArray - gpu_or_cpu(x::Transpose{<:Any,<:CUDA.CuArray}) = CUDA.CuArray - gpu_or_cpu(x::Adjoint{<:Any,<:CUDA.CuArray}) = CUDA.CuArray - gpu_or_cpu(x::Adjoint{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = CUDA.CuArray - gpu_or_cpu(x::Transpose{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = CUDA.CuArray - isgpu(::CUDA.CuArray) = true - isgpu(::TrackedArray{<:Any,<:Any,<:CUDA.CuArray}) = true - isgpu(::Transpose{<:Any,<:CUDA.CuArray}) = true - isgpu(::Adjoint{<:Any,<:CUDA.CuArray}) = true - isgpu(::Adjoint{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = true - isgpu(::Transpose{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = true - end -end diff --git a/src/reversediff.jl b/src/reversediff.jl new file mode 100644 index 000000000..5c4407923 --- /dev/null +++ b/src/reversediff.jl @@ -0,0 +1,57 @@ +# Piracy that used to be requires, allowing ReverseDiff.jl to be specialized for SciML + +DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value +DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value + +DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0 +DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, t0) = u0 +DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = u0 +DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = u0 +DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) +DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) + +# Support adaptive with non-tracked time +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t) where {N} + sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal,N}, t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal,N}, t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t) = abs(DiffEqBase.value(u)) + +# Support TrackedReal time, don't drop tracking on the adaptivity there +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t::ReverseDiff.TrackedReal) where {N} + sqrt(sum(abs2, u) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal,N}, t::ReverseDiff.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal,N}, t::ReverseDiff.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) +end +@inline DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t::ReverseDiff.TrackedReal) = abs(u) + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, args...; kwargs...) + ReverseDiff.track(solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0, p::ReverseDiff.TrackedArray, args...; kwargs...) + ReverseDiff.track(solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::ReverseDiff.TrackedArray, p, args...; kwargs...) + ReverseDiff.track(solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG, proto::ReverseDiff.TrackedArray) + ReverseDiff.track(convert.(eltype(proto.value), randn(rng, size(proto)))) +end +@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, rand_vec::Array{<:ReverseDiff.TrackedReal}) + rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec)))) +end +@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, rand_vec::AbstractArray{<:ReverseDiff.TrackedReal}) + rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec)))) +end \ No newline at end of file diff --git a/src/tracker.jl b/src/tracker.jl new file mode 100644 index 000000000..7a84b84f2 --- /dev/null +++ b/src/tracker.jl @@ -0,0 +1,65 @@ +# Piracy that used to be requires, allowing Tracker.jl to be specialized for SciML + +function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T,N}, a::AbstractArray{T2,N}) where {T<:Tracker.TrackedArray,T2<:Tracker.TrackedArray,N} + @inbounds for i in eachindex(a) + b[i] = copy(a[i]) + end +end + +DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T +DiffEqBase.value(x::Type{Tracker.TrackedArray{T,N,A}}) where {T,N,A} = Array{T,N} +DiffEqBase.value(x::Tracker.TrackedReal) = x.data +DiffEqBase.value(x::Tracker.TrackedArray) = x.data + +DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0 +DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, p::Tracker.TrackedArray, t0) = u0 +DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::AbstractArray{<:Tracker.TrackedReal}, t0) = u0 +DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, p::AbstractArray{<:Tracker.TrackedReal}, t0) = u0 +DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0) +DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0) + + +@inline DiffEqBase.fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y +@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x)) + +# Support adaptive with non-tracked time +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t) where {N} + sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N}, t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N}, t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t) = abs(DiffEqBase.value(u)) + +# Support TrackedReal time, don't drop tracking on the adaptivity there +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t::Tracker.TrackedReal) where {N} + sqrt(sum(abs2, u) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N}, t::Tracker.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N}, t::Tracker.TrackedReal) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) +end +@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u) + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::Tracker.TrackedArray, p::Tracker.TrackedArray, args...; kwargs...) + Tracker.track(solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::Tracker.TrackedArray, p, args...; kwargs...) + Tracker.track(solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0, p::Tracker.TrackedArray, args...; kwargs...) + Tracker.track(solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +Tracker.@grad function DiffEqBase.solve_up(prob, sensealg::Union{Nothing,DiffEqBase.AbstractSensitivityAlgorithm}, + u0, p, args...; + kwargs...) + _solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p), args...; kwargs...) +end diff --git a/src/zygote.jl b/src/zygote.jl new file mode 100644 index 000000000..9c4a6e9d1 --- /dev/null +++ b/src/zygote.jl @@ -0,0 +1,41 @@ +# Piracy that used to be requires, allowing Zyogote.jl to be specialized for SciML + +function ∇tmap(cx, f, args...) + ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> (NoTangent(), NoTangent()) + else + ys, backs = Zygote.unzip(ys_and_backs) + function ∇tmap_internal(Δ) + Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ) + Δf_and_args = Zygote.unzip(Δf_and_args_zipped) + Δf = reduce(Zygote.accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + ys, ∇tmap_internal + end +end + +function ∇responsible_map(cx, f, args...) + ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> (NoTangent(), NoTangent()) + else + ys, backs = Zygote.unzip(ys_and_backs) + ys, function ∇responsible_map_internal(Δ) + # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. + Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), Zygote._tryreverse(SciMLBase.responsible_map, backs, Δ)...) + Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, Δf_and_args_zipped)) + Δf = reduce(Zygote.accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + end +end + +ZygoteRules.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray,Tuple}...) + ∇tmap(__context__, f, args...) +end + +ZygoteRules.@adjoint function SciMLBase.responsible_map(f, args::Union{AbstractArray,Tuple}...) + ∇responsible_map(__context__, f, args...) +end \ No newline at end of file