-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Added base class for variational methods #1600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+1,192
−27
Closed
Changes from all commits
Commits
Show all changes
87 commits
Select commit
Hold shift + click to select a range
2d6fee8
Added mode argument to several step methods and advi to allow mode se…
fonnesbeck 4e5b9c2
Fixed namespace bugs in mode attribute
fonnesbeck 0ebaacd
Reverted function in delta_logp to not accept mode argument
fonnesbeck 55c8ce6
ENH User model (#1525)
ferrine 9ab04da
added new elbo implementation
ferrine 9811220
Added mode argument to several step methods and advi to allow mode se…
fonnesbeck fbd1d5b
Fixed namespace bugs in mode attribute
fonnesbeck 208aa79
Reverted function in delta_logp to not accept mode argument
fonnesbeck 40d0146
ENH User model (#1525)
ferrine fc0673b
ENH User model (#1525)
ferrine ea82ebd
Refactor Hamiltonian methods into single class
140a80c
Reformat docs
168b113
added replacements class and mean field approximation
ferrine c1211a6
moved local to local constructor
ferrine 9690562
property for deterministic replacements
ferrine 34da7c8
refactored replacements to make them more unitary
ferrine 07a248a
shape problem when sampling
ferrine 889b50e
tests passed
ferrine 0d486fb
deleted unused modules
ferrine 125f6ad
added replacement names for global/local dict
ferrine 1af91c0
Merge branch '3.1' into refactor_advi
ferrine 69f07a1
refactored replacements
ferrine 9614bf9
refactored replacements
ferrine 32a2eb7
refactored GARCH and added Mv(Gaussian/StudentT)RandomWalk (#1603)
ferrine 5e68b95
Merge branch '3.1' into refactor_advi
ferrine 0f2c38f
added flatten_list
ferrine 63e57d7
added tests
ferrine 4d4cb82
refactored local/global dicts
ferrine 82c7996
moved __type__ assignment to better place
ferrine 2cd6bc5
Don't do replacements too early or else it will be not possible to tr…
ferrine 4d810f2
refactored docs
ferrine 16a226b
fixed memory consumption during test
ferrine d8e9886
set nmc samples to 1000 in test
ferrine 1bb349e
optimized code a lot
ferrine 87e7e2d
changed expectations to sampling, added docs
ferrine 9eb79a0
code style
ferrine be1ca80
validate model
ferrine e8f6644
added tests for dynamic number of samples
ferrine 4add3bc
added `set_params` method
ferrine 43a8638
added `params` property
ferrine 6a88fde
ENH KL-weighting
taku-y a3bad35
Fix bugs
taku-y 7ed2cb5
Remove unnecessary comments
taku-y 163b1be
Fix typo
taku-y fad9410
Minor fixes
taku-y e1a88e0
Check transformed RVs using hasattr
taku-y ae349e9
Update conv-vae notebook
taku-y 9e237ef
Implementation of path derivative gradient estimator (NIPS 2016) #1615
ferrine 63c1285
local vars nee this path trick too
ferrine 02f5fa6
bug in local size calculation
ferrine 8cc9558
bug in global subset view
ferrine 4e302e6
improved performance
ferrine e5df6ee
changed the way for calling posterior
ferrine 23ed175
deleted accidental added nuts file
ferrine 7a7cdc3
Merge remote-tracking branch 'upstream/3.1' into refactor_advi
ferrine ac949d2
changed zero grad usage
ferrine 26adf3b
refactor apply replacements
ferrine 63000fb
added useful functions to replacements
ferrine 5240260
added `approximate` function
ferrine 7802a78
changed name MeanFieald to Advi
ferrine 2407d78
added docs, renamed classes
ferrine fbf26d4
add deterministics to posterior to point function
ferrine c394d5e
trying to fix reweighting
ferrine 2162d4c
weight log_p_W{local|global} correctly
ferrine 7609f72
local and global weighting
ferrine d55d258
added docs
ferrine dca919c
preparing mnist vae, fixed bugs
ferrine cb2e219
Took in account suggestions for refactoring
ferrine d94e7e7
refactored dist math
ferrine 8d1f088
Added mode argument to several step methods and advi to allow mode se…
fonnesbeck 37843af
Created Generator Op with simple Test
ferrine 3dc6f1b
added ndim test
ferrine a16512e
updated test
ferrine 7127c23
updated test, added test value check
ferrine 23b14ff
added test for replacing generator with shared variable
ferrine 96cd5bb
added shortcut for generator op
ferrine 633e4e9
refactored test
ferrine 0629adc
added population kwarg (no tests yet)
ferrine 06099a2
added population kwarg for free var(autoencoder case)
ferrine 75a4849
Revert "Added mode argument to several step methods and advi to allow…
ferrine ff325d8
add docstring to generator Op
ferrine 79ac934
rename population -> total_size
ferrine 57dbe47
update docstrings in model
ferrine f8bce58
fix typo in `as_tensor` function
ferrine 244bf21
Merge branch 'generator_op' into refactor_advi
ferrine 8d91fee
add simple test for density scaling via `total_size`
ferrine 1a9fa3d
raise an error when density scaling is done on scalar
ferrine File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,9 @@ | |
|
||
from .special import gammaln, multigammaln | ||
|
||
c = - 0.5 * np.log(2 * np.pi) | ||
|
||
|
||
def bound(logp, *conditions, **kwargs): | ||
""" | ||
Bounds a log probability density with several conditions. | ||
|
@@ -95,3 +98,64 @@ def i1(x): | |
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600, | ||
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3) | ||
+ 14175 / (98304 * x**4))) | ||
|
||
|
||
def sd2rho(sd): | ||
""" | ||
`sd -> rho` theano converter | ||
:math:`mu + sd*e = mu + log(1+exp(rho))*e`""" | ||
return tt.log(tt.exp(sd) - 1) | ||
|
||
|
||
def rho2sd(rho): | ||
""" | ||
`rho -> sd` theano converter | ||
:math:`mu + sd*e = mu + log(1+exp(rho))*e`""" | ||
return tt.log1p(tt.exp(rho)) | ||
|
||
|
||
def log_normal(x, mean, **kwargs): | ||
""" | ||
Calculate logarithm of normal distribution at point `x` | ||
with given `mean` and `std` | ||
Parameters | ||
---------- | ||
x : Tensor | ||
point of evaluation | ||
mean : Tensor | ||
mean of normal distribution | ||
kwargs : one of parameters `{sd, tau, w, rho}` | ||
Notes | ||
----- | ||
There are four variants for density parametrization. | ||
They are: | ||
1) standard deviation - `std` | ||
2) `w`, logarithm of `std` :math:`w = log(std)` | ||
3) `rho` that follows this equation :math:`rho = log(exp(std) - 1)` | ||
4) `tau` that follows this equation :math:`tau = std^{-1}` | ||
---- | ||
""" | ||
sd = kwargs.get('sd') | ||
w = kwargs.get('w') | ||
rho = kwargs.get('rho') | ||
tau = kwargs.get('tau') | ||
eps = kwargs.get('eps', 0.0) | ||
check = sum(map(lambda a: a is not None, [sd, w, rho, tau])) | ||
if check > 1: | ||
raise ValueError('more than one required kwarg is passed') | ||
if check == 0: | ||
raise ValueError('none of required kwarg is passed') | ||
if sd is not None: | ||
std = sd | ||
elif w is not None: | ||
std = tt.exp(w) | ||
elif rho is not None: | ||
std = rho2sd(rho) | ||
else: | ||
std = tau**(-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
std += eps | ||
return c - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2 * std ** 2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is not used, I can delete it