-
Notifications
You must be signed in to change notification settings - Fork 92
Closed
Labels
Description
Requested here: https://discourse.julialang.org/t/implementation-of-spectral-normalization-for-machine-learning/76074
The workaround is to call svd(X).S
, which is slower forwards. But it looks like the gradient calculation with something like svd_rev((; U=NoTangent(), s=s, V=NoTangnet(), Vt=NoTangent()), NoTangent(), S̄, NoTangent())
is probably fairly efficient, and could easily be extracted to its own rule:
ChainRules.jl/src/rulesets/LinearAlgebra/factorization.jl
Lines 221 to 225 in 3590f94
function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) | |
F = svd(X) | |
svd_pullback(ȳ) = _svd_pullback(ȳ, F) | |
return F, svd_pullback | |
end |