Skip to content

Rule for svdvals #587

Closed
Closed
@mcabbott

Description

@mcabbott

Requested here: https://discourse.julialang.org/t/implementation-of-spectral-normalization-for-machine-learning/76074

The workaround is to call svd(X).S, which is slower forwards. But it looks like the gradient calculation with something like svd_rev((; U=NoTangent(), s=s, V=NoTangnet(), Vt=NoTangent()), NoTangent(), S̄, NoTangent()) is probably fairly efficient, and could easily be extracted to its own rule:

function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
svd_pullback(ȳ) = _svd_pullback(ȳ, F)
return F, svd_pullback
end

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions