-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adding logdet #1777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding logdet #1777
Changes from all commits
95b2046
9c034c7
5b0efc9
f2729f2
a3602bb
3145e97
ffa85be
11b6990
1e0d9b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,20 @@ | ||
import numpy as np | ||
import theano | ||
|
||
from .vartypes import typefilter, continuous_types | ||
from theano import theano, scalar, tensor as tt | ||
from theano.gof.graph import inputs | ||
from theano.gof import Op | ||
from theano.gof import Op, Apply | ||
from theano.configparser import change_flags | ||
from theano.tensor.nlinalg import matrix_inverse | ||
from .memoize import memoize | ||
from .blocking import ArrayOrdering | ||
from .data import DataGenerator | ||
|
||
__all__ = ['gradient', 'hessian', 'hessian_diag', 'inputvars', | ||
'cont_inputs', 'floatX', 'jacobian', | ||
'CallableTensor', 'join_nonshared_inputs', | ||
'make_shared_replacements', 'generator'] | ||
'make_shared_replacements', 'generator', 'LogDet', 'logdet'] | ||
|
||
|
||
def inputvars(a): | ||
|
@@ -59,11 +61,44 @@ def floatX(X): | |
Theano derivative functions | ||
""" | ||
|
||
class LogDet(Op): | ||
"""Computes the logarithm of absolute determinant of a square | ||
matrix M, log(abs(det(M))), on CPU. Avoids det(M) overflow/ | ||
underflow. | ||
|
||
Note: Once PR #3959 (https://github.com/Theano/Theano/pull/3959/) by harpone is merged, | ||
this must be removed. | ||
""" | ||
def make_node(self, x): | ||
x = theano.tensor.as_tensor_variable(x) | ||
o = theano.tensor.scalar(dtype=x.dtype) | ||
return Apply(self, [x], [o]) | ||
|
||
def perform(self, node, inputs, outputs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Signature of overriden method changed, should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not crucial thing, so you can take it in account only if you decide moving |
||
try: | ||
(x,) = inputs | ||
(z,) = outputs | ||
s = np.linalg.svd(x, compute_uv=False) | ||
log_det = np.sum(np.log(np.abs(s))) | ||
z[0] = np.asarray(log_det, dtype=x.dtype) | ||
except Exception: | ||
print('Failed to compute logdet of {}.'.format(x)) | ||
raise | ||
|
||
def grad(self, inputs, g_outputs): | ||
[gz] = g_outputs | ||
[x] = inputs | ||
return [gz * matrix_inverse(x).T] | ||
|
||
def __str__(self): | ||
return "LogDet" | ||
|
||
logdet = LogDet() | ||
|
||
def gradient1(f, v): | ||
"""flat gradient of f wrt v""" | ||
return tt.flatten(tt.grad(f, v, disconnected_inputs='warn')) | ||
|
||
|
||
empty_gradient = tt.zeros(0, dtype='float32') | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the first term can be moved up top, then the if-clause just does -logdet or -tt.log(det(tau)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use this code in future, it worked for me