Skip to content
Merged
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GlobalSensitivity = "af5da776-676b-467e-8baf-acd8249e4f0f"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand All @@ -30,7 +31,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
Expand All @@ -57,6 +57,7 @@ FFTW = "1.1"
FiniteDiff = "2"
ForwardDiff = "0.10"
GlobalSensitivity = "1.0"
GPUArrays = "8"
LinearSolve = "1"
OrdinaryDiffEq = "5.60, 6"
Parameters = "0.12"
Expand All @@ -65,7 +66,6 @@ QuasiMonteCarlo = "0.1, 0.2"
RandomNumbers = "1.5.3"
RecursiveArrayTools = "2.4.2"
Reexport = "0.2, 1.0"
Requires = "1"
ReverseDiff = "1.9"
SciMLBase = "1.24"
StochasticDiffEq = "6.20"
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@
DiffEqSensitivity.jl is a component package in the [SciML Scientific Machine Learning ecosystem](https://sciml.ai/). It holds the
sensitivity analysis utilities. Users interested in using this
functionality should check out [DifferentialEquations.jl](https://github.com/JuliaDiffEq/DifferentialEquations.jl).

7 changes: 5 additions & 2 deletions src/DiffEqSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using DiffEqOperators
using Adapt
using LinearSolve
using Parameters: @unpack
using Requires
using StochasticDiffEq
using SharedArrays
using EllipsisNotation
Expand All @@ -16,6 +15,7 @@ using Random
import ZygoteRules, Zygote, ReverseDiff
import ArrayInterface
import Enzyme
import GPUArrays

using Cassette, DiffRules
using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot
Expand All @@ -26,7 +26,6 @@ import ChainRulesCore: @thunk, NoTangent, @not_implemented
abstract type SensitivityFunction end
abstract type TransformedFunction end

include("require.jl")
include("hasbranching.jl")
include("sensitivity_algorithms.jl")
include("derivative_wrappers.jl")
Expand All @@ -45,6 +44,10 @@ include("second_order.jl")
include("steadystate_adjoint.jl")
include("sde_tools.jl")

# AD Extensions
include("reversediff.jl")
include("tracker.jl")
include("zygote.jl")


export extract_local_sensitivities
Expand Down
4 changes: 2 additions & 2 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{ODEProblem,SDEProblem},
!(eltype(p) <: Complex) &&
length(u0) + length(p) <= 100
ForwardDiffSensitivity()
elseif isgpu(u0) || !DiffEqBase.isinplace(prob)
elseif u0 isa GPUArrays.AbstractGPUArray || !DiffEqBase.isinplace(prob)
# only Zygote is GPU compatible and fast
# so if out-of-place, try Zygote
if p === nothing || p === DiffEqBase.NullParameters()
Expand All @@ -66,7 +66,7 @@ end
function DiffEqBase._concrete_solve_adjoint(prob::Union{NonlinearProblem,SteadyStateProblem},alg,
sensealg::Nothing,u0,p,args...;kwargs...)

default_sensealg = if isgpu(u0) || !DiffEqBase.isinplace(prob)
default_sensealg = if u0 isa GPUArrays.AbstractGPUArray || !DiffEqBase.isinplace(prob)
# autodiff = false because forwarddiff fails on many GPU kernels
# this only effects the Jacobian calculation and is same computation order
SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP())
Expand Down
32 changes: 0 additions & 32 deletions src/require.jl

This file was deleted.

57 changes: 57 additions & 0 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Piracy that used to be requires, allowing ReverseDiff.jl to be specialized for SciML

DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value

DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, t0) = u0
DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = u0
DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = u0
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)

# Support adaptive with non-tracked time
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t) where {N}
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal,N}, t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal,N}, t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t) = abs(DiffEqBase.value(u))

# Support TrackedReal time, don't drop tracking on the adaptivity there
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(abs2, u) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal,N}, t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal,N}, t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u))
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t::ReverseDiff.TrackedReal) = abs(u)

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, args...; kwargs...)
ReverseDiff.track(solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0, p::ReverseDiff.TrackedArray, args...; kwargs...)
ReverseDiff.track(solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::ReverseDiff.TrackedArray, p, args...; kwargs...)
ReverseDiff.track(solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG, proto::ReverseDiff.TrackedArray)
ReverseDiff.track(convert.(eltype(proto.value), randn(rng, size(proto))))
end
@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, rand_vec::Array{<:ReverseDiff.TrackedReal})
rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec))))
end
@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG, rand_vec::AbstractArray{<:ReverseDiff.TrackedReal})
rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec))))
end
65 changes: 65 additions & 0 deletions src/tracker.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Piracy that used to be requires, allowing Tracker.jl to be specialized for SciML

function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T,N}, a::AbstractArray{T2,N}) where {T<:Tracker.TrackedArray,T2<:Tracker.TrackedArray,N}
@inbounds for i in eachindex(a)
b[i] = copy(a[i])
end
end

DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T
DiffEqBase.value(x::Type{Tracker.TrackedArray{T,N,A}}) where {T,N,A} = Array{T,N}
DiffEqBase.value(x::Tracker.TrackedReal) = x.data
DiffEqBase.value(x::Tracker.TrackedArray) = x.data

DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0
DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, p::Tracker.TrackedArray, t0) = u0
DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::AbstractArray{<:Tracker.TrackedReal}, t0) = u0
DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, p::AbstractArray{<:Tracker.TrackedReal}, t0) = u0
DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0)
DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0)


@inline DiffEqBase.fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x))

# Support adaptive with non-tracked time
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t) where {N}
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N}, t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N}, t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t) = abs(DiffEqBase.value(u))

# Support TrackedReal time, don't drop tracking on the adaptivity there
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t::Tracker.TrackedReal) where {N}
sqrt(sum(abs2, u) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N}, t::Tracker.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N}, t::Tracker.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u))
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u)

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::Tracker.TrackedArray, p::Tracker.TrackedArray, args...; kwargs...)
Tracker.track(solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0::Tracker.TrackedArray, p, args...; kwargs...)
Tracker.track(solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, sensealg::Union{DiffEqBase.AbstractSensitivityAlgorithm,Nothing}, u0, p::Tracker.TrackedArray, args...; kwargs...)
Tracker.track(solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

Tracker.@grad function DiffEqBase.solve_up(prob, sensealg::Union{Nothing,DiffEqBase.AbstractSensitivityAlgorithm},
u0, p, args...;
kwargs...)
_solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p), args...; kwargs...)
end
41 changes: 41 additions & 0 deletions src/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Piracy that used to be requires, allowing Zyogote.jl to be specialized for SciML

function ∇tmap(cx, f, args...)
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
function ∇tmap_internal(Δ)
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
ys, ∇tmap_internal
end
end

function ∇responsible_map(cx, f, args...)
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
ys, function ∇responsible_map_internal(Δ)
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), Zygote._tryreverse(SciMLBase.responsible_map, backs, Δ)...)
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, Δf_and_args_zipped))
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
end
end

ZygoteRules.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray,Tuple}...)
∇tmap(__context__, f, args...)
end

ZygoteRules.@adjoint function SciMLBase.responsible_map(f, args::Union{AbstractArray,Tuple}...)
∇responsible_map(__context__, f, args...)
end