Skip to content

destructure's gradient is confused by trainable #72

@mcabbott

Description

@mcabbott

Example:

using Optimisers, Functors, Zygote

struct TwoThirds a; b; c; end  # from the tests
Functors.@functor TwoThirds (a, c)
Optimisers.trainable(x::TwoThirds) = (a = x.a,)

mtt = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])
v, re = destructure(mtt)
re(100v)  # TwoThirds([100.0, 200.0], [3.0], [4.0, 5.0])

gradient(mtt) do x
  w, _ = destructure(x)
  1000 * prod(w)
end  # ((a = [2000.0, 1000.0], b = nothing, c = [4.0, 5.0]),)

Here b is correct (excluded from children), but c (non-trainable child) has a value not a gradient.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions