Skip to content

Conversation

DhairyaLGandhi
Copy link
Member

Fixes

julia> FastDense(1,12).initial_params()
13-element Vector{Union{Nothing, Float32}}:
  0.6627236f0
  0.676602f0
 -0.59016824f0
 -0.53156716f0
 -0.58506316f0
  0.00657402f0
  
 -0.58002657f0
 -0.092621975f0
  0.30483305f0
  0.30343133f0
   nothing

This still needs a fix in Flux.

@ChrisRackauckas
Copy link
Member

@DhairyaLGandhi I tracked it down to:

using DiffEqFlux, Flux, OrdinaryDiffEq, DiffEqSensitivity
n = 1  # number of ODEs
x = rand(n, 5)
y = rand(n, 5)

tspan = (0.0, 1.0)
NN = Dense(n, n, tanh)
nODE = NeuralODE(NN, tspan, ROCK4(), reltol=1e-4, saveat=[tspan[end]], sensealg=InterpolatingAdjoint())
function loss_function(θ)
    z = nODE(x, θ)
    Flux.mse(y, z)
end
size(nODE(x, nODE.p)[:,1]) == size(y)
yy,back = Zygote.pullback(loss_function,nODE.p)
back(1)

I thought I fixed it with:

SciML/RecursiveArrayTools.jl#164

For some reason, the new getindex fallbacks were causing things to fail. But now, this overload catches it but has an issue. What was the old dispatch it was hitting?

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Jul 26, 2021

@ChrisRackauckas
Copy link
Member

No, the getindex methods on vector must've recently changed.

@ChrisRackauckas ChrisRackauckas merged commit 10db629 into master Jul 26, 2021
@ChrisRackauckas ChrisRackauckas deleted the dg/zeros branch July 26, 2021 11:28
@DhairyaLGandhi
Copy link
Member Author

Should we be pulling out the last index here? Seems like something that shouldn't need manual handling. ref FluxML/Flux.jl#1636

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants