Skip to content

Commit 1cdd163

Browse files
ferrineJunpeng Lao
authored and
Junpeng Lao
committed
OPVI speedup (#2759)
* fix scan op redundancy * fix clear cache * Better solution for caching * Redundant usages of memoize * Clear cache function * clear cache * fix testing * fix unused import * fix imports * fix imports
1 parent de1b8c8 commit 1cdd163

File tree

7 files changed

+98
-36
lines changed

7 files changed

+98
-36
lines changed

pymc3/memoize.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import functools
22
import pickle
33

4+
CACHE_REGISTRY = []
5+
46

57
def memoize(obj):
68
"""
79
An expensive memoizer that works with unhashables
810
"""
911
cache = obj.cache = {}
12+
CACHE_REGISTRY.append(cache)
1013

1114
@functools.wraps(obj)
1215
def memoizer(*args, **kwargs):
13-
key = (hashable(args), hashable(kwargs))
16+
# remember first argument as well, used to clear cache for particular instance
17+
key = (hashable(args[:1]), hashable(args), hashable(kwargs))
1418

1519
if key not in cache:
1620
cache[key] = obj(*args, **kwargs)
@@ -19,6 +23,27 @@ def memoizer(*args, **kwargs):
1923
return memoizer
2024

2125

26+
def clear_cache():
27+
for c in CACHE_REGISTRY:
28+
c.clear()
29+
30+
31+
class WithMemoization(object):
32+
def __hash__(self):
33+
return hash(id(self))
34+
35+
def __del__(self):
36+
# regular property call with args (self, )
37+
key = hash((self, ))
38+
to_del = []
39+
for c in CACHE_REGISTRY:
40+
for k in c.keys():
41+
if k[0] == key:
42+
to_del.append((c, k))
43+
for (c, k) in to_del:
44+
del c[k]
45+
46+
2247
def hashable(a):
2348
"""
2449
Turn some unhashable objects into hashable ones.

pymc3/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pymc3.theanof import set_theano_conf
1414
import pymc3 as pm
1515
from pymc3.math import flatten_list
16-
from .memoize import memoize
16+
from .memoize import memoize, WithMemoization
1717
from .theanof import gradient, hessian, inputvars, generator
1818
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
1919
from .blocking import DictToArrayBijection, ArrayOrdering
@@ -487,7 +487,7 @@ def _build_joined(self, cost, args, vmap):
487487
return args_joined, theano.clone(cost, replace=replace)
488488

489489

490-
class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
490+
class Model(six.with_metaclass(InitContextMeta, Context, Factor, WithMemoization)):
491491
"""Encapsulates the variables and likelihood factors of a model.
492492
493493
Model class can be used for creating class based models. To create

pymc3/tests/test_variational_inference.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import pytest
2+
import six
23
import functools
34
import operator
45
import numpy as np
56
from theano import theano, tensor as tt
67

78

89
import pymc3 as pm
10+
import pymc3.memoize
911
import pymc3.util
1012
from pymc3.theanof import change_flags
1113
from pymc3.variational.approximations import (
@@ -596,11 +598,30 @@ def test_fit_oo(inference,
596598

597599

598600
def test_profile(inference):
599-
try:
600-
inference.run_profiling(n=100).summary()
601-
except ZeroDivisionError:
602-
# weird error in SVGD, ASVGD
603-
pass
601+
inference.run_profiling(n=100).summary()
602+
603+
604+
def test_remove_scan_op():
605+
with pm.Model():
606+
pm.Normal('n', 0, 1)
607+
inference = ADVI()
608+
buff = six.StringIO()
609+
inference.run_profiling(n=10).summary(buff)
610+
assert 'theano.scan_module.scan_op.Scan' not in buff.getvalue()
611+
buff.close()
612+
613+
614+
def test_clear_cache():
615+
pymc3.memoize.clear_cache()
616+
with pm.Model():
617+
pm.Normal('n', 0, 1)
618+
inference = ADVI()
619+
inference.fit(n=10)
620+
assert len(pm.variational.opvi.Approximation.logp.fget.cache) == 1
621+
del inference
622+
assert len(pm.variational.opvi.Approximation.logp.fget.cache) == 0
623+
for c in pymc3.memoize.CACHE_REGISTRY:
624+
assert len(c) == 0
604625

605626

606627
@pytest.fixture('module')

pymc3/theanof.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from .blocking import ArrayOrdering
1010
from .data import GeneratorAdapter
11-
from .memoize import memoize
1211
from .vartypes import typefilter, continuous_types
1312

1413
__all__ = ['gradient',
@@ -85,10 +84,10 @@ def gradient1(f, v):
8584
"""flat gradient of f wrt v"""
8685
return tt.flatten(tt.grad(f, v, disconnected_inputs='warn'))
8786

87+
8888
empty_gradient = tt.zeros(0, dtype='float32')
8989

9090

91-
@memoize
9291
def gradient(f, vars=None):
9392
if vars is None:
9493
vars = cont_inputs(f)
@@ -110,7 +109,6 @@ def grad_i(i):
110109
return theano.map(grad_i, idx)[0]
111110

112111

113-
@memoize
114112
def jacobian(f, vars=None):
115113
if vars is None:
116114
vars = cont_inputs(f)
@@ -132,7 +130,6 @@ def grad_ii(i):
132130
name='jacobian_diag')[0]
133131

134132

135-
@memoize
136133
@change_flags(compute_test_value='ignore')
137134
def hessian(f, vars=None):
138135
return -jacobian(gradient(f, vars), vars)
@@ -149,7 +146,6 @@ def hess_ii(i):
149146
return theano.map(hess_ii, idx)[0]
150147

151148

152-
@memoize
153149
@change_flags(compute_test_value='ignore')
154150
def hessian_diag(f, vars=None):
155151
if vars is None:
@@ -276,6 +272,7 @@ def __call__(self, input):
276272
oldinput, = inputvars(self.tensor)
277273
return theano.clone(self.tensor, {oldinput: input}, strict=False)
278274

275+
279276
scalar_identity = IdentityOp(scalar.upgrade_to_float, name='scalar_identity')
280277
identity = tt.Elemwise(scalar_identity, name='identity')
281278

@@ -463,5 +460,3 @@ def largest_common_dtype(tensors):
463460
else smartfloatX(np.asarray(t)).dtype
464461
for t in tensors)
465462
return np.stack([np.ones((), dtype=dtype) for dtype in dtypes]).dtype
466-
467-

pymc3/variational/flows.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
import theano
33
from theano import tensor as tt
44

5-
from pymc3.distributions.dist_math import rho2sd
6-
from pymc3.theanof import change_flags
5+
from ..distributions.dist_math import rho2sd
6+
from ..theanof import change_flags
7+
from ..memoize import WithMemoization
78
from .opvi import node_property, collect_shared_to_list
89
from . import opvi
910

1011
__all__ = [
1112
'Formula',
1213
'PlanarFlow',
14+
'HouseholderFlow',
15+
'RadialFlow',
1316
'LocFlow',
1417
'ScaleFlow'
1518
]
@@ -97,7 +100,7 @@ def seems_like_flow_params(params):
97100
return False
98101

99102

100-
class AbstractFlow(object):
103+
class AbstractFlow(WithMemoization):
101104
shared_params = None
102105
__param_spec__ = dict()
103106
short_name = ''
@@ -255,6 +258,7 @@ def __repr__(self):
255258
def __str__(self):
256259
return self.short_name
257260

261+
258262
flow_for_params = AbstractFlow.flow_for_params
259263
flow_for_short_name = AbstractFlow.flow_for_short_name
260264

pymc3/variational/opvi.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646
ArrayOrdering, DictToArrayBijection, VarMap
4747
)
4848
from ..model import modelcontext
49-
from ..theanof import tt_rng, memoize, change_flags, identity
49+
from ..theanof import tt_rng, change_flags, identity
5050
from ..util import get_default_varnames
51+
from ..memoize import WithMemoization, memoize
5152

5253
__all__ = [
5354
'ObjectiveFunction',
@@ -86,10 +87,29 @@ class LocalGroupError(BatchedGroupError, AEVBInferenceError):
8687
"""Error raised in case of bad local_rv usage"""
8788

8889

90+
def append_name(name):
91+
def wrap(f):
92+
if name is None:
93+
return f
94+
95+
def inner(*args, **kwargs):
96+
res = f(*args, **kwargs)
97+
res.name = name
98+
return res
99+
return inner
100+
return wrap
101+
102+
89103
def node_property(f):
90104
"""A shortcut for wrapping method to accessible tensor
91105
"""
92-
return property(memoize(change_flags(compute_test_value='off')(f)))
106+
if isinstance(f, str):
107+
108+
def wrapper(fn):
109+
return property(memoize(change_flags(compute_test_value='off')(append_name(f)(fn))))
110+
return wrapper
111+
else:
112+
return property(memoize(change_flags(compute_test_value='off')(f)))
93113

94114

95115
@change_flags(compute_test_value='ignore')
@@ -134,7 +154,6 @@ class ObjectiveFunction(object):
134154
tf : :class:`TestFunction`
135155
OPVI TestFunction
136156
"""
137-
__hash__ = id
138157

139158
def __init__(self, op, tf):
140159
self.op = op
@@ -351,7 +370,6 @@ class Operator(object):
351370
-----
352371
For implementing custom operator it is needed to define :func:`Operator.apply` method
353372
"""
354-
__hash__ = id
355373

356374
has_test_function = False
357375
returns_loss = True
@@ -444,8 +462,6 @@ def collect_shared_to_list(params):
444462

445463

446464
class TestFunction(object):
447-
__hash__ = id
448-
449465
def __init__(self):
450466
self._inited = False
451467
self.shared_params = None
@@ -469,7 +485,7 @@ def from_function(cls, f):
469485
return obj
470486

471487

472-
class Group(object):
488+
class Group(WithMemoization):
473489
R"""**Base class for grouping variables in VI**
474490
475491
Grouped Approximation is used for modelling mutual dependencies
@@ -682,8 +698,7 @@ class Group(object):
682698
- Kingma, D. P., & Welling, M. (2014).
683699
`Auto-Encoding Variational Bayes. stat, 1050, 1. <https://arxiv.org/abs/1312.6114>`_
684700
"""
685-
__hash__ = id
686-
# need to be defined in init
701+
# needs to be defined in init
687702
shared_params = None
688703
symbolic_initial = None
689704
replacements = None
@@ -1064,14 +1079,14 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
10641079
:class:`Variable` with applied replacements, ready to use
10651080
"""
10661081
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
1067-
node_out = theano.clone(node, flat2rand, strict=False)
1082+
node_out = theano.clone(node, flat2rand)
10681083
try_to_set_test_value(node, node_out, s)
10691084
return node_out
10701085

10711086
def to_flat_input(self, node):
10721087
"""*Dev* - replace vars with flattened view stored in `self.inputs`
10731088
"""
1074-
return theano.clone(node, self.replacements, strict=False)
1089+
return theano.clone(node, self.replacements)
10751090

10761091
def symbolic_sample_over_posterior(self, node):
10771092
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
@@ -1184,11 +1199,12 @@ def cov(self):
11841199
def mean(self):
11851200
raise NotImplementedError
11861201

1202+
11871203
group_for_params = Group.group_for_params
11881204
group_for_short_name = Group.group_for_short_name
11891205

11901206

1191-
class Approximation(object):
1207+
class Approximation(WithMemoization):
11921208
"""**Wrapper for grouped approximations**
11931209
11941210
Wraps list of groups, creates an Approximation instance that collects
@@ -1217,7 +1233,6 @@ class Approximation(object):
12171233
--------
12181234
:class:`Group`
12191235
"""
1220-
__hash__ = id
12211236

12221237
def __init__(self, groups, model=None):
12231238
self._scale_cost_to_minibatch = theano.shared(np.int8(1))
@@ -1374,12 +1389,13 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
13741389
-------
13751390
:class:`Variable` with applied replacements, ready to use
13761391
"""
1392+
_node = node
13771393
optimizations = self.get_optimization_replacements(s, d)
13781394
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
13791395
node = theano.clone(node, optimizations)
1380-
node_out = theano.clone(node, flat2rand, strict=False)
1381-
try_to_set_test_value(node, node_out, s)
1382-
return node_out
1396+
node = theano.clone(node, flat2rand)
1397+
try_to_set_test_value(_node, node, s)
1398+
return node
13831399

13841400
def to_flat_input(self, node):
13851401
"""*Dev* - replace vars with flattened view stored in `self.inputs`

pymc3/variational/stein.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from theano import theano, tensor as tt
22
from pymc3.variational.opvi import node_property
33
from pymc3.variational.test_functions import rbf
4-
from pymc3.theanof import memoize, floatX, change_flags
4+
from pymc3.theanof import floatX, change_flags
5+
from pymc3.memoize import WithMemoization, memoize
56

67
__all__ = [
78
'Stein'
89
]
910

1011

11-
class Stein(object):
12+
class Stein(WithMemoization):
1213
def __init__(self, approx, kernel=rbf, use_histogram=True, temperature=1):
1314
self.approx = approx
1415
self.temperature = floatX(temperature)

0 commit comments

Comments
 (0)