Skip to content

Commit 10db629

Browse files
Merge pull request #592 from SciML/dg/zeros
use Flux.zeros32
2 parents 91a843b + b5cae13 commit 10db629

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/fast_layers.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ initial_params(c::FastChain) = vcat(initial_params.(c.layers)...)
2424

2525
"""
2626
FastDense(in,out,activation=identity;
27-
bias = true ,initW = Flux.glorot_uniform, initb = Flux.zeros)
27+
bias = true ,initW = Flux.glorot_uniform, initb = Flux.zeros32)
2828
2929
A Dense layer `activation.(W*x + b)` with input size `in` and output size `out`.
3030
The `activation` function defaults to `identity`, meaning the layer is an affine
@@ -41,7 +41,7 @@ struct FastDense{F,F2} <: FastLayer
4141
initial_params::F2
4242
bias::Bool
4343
function FastDense(in::Integer, out::Integer, σ = identity;
44-
bias = true, initW = Flux.glorot_uniform, initb = Flux.zeros)
44+
bias = true, initW = Flux.glorot_uniform, initb = Flux.zeros32)
4545
temp = ((bias == false) ? vcat(vec(initW(out, in))) : vcat(vec(initW(out, in)),initb(out)))
4646
initial_params() = temp
4747
new{typeof(σ),typeof(initial_params)}(out,in,σ,initial_params,bias)
@@ -108,7 +108,7 @@ initial_params(f::FastDense) = f.initial_params()
108108

109109
"""
110110
StaticDense(in,out,activation=identity;
111-
initW = Flux.glorot_uniform, initb = Flux.zeros)
111+
initW = Flux.glorot_uniform, initb = Flux.zeros32)
112112
113113
A Dense layer `activation.(W*x + b)` with input size `in` and output size `out`.
114114
The `activation` function defaults to `identity`, meaning the layer is an affine
@@ -124,7 +124,7 @@ struct StaticDense{out,in,bias,F,F2} <: FastLayer
124124
σ::F
125125
initial_params::F2
126126
function StaticDense(in::Integer, out::Integer, σ = identity;
127-
bias::Bool = true, initW = Flux.glorot_uniform, initb = Flux.zeros)
127+
bias::Bool = true, initW = Flux.glorot_uniform, initb = Flux.zeros32)
128128
temp = ((bias == true ) ? vcat(vec(initW(out, in)),initb(out)) : vcat(vec(initW(out, in))) )
129129
initial_params() = temp
130130
new{out,in,bias,typeof(σ),typeof(initial_params)}(σ,initial_params)

test/newton_neural_ode.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ NN = Chain(Dense(n, 5n, tanh),
2121
@info "ROCK4"
2222
nODE = NeuralODE(NN, tspan, ROCK4(), reltol=1e-4, saveat=[tspan[end]])
2323

24-
loss_function(θ) = Flux.mse(y, nODE(x, θ))
24+
loss_function(θ) = Flux.mse(y, nODE(x, θ)[end])
2525
l1 = loss_function(nODE.p)
2626

2727
res = DiffEqFlux.sciml_train(loss_function, nODE.p, NewtonTrustRegion(), GalacticOptim.AutoZygote(), maxiters = 100, cb=cb)
@@ -35,7 +35,7 @@ NN = FastChain(FastDense(n, 5n, tanh),
3535
@info "ROCK2"
3636
nODE = NeuralODE(NN, tspan, ROCK2(), reltol=1e-4, saveat=[tspan[end]])
3737

38-
loss_function(θ) = Flux.mse(y, nODE(x, θ))
38+
loss_function(θ) = Flux.mse(y, nODE(x, θ)[end])
3939
l1 = loss_function(nODE.p)
4040
optfunc = GalacticOptim.OptimizationFunction((x, p) -> loss_function(x), GalacticOptim.AutoZygote())
4141
optprob = GalacticOptim.OptimizationProblem(optfunc, nODE.p,)

0 commit comments

Comments
 (0)