Closed
Description
Requested here: https://discourse.julialang.org/t/implementation-of-spectral-normalization-for-machine-learning/76074
The workaround is to call svd(X).S
, which is slower forwards. But it looks like the gradient calculation with something like svd_rev((; U=NoTangent(), s=s, V=NoTangnet(), Vt=NoTangent()), NoTangent(), S̄, NoTangent())
is probably fairly efficient, and could easily be extracted to its own rule:
ChainRules.jl/src/rulesets/LinearAlgebra/factorization.jl
Lines 221 to 225 in 3590f94