Skip to content

Differentiating ArrayPartitions in DEProblems #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 11, 2022

Conversation

frankschae
Copy link
Member

@frankschae frankschae commented Aug 4, 2022

Related issue: SciML/SciMLSensitivity.jl#689

The reason why some ArrayPartition tests are working is that the @code_typed from Zygote uses the ZygoteRules

ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
  function ArrayPartition_adjoint(_y)
      y = Array(_y)
      starts = vcat(0,cumsum(reduce(vcat,length.(x))))
      ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), nothing
  end

  ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
end

The failing ones instead use the ChainRules definition

function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
  function ArrayPartition_adjoint(_y)
      y = Array(_y)
      starts = vcat(0,cumsum(reduce(vcat,length.(x))))
      NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent()
  end

  ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
end

@frankschae frankschae changed the title Using ArrayPartitions together in DEProblems Using ArrayPartitions in DEProblems Aug 4, 2022
@frankschae frankschae changed the title Using ArrayPartitions in DEProblems Differentiating ArrayPartitions in DEProblems Aug 4, 2022
@frankschae
Copy link
Member Author

using Zygote, ForwardDiff
using ChainRules, ChainRulesCore
using RecursiveArrayTools

u0 = [1.0; 1.0]

function f(x)
  _x = ArrayPartition(x, x)
  sum(abs, _x.x[1])
end

f(u0)
Fdx = ForwardDiff.gradient(f, u0)
Zdx = Zygote.gradient(f, u0)[1] # works 

struct foo{uType}
  u0::uType
end

# function ChainRulesCore.rrule(::Type{foo}, u0)
#   foo_pullback(Δfoo) = NoTangent(), Δfoo.u0
#   return foo(u0), foo_pullback
# end

function f2(x)
  _x = ArrayPartition(x, x)
  _prob = foo(_x)
  sum(abs, _prob.u0[1])
end

f2(u0)
Fdx = ForwardDiff.gradient(f2, u0)
Zdx = Zygote.gradient(f2, u0)[1]


using DiffEqBase
function f3(x)
  _x = ArrayPartition(x, x)
  _prob = ODEProblem((u, p, t) -> u, _x, (0.0, 1.0))
  sum(abs, _prob.u0[1])
end

f3(u0)
Fdx = ForwardDiff.gradient(f3, u0)
Zdx = Zygote.gradient(f3, u0)[1]

If I use the rrule for the constructor of foo above, Zdx = Zygote.gradient(f2, u0)[1] works! Otherwise,

Zdx = Zygote.gradient(f2, u0)[1]
Zdx = Zygote.gradient(f3, u0)[1]

throw the same error:

ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ./essentials.jl:599 [inlined]
  [3] (::typeof((getindex)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:503 [inlined]
  [5] Pullback
    @ ./ntuple.jl:49 [inlined]
  [6] Pullback
    @ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:502 [inlined]
  [7] (::typeof((convert)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/SciMLSensitivity/689/MWE.jl:21 [inlined]
  [9] (::typeof((foo{ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}})))(Δ::NamedTuple{(:u0,), Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/SciMLSensitivity/689/MWE.jl:21 [inlined]
 [11] (::typeof((foo)))(Δ::NamedTuple{(:u0,), Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/SciMLSensitivity/689/MWE.jl:31 [inlined]
 [13] (::typeof((f2)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [14] (::Zygote.var"#60#61"{typeof((f2))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:41
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:76
 [16] top-level scope
    @ ~/SciMLSensitivity/689/MWE.jl:37
ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ./essentials.jl:599 [inlined]
  [3] (::typeof((getindex)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:503 [inlined]
  [5] Pullback
    @ ./ntuple.jl:49 [inlined]
  [6] Pullback
    @ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:502 [inlined]
  [7] (::typeof((convert)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:93 [inlined]
  [9] (::typeof((_#241)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:92 [inlined]
 [11] (::typeof((ODEProblem{false})))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:92 [inlined]
 [13] (::typeof((ODEProblem{false})))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [14] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof((ODEProblem{false}))})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:207
 [15] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof((ODEProblem{false}))}})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [16] Pullback
    @ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:131 [inlined]
 [17] (::typeof((#ODEProblem#251)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#216#217"{Tuple{NTuple{5, Nothing}, Tuple{Nothing}}, typeof((#ODEProblem#251))})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:207
 [19] #1909#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [20] Pullback
    @ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:131 [inlined]
 [21] (::typeof((ODEProblem)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:135 [inlined]
 [23] (::typeof((#ODEProblem#252)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:135 [inlined]
 [25] (::typeof((ODEProblem)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [26] Pullback
    @ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:135 [inlined]
 [27] (::typeof((ODEProblem)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [28] Pullback
    @ ~/SciMLSensitivity/689/MWE.jl:43 [inlined]
 [29] (::typeof((f3)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [30] (::Zygote.var"#60#61"{typeof((f3))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:41
 [31] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:76
 [32] top-level scope
    @ ~/SciMLSensitivity/689/MWE.jl:49

@ChrisRackauckas where is the rrule for ODEProblem?
@oxinabox Does this make sense to you? The way to see this was to delete all ChainRulesCore.rrules from https://github.com/SciML/RecursiveArrayTools.jl/blob/master/src/zygote.jl and then the error is more informative:

using Zygote, ForwardDiff
using ChainRules, ChainRulesCore
using RecursiveArrayTools

u0 = [1.0; 1.0]
struct foo{uType}
  u0::uType
end

# function ChainRulesCore.rrule(::Type{foo}, u0)
#   foo_pullback(Δfoo) = NoTangent(), Δfoo.u0
#   return foo(u0), foo_pullback
# end

function f2(x)
  _x = ArrayPartition(x, x)
  _prob = foo(_x)
  sum(abs, _prob.u0[1])
end

f2(u0)
Fdx = ForwardDiff.gradient(f2, u0)
Zdx = Zygote.gradient(f2, u0)[1]
ERROR: Need an adjoint for constructor ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}. Gradient is of type Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}, Nothing, false})(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:328
  [3] (::Zygote.var"#1943#back#232"{Zygote.Jnew{ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}, Nothing, false}})(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:27 [inlined]
  [5] (::typeof(∂(ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}})))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:502 [inlined]
  [7] (::typeof(∂(convert)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/SciMLSensitivity/689/MWE.jl:21 [inlined]
  [9] (::typeof(∂(foo{ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}})))(Δ::NamedTuple{(:u0,), Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/SciMLSensitivity/689/MWE.jl:21 [inlined]
 [11] (::typeof(∂(foo)))(Δ::NamedTuple{(:u0,), Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/SciMLSensitivity/689/MWE.jl:31 [inlined]
 [13] (::typeof(∂(f2)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [14] (::Zygote.var"#60#61"{typeof(∂(f2))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:41
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:76
 [16] top-level scope
    @ ~/SciMLSensitivity/689/MWE.jl:37

@ChrisRackauckas
Copy link
Member

@ChrisRackauckas where is the rrule for ODEProblem?

https://github.com/SciML/DiffEqBase.jl/blob/master/src/chainrules.jl

@ChrisRackauckas ChrisRackauckas merged commit abeaf3d into SciML:master Aug 11, 2022
@frankschae frankschae deleted the ODEProblem_u0 branch August 11, 2022 12:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants