Skip to content

Conversation

frankschae
Copy link
Member

Fixes the remaining error in the pullback of SciML/RecursiveArrayTools.jl#221.

The MWE

function loss(u0, p)
  _u0 = ArrayPartition(u0)
  _prob = ODEProblem(fiip, _u0, (0.0, 10.0), p)
  _sol = solve(_prob, Tsit5(), abstol=1e-14, reltol=1e-14, saveat=0.1)
  sum(_sol)
end

from SciML/SciMLSensitivity.jl#689 then works. However, it does not see this rrule definition for remake https://github.com/SciML/SciMLBase.jl/blob/0d55f56283fda3b6accb12fb282531447bdabe7c/src/remake.jl#L77 yet... Any ideas how to generalize it to account for the isinplace?

@ChrisRackauckas
Copy link
Member

Do you need to? It seems like just handling the problem constructors is fine?

@frankschae
Copy link
Member Author

Oh, yes. Probably a version/restart thing. I tried

function myremake(prob::ODEProblem; f=missing,
  u0=missing,
  tspan=missing,
  p=missing,
  kwargs=missing,
  _kwargs...)

  if f === missing
    f = prob.f
  elseif !isrecompile(prob)
    if isinplace(prob)
      f = wrapfun_iip(unwrap_fw(f), (u0, u0, p, tspan[1]))
    else
      f = wrapfun_oop(unwrap_fw(f), (u0, p, tspan[1]))
    end
    f = convert(ODEFunction{isinplace(prob)}, f)
  elseif prob.f isa ODEFunction # avoid the SplitFunction etc. cases
    f = convert(ODEFunction{isinplace(prob)}, f)
  end

  if u0 === missing
    u0 = prob.u0
  end

  if tspan === missing
    tspan = prob.tspan
  end

  if p === missing
    p = prob.p
  end

  if kwargs === missing
    ODEProblem{isinplace(prob)}(f, u0, tspan, p, prob.problem_type; prob.kwargs...,
      _kwargs...)
  else
    ODEProblem{isinplace(prob)}(f, u0, tspan, p, prob.problem_type; kwargs...)
  end
end

prob = ODEProblem((u, p, t) -> u, u0, (0.0, 1.0))
function f(x)
  _x = ArrayPartition(x) #ArrayPartition(x, x)
  _prob = myremake(prob, u0=x)
  sum(abs, _prob.u0[1])
end

f(u0)
Fdx = ForwardDiff.gradient(f, u0)
Zdx = Zygote.gradient(f, u0)[1]

which also works.

@ChrisRackauckas ChrisRackauckas merged commit 84ee3db into master Aug 9, 2022
@ChrisRackauckas ChrisRackauckas deleted the rrule_DEProblem branch August 9, 2022 10:54
BenChung pushed a commit to BenChung/DiffEqBase.jl that referenced this pull request Oct 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants