46
46
ArrayOrdering , DictToArrayBijection , VarMap
47
47
)
48
48
from ..model import modelcontext
49
- from ..theanof import tt_rng , memoize , change_flags , identity
49
+ from ..theanof import tt_rng , change_flags , identity
50
50
from ..util import get_default_varnames
51
+ from ..memoize import WithMemoization , memoize
51
52
52
53
__all__ = [
53
54
'ObjectiveFunction' ,
@@ -86,10 +87,29 @@ class LocalGroupError(BatchedGroupError, AEVBInferenceError):
86
87
"""Error raised in case of bad local_rv usage"""
87
88
88
89
90
+ def append_name (name ):
91
+ def wrap (f ):
92
+ if name is None :
93
+ return f
94
+
95
+ def inner (* args , ** kwargs ):
96
+ res = f (* args , ** kwargs )
97
+ res .name = name
98
+ return res
99
+ return inner
100
+ return wrap
101
+
102
+
89
103
def node_property (f ):
90
104
"""A shortcut for wrapping method to accessible tensor
91
105
"""
92
- return property (memoize (change_flags (compute_test_value = 'off' )(f )))
106
+ if isinstance (f , str ):
107
+
108
+ def wrapper (fn ):
109
+ return property (memoize (change_flags (compute_test_value = 'off' )(append_name (f )(fn ))))
110
+ return wrapper
111
+ else :
112
+ return property (memoize (change_flags (compute_test_value = 'off' )(f )))
93
113
94
114
95
115
@change_flags (compute_test_value = 'ignore' )
@@ -134,7 +154,6 @@ class ObjectiveFunction(object):
134
154
tf : :class:`TestFunction`
135
155
OPVI TestFunction
136
156
"""
137
- __hash__ = id
138
157
139
158
def __init__ (self , op , tf ):
140
159
self .op = op
@@ -351,7 +370,6 @@ class Operator(object):
351
370
-----
352
371
For implementing custom operator it is needed to define :func:`Operator.apply` method
353
372
"""
354
- __hash__ = id
355
373
356
374
has_test_function = False
357
375
returns_loss = True
@@ -444,8 +462,6 @@ def collect_shared_to_list(params):
444
462
445
463
446
464
class TestFunction (object ):
447
- __hash__ = id
448
-
449
465
def __init__ (self ):
450
466
self ._inited = False
451
467
self .shared_params = None
@@ -469,7 +485,7 @@ def from_function(cls, f):
469
485
return obj
470
486
471
487
472
- class Group (object ):
488
+ class Group (WithMemoization ):
473
489
R"""**Base class for grouping variables in VI**
474
490
475
491
Grouped Approximation is used for modelling mutual dependencies
@@ -682,8 +698,7 @@ class Group(object):
682
698
- Kingma, D. P., & Welling, M. (2014).
683
699
`Auto-Encoding Variational Bayes. stat, 1050, 1. <https://arxiv.org/abs/1312.6114>`_
684
700
"""
685
- __hash__ = id
686
- # need to be defined in init
701
+ # needs to be defined in init
687
702
shared_params = None
688
703
symbolic_initial = None
689
704
replacements = None
@@ -1064,14 +1079,14 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
1064
1079
:class:`Variable` with applied replacements, ready to use
1065
1080
"""
1066
1081
flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
1067
- node_out = theano .clone (node , flat2rand , strict = False )
1082
+ node_out = theano .clone (node , flat2rand )
1068
1083
try_to_set_test_value (node , node_out , s )
1069
1084
return node_out
1070
1085
1071
1086
def to_flat_input (self , node ):
1072
1087
"""*Dev* - replace vars with flattened view stored in `self.inputs`
1073
1088
"""
1074
- return theano .clone (node , self .replacements , strict = False )
1089
+ return theano .clone (node , self .replacements )
1075
1090
1076
1091
def symbolic_sample_over_posterior (self , node ):
1077
1092
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
@@ -1184,11 +1199,12 @@ def cov(self):
1184
1199
def mean (self ):
1185
1200
raise NotImplementedError
1186
1201
1202
+
1187
1203
group_for_params = Group .group_for_params
1188
1204
group_for_short_name = Group .group_for_short_name
1189
1205
1190
1206
1191
- class Approximation (object ):
1207
+ class Approximation (WithMemoization ):
1192
1208
"""**Wrapper for grouped approximations**
1193
1209
1194
1210
Wraps list of groups, creates an Approximation instance that collects
@@ -1217,7 +1233,6 @@ class Approximation(object):
1217
1233
--------
1218
1234
:class:`Group`
1219
1235
"""
1220
- __hash__ = id
1221
1236
1222
1237
def __init__ (self , groups , model = None ):
1223
1238
self ._scale_cost_to_minibatch = theano .shared (np .int8 (1 ))
@@ -1374,12 +1389,13 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
1374
1389
-------
1375
1390
:class:`Variable` with applied replacements, ready to use
1376
1391
"""
1392
+ _node = node
1377
1393
optimizations = self .get_optimization_replacements (s , d )
1378
1394
flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
1379
1395
node = theano .clone (node , optimizations )
1380
- node_out = theano .clone (node , flat2rand , strict = False )
1381
- try_to_set_test_value (node , node_out , s )
1382
- return node_out
1396
+ node = theano .clone (node , flat2rand )
1397
+ try_to_set_test_value (_node , node , s )
1398
+ return node
1383
1399
1384
1400
def to_flat_input (self , node ):
1385
1401
"""*Dev* - replace vars with flattened view stored in `self.inputs`
0 commit comments