From 5b4279b1a88f2f8c94b8b9ec6556707b334fd8df Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 25 Nov 2024 13:38:20 +0100 Subject: [PATCH 01/11] More strict/explicit signature in step samplers --- pymc/step_methods/arraystep.py | 3 +- pymc/step_methods/hmc/base_hmc.py | 3 +- pymc/step_methods/metropolis.py | 53 ++++++++++++++++++++++++------- pymc/step_methods/slicer.py | 12 +++++-- tests/models.py | 2 +- 5 files changed, 56 insertions(+), 17 deletions(-) diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index b7da80aee0..060557ea23 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -174,8 +174,9 @@ class GradientSharedStep(ArrayStepShared): def __init__( self, vars, + *, model=None, - blocked=True, + blocked: bool = True, dtype=None, logp_dlogp_func=None, rng: RandomGenerator = None, diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 87daff649c..832fdb1f28 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -82,11 +82,12 @@ class BaseHMC(GradientSharedStep): def __init__( self, vars=None, + *, scaling=None, step_scale=0.25, is_cov=False, model=None, - blocked=True, + blocked: bool = True, potential=None, dtype=None, Emax=1000, diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 60aa33d45f..601aaf5628 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -151,6 +151,7 @@ class Metropolis(ArrayStepShared): def __init__( self, vars=None, + *, S=None, proposal_dist=None, scaling=1.0, @@ -159,7 +160,7 @@ def __init__( model=None, mode=None, rng=None, - **kwargs, + blocked: bool = False, ): """Create an instance of a Metropolis stepper. @@ -251,7 +252,7 @@ def __init__( shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) - super().__init__(vars, shared, rng=rng) + super().__init__(vars, shared, blocked=blocked, rng=rng) def reset_tuning(self): """Reset the tuned sampler parameters to their initial values.""" @@ -418,7 +419,17 @@ class BinaryMetropolis(ArrayStep): _state_class = BinaryMetropolisState - def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None, rng=None): + def __init__( + self, + vars, + *, + scaling=1.0, + tune=True, + tune_interval=100, + model=None, + rng=None, + blocked: bool = True, + ): model = pm.modelcontext(model) self.scaling = scaling @@ -432,7 +443,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None, if not all(v.dtype in pm.discrete_types for v in vars): raise ValueError("All variables must be Bernoulli for BinaryMetropolis") - super().__init__(vars, [model.compile_logp()], rng=rng) + super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -530,7 +541,16 @@ class BinaryGibbsMetropolis(ArrayStep): _state_class = BinaryGibbsMetropolisState - def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None): + def __init__( + self, + vars, + *, + order="random", + transit_p=0.8, + model=None, + rng=None, + blocked: bool = True, + ): model = pm.modelcontext(model) # Doesn't actually tune, but it's required to emit a sampler stat @@ -556,7 +576,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None): if not all(v.dtype in pm.discrete_types for v in vars): raise ValueError("All variables must be binary for BinaryGibbsMetropolis") - super().__init__(vars, [model.compile_logp()], rng=rng) + super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) def reset_tuning(self): # There are no tuning parameters in this step method. @@ -638,7 +658,14 @@ class CategoricalGibbsMetropolis(ArrayStep): _state_class = CategoricalGibbsMetropolisState def __init__( - self, vars, proposal="uniform", order="random", model=None, rng: RandomGenerator = None + self, + vars, + *, + proposal="uniform", + order="random", + model=None, + rng: RandomGenerator = None, + blocked: bool = True, ): model = pm.modelcontext(model) @@ -693,7 +720,7 @@ def __init__( # that indicates whether a draw was done in a tuning phase. self.tune = True - super().__init__(vars, [model.compile_logp()], rng=rng) + super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) def reset_tuning(self): # There are no tuning parameters in this step method. @@ -858,6 +885,7 @@ class DEMetropolis(PopulationArrayStepShared): def __init__( self, vars=None, + *, S=None, proposal_dist=None, lamb=None, @@ -867,7 +895,7 @@ def __init__( model=None, mode=None, rng=None, - **kwargs, + blocked: bool = True, ): model = pm.modelcontext(model) initial_values = model.initial_point() @@ -902,7 +930,7 @@ def __init__( shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) - super().__init__(vars, shared, rng=rng) + super().__init__(vars, shared, blocked=blocked, rng=rng) def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info @@ -1025,6 +1053,7 @@ class DEMetropolisZ(ArrayStepShared): def __init__( self, vars=None, + *, S=None, proposal_dist=None, lamb=None, @@ -1035,7 +1064,7 @@ def __init__( model=None, mode=None, rng=None, - **kwargs, + blocked: bool = True, ): model = pm.modelcontext(model) initial_values = model.initial_point() @@ -1082,7 +1111,7 @@ def __init__( shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) - super().__init__(vars, shared, rng=rng) + super().__init__(vars, shared, blocked=blocked, rng=rng) def reset_tuning(self): """Reset the tuned sampler parameters and history to their initial values.""" diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 2ea4b1f55f..c4ca03d125 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -76,7 +76,15 @@ class Slice(ArrayStepShared): _state_class = SliceState def __init__( - self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs + self, + vars=None, + *, + w=1.0, + tune=True, + model=None, + iter_limit=np.inf, + rng=None, + blocked: bool = False, # Could be true since tuning is independent across dims? ): model = modelcontext(model) self.w = np.asarray(w).copy() @@ -97,7 +105,7 @@ def __init__( self.logp = compile_pymc([raveled_inp], logp) self.logp.trust_input = True - super().__init__(vars, shared, rng=rng) + super().__init__(vars, shared, blocked=blocked, rng=rng) def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]: # The arguments are determined by the list passed via `super().__init__(..., fs, ...)` diff --git a/tests/models.py b/tests/models.py index fd45fb8bdb..abf461fa90 100644 --- a/tests/models.py +++ b/tests/models.py @@ -78,7 +78,7 @@ def arbitrary_det(value): def simple_init(): start, model, moments = simple_model() - step = Metropolis(model.value_vars, np.diag([1.0]), model=model) + step = Metropolis(model.value_vars, S=np.diag([1.0]), model=model) return model, start, step, moments From 24fbbe450c360b5e51628d3d85909f8cf3ea0f24 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 23 Nov 2024 23:03:15 +0100 Subject: [PATCH 02/11] Traces are already closed on `finally` --- pymc/sampling/mcmc.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4ee79607b7..d511fee311 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1277,14 +1277,11 @@ def _mp_sample( strace = traces[draw.chain] strace.record(draw.point, draw.stats) log_warning_stats(draw.stats) - if draw.is_last: - strace.close() if callback is not None: callback(trace=strace, draw=draw) except ps.ParallelSamplingError as error: - strace = traces[error._chain] for strace in traces: strace.close() raise From e5eacb83ab7405a7f47466e95c4ef459b0135fdd Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 22 Nov 2024 01:59:32 +0100 Subject: [PATCH 03/11] Don't recompile Ndarray function on trace slicing --- pymc/backends/base.py | 45 ++++++++++++++++++++++++++-------------- pymc/backends/ndarray.py | 12 ++++++++--- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/pymc/backends/base.py b/pymc/backends/base.py index c0239f8dec..47133b4b13 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -147,32 +147,45 @@ class BaseTrace(IBaseTrace): use different test point that might be with changed variables shapes """ - def __init__(self, name, model=None, vars=None, test_point=None): - self.name = name - + def __init__( + self, + name=None, + model=None, + vars=None, + test_point=None, + *, + fn=None, + var_shapes=None, + var_dtypes=None, + ): model = modelcontext(model) - self.model = model + if vars is None: vars = model.unobserved_value_vars unnamed_vars = {var for var in vars if var.name is None} if unnamed_vars: raise Exception(f"Can't trace unnamed variables: {unnamed_vars}") - self.vars = vars - self.varnames = [var.name for var in vars] - self.fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore") + + if fn is None: + fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore") # Get variable shapes. Most backends will need this # information. - if test_point is None: - test_point = model.initial_point() - else: - test_point_ = model.initial_point().copy() - test_point_.update(test_point) - test_point = test_point_ - var_values = list(zip(self.varnames, self.fn(test_point))) - self.var_shapes = {var: value.shape for var, value in var_values} - self.var_dtypes = {var: value.dtype for var, value in var_values} + if var_shapes is None or var_dtypes is None: + if test_point is None: + test_point = model.initial_point() + var_values = tuple(zip(vars, fn(**test_point))) + var_shapes = {var.name: value.shape for var, value in var_values} + var_dtypes = {var.name: value.dtype for var, value in var_values} + + self.name = name + self.model = model + self.fn = fn + self.vars = vars + self.varnames = [var.name for var in vars] + self.var_shapes = var_shapes + self.var_dtypes = var_dtypes self.chain = None self._is_base_setup = False self.sampler_vars = None diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index 98a11fdeca..cf90043f7c 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -40,8 +40,8 @@ class NDArray(base.BaseTrace): `model.unobserved_RVs` is used. """ - def __init__(self, name=None, model=None, vars=None, test_point=None): - super().__init__(name, model, vars, test_point) + def __init__(self, name=None, model=None, vars=None, test_point=None, **kwargs): + super().__init__(name, model, vars, test_point, **kwargs) self.draw_idx = 0 self.draws = None self.samples = {} @@ -166,7 +166,13 @@ def _slice(self, idx: slice): # Only the first `draw_idx` value are valid because of preallocation idx = slice(*idx.indices(len(self))) - sliced = NDArray(model=self.model, vars=self.vars) + sliced = type(self)( + model=self.model, + vars=self.vars, + fn=self.fn, + var_shapes=self.var_shapes, + var_dtypes=self.var_dtypes, + ) sliced.chain = self.chain sliced.samples = {varname: values[idx] for varname, values in self.samples.items()} sliced.sampler_vars = self.sampler_vars From b589ce8562fe9fb23be420f4208333bd2db8357e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 22 Nov 2024 02:04:28 +0100 Subject: [PATCH 04/11] Avoid input copy in Ndarray fn Also initialize empty trace and set `trust_input=True` --- pymc/backends/base.py | 10 +++++++++- pymc/backends/ndarray.py | 4 ++-- pymc/pytensorf.py | 7 ++++++- pymc/variational/opvi.py | 5 ++++- tests/backends/fixtures.py | 5 ++++- 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 47133b4b13..fe05b8e5ca 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -30,9 +30,11 @@ ) import numpy as np +import pytensor from pymc.backends.report import SamplerReport from pymc.model import modelcontext +from pymc.pytensorf import compile_pymc from pymc.util import get_var_name logger = logging.getLogger(__name__) @@ -168,7 +170,13 @@ def __init__( raise Exception(f"Can't trace unnamed variables: {unnamed_vars}") if fn is None: - fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore") + # borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables + fn = compile_pymc( + inputs=[pytensor.In(v, borrow=True) for v in model.value_vars], + outputs=[pytensor.Out(v, borrow=True) for v in vars], + on_unused_input="ignore", + ) + fn.trust_input = True # Get variable shapes. Most backends will need this # information. diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index cf90043f7c..079d6752fd 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -76,7 +76,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None: else: # Otherwise, make array of zeros for each variable. self.draws = draws for varname, shape in self.var_shapes.items(): - self.samples[varname] = np.zeros((draws, *shape), dtype=self.var_dtypes[varname]) + self.samples[varname] = np.empty((draws, *shape), dtype=self.var_dtypes[varname]) if sampler_vars is None: return @@ -105,7 +105,7 @@ def record(self, point, sampler_stats=None) -> None: point: dict Values mapped to variable names """ - for varname, value in zip(self.varnames, self.fn(point)): + for varname, value in zip(self.varnames, self.fn(*point.values())): self.samples[varname][self.draw_idx] = value if self._stats is not None and sampler_stats is None: diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index e3b6562f8f..67212b4f5e 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1024,7 +1024,12 @@ def compile_pymc( """ # Create an update mapping of RandomVariable's RNG so that it is automatically # updated after every function call - rng_updates = collect_default_updates(inputs=inputs, outputs=outputs) + rng_updates = collect_default_updates( + inputs=[inp.variable if isinstance(inp, pytensor.In) else inp for inp in inputs], + outputs=[ + out.variable if isinstance(out, pytensor.Out) else out for out in makeiter(outputs) + ], + ) # We always reseed random variables as this provides RNGs with no chances of collision if rng_updates: diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index b07b9ded84..a03fa58607 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1554,7 +1554,10 @@ def sample( if random_seed is not None: (random_seed,) = _get_seeds_per_chain(random_seed, 1) samples: dict = self.sample_dict_fn(draws, random_seed=random_seed) - points = ({name: records[i] for name, records in samples.items()} for i in range(draws)) + points = ( + {name: np.asarray(records[i]) for name, records in samples.items()} + for i in range(draws) + ) trace = NDArray( model=self.model, diff --git a/tests/backends/fixtures.py b/tests/backends/fixtures.py index 1b02eb8bca..c7a3bdcec3 100644 --- a/tests/backends/fixtures.py +++ b/tests/backends/fixtures.py @@ -238,7 +238,10 @@ class SamplingTestCase(ModelBackendSetupTestCase): """ def record_point(self, val): - point = {varname: np.tile(val, value.shape) for varname, value in self.test_point.items()} + point = { + varname: np.tile(val, value.shape).astype(value.dtype) + for varname, value in self.test_point.items() + } if self.sampler_vars is not None: stats = [{key: dtype(val) for key, dtype in vars.items()} for vars in self.sampler_vars] self.strace.record(point=point, sampler_stats=stats) From f979bd64cff2bef9924e810043fbeb55171cb7e4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 23 Nov 2024 11:02:56 +0100 Subject: [PATCH 05/11] Reduce attribute accesses on point record --- pymc/backends/ndarray.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index 079d6752fd..70ca60879c 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -105,17 +105,18 @@ def record(self, point, sampler_stats=None) -> None: point: dict Values mapped to variable names """ + samples = self.samples + draw_idx = self.draw_idx for varname, value in zip(self.varnames, self.fn(*point.values())): - self.samples[varname][self.draw_idx] = value + samples[varname][draw_idx] = value - if self._stats is not None and sampler_stats is None: - raise ValueError("Expected sampler_stats") - if self._stats is None and sampler_stats is not None: - raise ValueError("Unknown sampler_stats") if sampler_stats is not None: for data, vars in zip(self._stats, sampler_stats): for key, val in vars.items(): - data[key][self.draw_idx] = val + data[key][draw_idx] = val + elif self._stats is not None: + raise ValueError("Expected sampler_stats") + self.draw_idx += 1 def _get_sampler_stats( From acf5175ade72e6756c7432cef66fd9b39b51d54f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 22 Nov 2024 01:45:28 +0100 Subject: [PATCH 06/11] ValueGradFunction inner function now accepts a raveled input --- pymc/model/core.py | 80 +++++++++++++-------- pymc/sampling/mcmc.py | 2 + pymc/step_methods/arraystep.py | 18 ++--- pymc/step_methods/hmc/base_hmc.py | 6 +- pymc/step_methods/hmc/integration.py | 51 +++++++------ pymc/step_methods/hmc/nuts.py | 12 ++-- tests/distributions/test_multivariate.py | 2 +- tests/model/test_core.py | 92 ++++++++++++++++-------- tests/step_methods/hmc/test_hmc.py | 3 +- tests/step_methods/hmc/test_nuts.py | 21 ++++++ 10 files changed, 187 insertions(+), 100 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index ad60a84dfb..94756e6c06 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -61,6 +61,7 @@ gradient, hessian, inputvars, + join_nonshared_inputs, rewrite_pregrad, ) from pymc.util import ( @@ -172,6 +173,9 @@ def __init__( dtype=None, casting="no", compute_grads=True, + model=None, + initial_point=None, + ravel_inputs: bool | None = None, **kwargs, ): if extra_vars_and_values is None: @@ -219,9 +223,7 @@ def __init__( givens = [] self._extra_vars_shared = {} for var, value in extra_vars_and_values.items(): - shared = pytensor.shared( - value, var.name + "_shared__", shape=[1 if s == 1 else None for s in value.shape] - ) + shared = pytensor.shared(value, var.name + "_shared__", shape=value.shape) self._extra_vars_shared[var.name] = shared givens.append((var, shared)) @@ -231,13 +233,28 @@ def __init__( grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore") for grad_wrt, var in zip(grads, grad_vars): grad_wrt.name = f"{var.name}_grad" - outputs = [cost, *grads] + grads = pt.join(0, *[pt.atleast_1d(grad.ravel()) for grad in grads]) + outputs = [cost, grads] else: outputs = [cost] - inputs = grad_vars + if ravel_inputs: + if initial_point is None: + initial_point = modelcontext(model).initial_point() + outputs, raveled_grad_vars = join_nonshared_inputs( + point=initial_point, inputs=grad_vars, outputs=outputs, make_inputs_shared=False + ) + inputs = [raveled_grad_vars] + else: + if ravel_inputs is None: + warnings.warn( + "ValueGradFunction will become a function of raveled inputs.\n" + "Specify `ravel_inputs` to suppress this warning. Note that setting `ravel_inputs=False` will be forbidden in a future release." + ) + inputs = grad_vars self._pytensor_function = compile_pymc(inputs, outputs, givens=givens, **kwargs) + self._raveled_inputs = ravel_inputs def set_weights(self, values): if values.shape != (self._n_costs - 1,): @@ -247,7 +264,7 @@ def set_weights(self, values): def set_extra_values(self, extra_vars): self._extra_are_set = True for var in self._extra_vars: - self._extra_vars_shared[var.name].set_value(extra_vars[var.name]) + self._extra_vars_shared[var.name].set_value(extra_vars[var.name], borrow=True) def get_extra_values(self): if not self._extra_are_set: @@ -255,30 +272,21 @@ def get_extra_values(self): return {var.name: self._extra_vars_shared[var.name].get_value() for var in self._extra_vars} - def __call__(self, grad_vars, grad_out=None, extra_vars=None): + def __call__(self, grad_vars, *, extra_vars=None): if extra_vars is not None: self.set_extra_values(extra_vars) - - if not self._extra_are_set: + elif not self._extra_are_set: raise ValueError("Extra values are not set.") if isinstance(grad_vars, RaveledVars): - grad_vars = list(DictToArrayBijection.rmap(grad_vars).values()) - - cost, *grads = self._pytensor_function(*grad_vars) - - if grads: - grads_raveled = DictToArrayBijection.map( - {v.name: gv for v, gv in zip(self._grad_vars, grads)} - ) - - if grad_out is None: - return cost, grads_raveled.data + if self._raveled_inputs: + grad_vars = (grad_vars.data,) else: - np.copyto(grad_out, grads_raveled.data) - return cost - else: - return cost + grad_vars = DictToArrayBijection.rmap(grad_vars).values() + elif self._raveled_inputs and not isinstance(grad_vars, Sequence): + grad_vars = (grad_vars,) + + return self._pytensor_function(*grad_vars) @property def profile(self): @@ -521,7 +529,14 @@ def root(self): def isroot(self): return self.parent is None - def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs): + def logp_dlogp_function( + self, + grad_vars=None, + tempered=False, + initial_point=None, + ravel_inputs: bool | None = None, + **kwargs, + ): """Compile a PyTensor function that computes logp and gradient. Parameters @@ -547,13 +562,22 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs): costs = [self.logp()] input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)} - ip = self.initial_point(0) + if initial_point is None: + initial_point = self.initial_point(0) extra_vars_and_values = { - var: ip[var.name] + var: initial_point[var.name] for var in self.value_vars if var in input_vars and var not in grad_vars } - return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs) + return ValueGradFunction( + costs, + grad_vars, + extra_vars_and_values, + model=self, + initial_point=initial_point, + ravel_inputs=ravel_inputs, + **kwargs, + ) def compile_logp( self, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index d511fee311..54dd15d766 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1441,6 +1441,8 @@ def init_nuts( pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), ] + logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True) + logp_dlogp_func.trust_input = True initial_points = _init_jitter( model, initvals, diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index 060557ea23..f2b8c39ad5 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -185,17 +185,17 @@ def __init__( model = modelcontext(model) if logp_dlogp_func is None: - func = model.logp_dlogp_function(vars, dtype=dtype, **pytensor_kwargs) - else: - func = logp_dlogp_func - - self._logp_dlogp_func = func + logp_dlogp_func = model.logp_dlogp_function( + vars, + dtype=dtype, + ravel_inputs=True, + **pytensor_kwargs, + ) + logp_dlogp_func.trust_input = True - super().__init__(vars, func._extra_vars_shared, blocked, rng=rng) + self._logp_dlogp_func = logp_dlogp_func - def step(self, point) -> tuple[PointType, StatsType]: - self._logp_dlogp_func._extra_are_set = True - return super().step(point) + super().__init__(vars, logp_dlogp_func._extra_vars_shared, blocked, rng=rng) def metrop_select( diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 832fdb1f28..7195d6ee63 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -194,8 +194,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: process_start = time.process_time() p0 = self.potential.random() - p0 = RaveledVars(p0, q0.point_map_info) - start = self.integrator.compute_state(q0, p0) warning: SamplerWarning | None = None @@ -226,13 +224,13 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: if self._step_rand is not None: step_size = self._step_rand(step_size, rng=self.rng) - hmc_step = self._hamiltonian_step(start, p0.data, step_size) + hmc_step = self._hamiltonian_step(start, p0, step_size) perf_end = time.perf_counter() process_end = time.process_time() self.step_adapt.update(hmc_step.accept_stat, adapt_step) - self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune) + self.potential.update(hmc_step.end.q.data, hmc_step.end.q_grad, self.tune) if hmc_step.divergence_info: info = hmc_step.divergence_info point = None diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index 2d1e725cde..067cd239f8 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -18,13 +18,13 @@ from scipy import linalg -from pymc.blocking import RaveledVars +from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.step_methods.hmc.quadpotential import QuadPotential class State(NamedTuple): q: RaveledVars - p: RaveledVars + p: np.ndarray v: np.ndarray q_grad: np.ndarray energy: float @@ -40,23 +40,35 @@ class CpuLeapfrogIntegrator: def __init__(self, potential: QuadPotential, logp_dlogp_func): """Leapfrog integrator using CPU.""" self._potential = potential - self._logp_dlogp_func = logp_dlogp_func - self._dtype = self._logp_dlogp_func.dtype + # Sidestep logp_dlogp_function.__call__ + pytensor_function = logp_dlogp_func._pytensor_function + # Create some wrappers for backwards compatibility during transition + # When raveled_inputs=False is forbidden, func = pytensor_function + if logp_dlogp_func._raveled_inputs: + + def func(q, _): + return pytensor_function(q) + + else: + + def func(q, point_map_info): + unraveled_q = DictToArrayBijection.rmap(RaveledVars(q, point_map_info)).values() + return pytensor_function(*unraveled_q) + + self._logp_dlogp_func = func + self._dtype = logp_dlogp_func.dtype if self._potential.dtype != self._dtype: raise ValueError( f"dtypes of potential ({self._potential.dtype}) and logp function ({self._dtype})" "don't match." ) - def compute_state(self, q: RaveledVars, p: RaveledVars): + def compute_state(self, q: RaveledVars, p: np.ndarray): """Compute Hamiltonian functions using a position and momentum.""" - if q.data.dtype != self._dtype or p.data.dtype != self._dtype: - raise ValueError(f"Invalid dtype. Must be {self._dtype}") - - logp, dlogp = self._logp_dlogp_func(q) + logp, dlogp = self._logp_dlogp_func(q.data, q.point_map_info) - v = self._potential.velocity(p.data, out=None) - kinetic = self._potential.energy(p.data, velocity=v) + v = self._potential.velocity(p, out=None) + kinetic = self._potential.energy(p, velocity=v) energy = kinetic - logp return State(q, p, v, dlogp, energy, logp, 0) @@ -96,10 +108,10 @@ def _step(self, epsilon, state): axpy = linalg.blas.get_blas_funcs("axpy", dtype=self._dtype) pot = self._potential - q_new = state.q.data.copy() - p_new = state.p.data.copy() + q = state.q + q_new = q.data.copy() + p_new = state.p.copy() v_new = np.empty_like(q_new) - q_new_grad = np.empty_like(q_new) dt = 0.5 * epsilon @@ -112,19 +124,16 @@ def _step(self, epsilon, state): # q_new = q + epsilon * v_new axpy(v_new, q_new, a=epsilon) - p_new = RaveledVars(p_new, state.p.point_map_info) - q_new = RaveledVars(q_new, state.q.point_map_info) - - logp = self._logp_dlogp_func(q_new, grad_out=q_new_grad) + logp, q_new_grad = self._logp_dlogp_func(q_new, q.point_map_info) # p_new = p_new + dt * q_new_grad - axpy(q_new_grad, p_new.data, a=dt) + axpy(q_new_grad, p_new, a=dt) - kinetic = pot.velocity_energy(p_new.data, v_new) + kinetic = pot.velocity_energy(p_new, v_new) energy = kinetic - logp return State( - q_new, + RaveledVars(q_new, state.q.point_map_info), p_new, v_new, q_new_grad, diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index fb816954b6..64ee97188b 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -279,7 +279,7 @@ def __init__( self.log_accept_sum = -np.inf self.mean_tree_accept = 0.0 self.n_proposals = 0 - self.p_sum = start.p.data.copy() + self.p_sum = start.p.copy() self.max_energy_change = 0.0 def extend(self, direction): @@ -330,9 +330,9 @@ def extend(self, direction): left, right = self.left, self.right p_sum = self.p_sum turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) - p_sum1 = leftmost_p_sum + rightmost_begin.p.data + p_sum1 = leftmost_p_sum + rightmost_begin.p turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0) - p_sum2 = leftmost_end.p.data + rightmost_p_sum + p_sum2 = leftmost_end.p + rightmost_p_sum turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0) turning = turning | turning1 | turning2 @@ -372,7 +372,7 @@ def _single_step(self, left: State, epsilon: float): right.model_logp, right.index_in_trajectory, ) - tree = Subtree(right, right, right.p.data, proposal, log_size) + tree = Subtree(right, right, right.p, proposal, log_size) return tree, None, False else: error_msg = f"Energy change in leapfrog step is too large: {energy_change}." @@ -400,9 +400,9 @@ def _build_subtree(self, left, depth, epsilon): turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) # Additional U turn check only when depth > 1 to avoid redundant work. if depth - 1 > 0: - p_sum1 = tree1.p_sum + tree2.left.p.data + p_sum1 = tree1.p_sum + tree2.left.p turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0) - p_sum2 = tree1.right.p.data + tree2.p_sum + p_sum2 = tree1.right.p + tree2.p_sum turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0) turning = turning | turning1 | turning2 diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 1fd1b7e6d8..6503050c91 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -2395,7 +2395,7 @@ def test_mvnormal_no_cholesky_in_model_logp(): d2logp = m.compile_d2logp() assert not contains_cholesky_op(d2logp.f.maker.fgraph) - logp_dlogp = m.logp_dlogp_function() + logp_dlogp = m.logp_dlogp_function(ravel_inputs=True) assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph) diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 17304fedc4..41e85ce8a1 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -15,7 +15,6 @@ import pickle import threading import traceback -import unittest import warnings from unittest.mock import patch @@ -302,23 +301,26 @@ def test_empty_observed(): assert not hasattr(a.tag, "observations") -class TestValueGradFunction(unittest.TestCase): +class TestValueGradFunction: def test_no_extra(self): a = pt.vector("a") - a.tag.test_value = np.zeros(3, dtype=a.dtype) - f_grad = ValueGradFunction([a.sum()], [a], {}, mode="FAST_COMPILE") + a_ = np.zeros(3, dtype=a.dtype) + f_grad = ValueGradFunction( + [a.sum()], [a], {}, ravel_inputs=True, initial_point={"a": a_}, mode="FAST_COMPILE" + ) assert f_grad._extra_vars == [] def test_invalid_type(self): a = pt.ivector("a") - a.tag.test_value = np.zeros(3, dtype=a.dtype) + a_ = np.zeros(3, dtype=a.dtype) a.dshape = (3,) a.dsize = 3 - with pytest.raises(TypeError) as err: - ValueGradFunction([a.sum()], [a], {}, mode="FAST_COMPILE") - err.match("Invalid dtype") + with pytest.raises(TypeError, match="Invalid dtype"): + ValueGradFunction( + [a.sum()], [a], {}, ravel_inputs=True, initial_point={"a": a_}, mode="FAST_COMPILE" + ) - def setUp(self): + def setup_method(self, test_method): extra1 = pt.iscalar("extra1") extra1_ = np.array(0, dtype=extra1.dtype) extra1.dshape = () @@ -340,41 +342,68 @@ def setUp(self): self.cost = extra1 * val1.sum() + val2.sum() - self.f_grad = ValueGradFunction( - [self.cost], [val1, val2], {extra1: extra1_}, mode="FAST_COMPILE" + self.initial_point = { + "extra1": extra1_, + "val1": val1_, + "val2": val2_, + } + + with pytest.warns( + UserWarning, match="ValueGradFunction will become a function of raveled inputs" + ): + self.f_grad = ValueGradFunction( + [self.cost], + [val1, val2], + {extra1: extra1_}, + mode="FAST_COMPILE", + ) + + self.f_grad_raveled_inputs = ValueGradFunction( + [self.cost], + [val1, val2], + {extra1: extra1_}, + initial_point=self.initial_point, + mode="FAST_COMPILE", + ravel_inputs=True, ) + self.f_grad_raveled_inputs.trust_input = True - def test_extra_not_set(self): + @pytest.mark.parametrize("raveled_fn", (False, True)) + def test_extra_not_set(self, raveled_fn): + f_grad = self.f_grad_raveled_inputs if raveled_fn else self.f_grad with pytest.raises(ValueError) as err: - self.f_grad.get_extra_values() + f_grad.get_extra_values() err.match("Extra values are not set") with pytest.raises(ValueError) as err: size = self.val1_.size + self.val2_.size - self.f_grad(np.zeros(size, dtype=self.f_grad.dtype)) + f_grad(np.zeros(size, dtype=self.f_grad.dtype)) err.match("Extra values are not set") - def test_grad(self): - self.f_grad.set_extra_values({"extra1": 5}) + @pytest.mark.parametrize("raveled_fn", (False, True)) + def test_grad(self, raveled_fn): + f_grad = self.f_grad_raveled_inputs if raveled_fn else self.f_grad + f_grad.set_extra_values({"extra1": 5}) + size = self.val1_.size + self.val2_.size array = RaveledVars( np.ones(size, dtype=self.f_grad.dtype), ( - ("val1", self.val1_.shape, self.val1_.dtype), - ("val2", self.val2_.shape, self.val2_.dtype), + ("val1", self.val1_.shape, self.val1_.size, self.val1_.dtype), + ("val2", self.val2_.shape, self.val2_.size, self.val2_.dtype), ), ) - val, grad = self.f_grad(array) + + val, grad = f_grad(array) assert val == 21 npt.assert_allclose(grad, [5, 5, 5, 1, 1, 1, 1, 1, 1]) - @pytest.mark.xfail(reason="Test not refactored for v4") def test_edge_case(self): # Edge case discovered in #2948 ndim = 3 with pm.Model() as m: pm.LogNormal( - "sigma", mu=np.zeros(ndim), tau=np.ones(ndim), shape=ndim + "sigma", mu=np.zeros(ndim), tau=np.ones(ndim), initval=np.ones(ndim), shape=ndim ) # variance for the correlation matrix pm.HalfCauchy("nu", beta=10) step = pm.NUTS() @@ -382,7 +411,7 @@ def test_edge_case(self): func = step._logp_dlogp_func initial_point = m.initial_point() func.set_extra_values(initial_point) - q = func.dict_to_array(initial_point) + q = DictToArrayBijection.map(initial_point) logp, dlogp = func(q) assert logp.size == 1 assert dlogp.size == 4 @@ -398,7 +427,7 @@ def test_missing_data(self): with pytest.warns(ImputationWarning): x2 = pm.Bernoulli("x2", x1, observed=X) - gf = m.logp_dlogp_function() + gf = m.logp_dlogp_function(ravel_inputs=True) gf._extra_are_set = True assert m["x2_unobserved"].type == gf._extra_vars_shared["x2_unobserved"].type @@ -414,6 +443,8 @@ def test_missing_data(self): # Assert that all the elements of res are equal assert res[1:] == res[:-1] + +class TestPytensorRelatedLogpBugs: def test_pytensor_switch_broadcast_edge_cases_1(self): # Tests against two subtle issues related to a previous bug in Theano # where `tt.switch` would not always broadcast tensors with single @@ -460,25 +491,28 @@ def test_multiple_observed_rv(): assert model["x"] not in model.value_vars -def test_tempered_logp_dlogp(): +@pytest.mark.parametrize("ravel_inputs", (False, True)) +def test_tempered_logp_dlogp(ravel_inputs): with pm.Model() as model: pm.Normal("x") pm.Normal("y", observed=1) pm.Potential("z", pt.constant(-1.0, dtype=pytensor.config.floatX)) - func = model.logp_dlogp_function() + func = model.logp_dlogp_function(ravel_inputs=ravel_inputs) func.set_extra_values({}) - func_temp = model.logp_dlogp_function(tempered=True) + func_temp = model.logp_dlogp_function(tempered=True, ravel_inputs=ravel_inputs) func_temp.set_extra_values({}) - func_nograd = model.logp_dlogp_function(compute_grads=False) + func_nograd = model.logp_dlogp_function(compute_grads=False, ravel_inputs=ravel_inputs) func_nograd.set_extra_values({}) - func_temp_nograd = model.logp_dlogp_function(tempered=True, compute_grads=False) + func_temp_nograd = model.logp_dlogp_function( + tempered=True, compute_grads=False, ravel_inputs=ravel_inputs + ) func_temp_nograd.set_extra_values({}) - x = np.ones(1, dtype=func.dtype) + x = np.ones((1,), dtype=func.dtype) npt.assert_allclose(func(x)[0], func_temp(x)[0]) npt.assert_allclose(func(x)[1], func_temp(x)[1]) diff --git a/tests/step_methods/hmc/test_hmc.py b/tests/step_methods/hmc/test_hmc.py index 96840eed07..d228820328 100644 --- a/tests/step_methods/hmc/test_hmc.py +++ b/tests/step_methods/hmc/test_hmc.py @@ -59,10 +59,9 @@ def _hamiltonian_step(self, *args, **kwargs): step = HMC(vars=model.value_vars, model=model, scaling=scaling) - step.integrator._logp_dlogp_func.set_extra_values({}) astart = DictToArrayBijection.map(start) p = RaveledVars(floatX(step.potential.random()), astart.point_map_info) - q = RaveledVars(floatX(np.random.randn(size)), astart.point_map_info) + q = floatX(np.random.randn(size)) start = step.integrator.compute_state(p, q) for epsilon in [0.01, 0.1]: for n_steps in [1, 2, 3, 4, 20]: diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py index 2bb71b893e..d37782fb78 100644 --- a/tests/step_methods/hmc/test_nuts.py +++ b/tests/step_methods/hmc/test_nuts.py @@ -204,3 +204,24 @@ class TestRVsAssignmentNUTS(RVsAssignmentStepsTester): @pytest.mark.parametrize("step, step_kwargs", [(NUTS, {})]) def test_continuous_steps(self, step, step_kwargs): self.continuous_steps(step, step_kwargs) + + +def test_nuts_step_legacy_value_grad_function(): + # This test can be removed once ravel_inputs=False is deprecated + with pm.Model() as m: + x = pm.Normal("x", shape=(2,)) + y = pm.Normal("y", x, shape=(3, 2)) + + legacy_value_grad_fn = m.logp_dlogp_function(ravel_inputs=False, mode="FAST_COMPILE") + legacy_value_grad_fn.set_extra_values({}) + nuts = NUTS(model=m, logp_dlogp_func=legacy_value_grad_fn) + + # Confirm it is a function of multiple variables + logp, dlogp = nuts._logp_dlogp_func([np.zeros((2,)), np.zeros((3, 2))]) + np.testing.assert_allclose(dlogp, np.zeros(8)) + + # Confirm we can perform a NUTS step + ip = m.initial_point() + new_ip, _ = nuts.step(ip) + assert np.all(new_ip["x"] != ip["x"]) + assert np.all(new_ip["y"] != ip["y"]) From 812d985bcc71ee2abdf1186b33b42628b5f9a3d7 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 22 Nov 2024 02:05:40 +0100 Subject: [PATCH 07/11] Cache size in DictToArrayBijection --- pymc/blocking.py | 25 ++++++++++--------------- pymc/sampling/parallel.py | 11 ++++------- pymc/step_methods/hmc/quadpotential.py | 16 ++++++++-------- pymc/tuning/starting.py | 2 +- 4 files changed, 23 insertions(+), 31 deletions(-) diff --git a/pymc/blocking.py b/pymc/blocking.py index dcbfe0ead3..2aad656128 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -39,11 +39,11 @@ StatShape: TypeAlias = Sequence[int | None] | None -# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for +# `point_map_info` is a tuple of tuples containing `(name, shape, size, dtype)` for # each of the raveled variables. class RaveledVars(NamedTuple): data: np.ndarray - point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...] + point_map_info: tuple[tuple[str, tuple[int, ...], int, np.dtype], ...] class Compose(Generic[T]): @@ -67,10 +67,9 @@ class DictToArrayBijection: @staticmethod def map(var_dict: PointType) -> RaveledVars: """Map a dictionary of names and variables to a concatenated 1D array space.""" - vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items()) - raveled_vars = [v[0].ravel() for v in vars_info] - if raveled_vars: - result = np.concatenate(raveled_vars) + vars_info = tuple((v, k, v.shape, v.size, v.dtype) for k, v in var_dict.items()) + if vars_info: + result = np.concatenate(tuple(v[0].ravel() for v in vars_info)) else: result = np.array([]) return RaveledVars(result, tuple(v[1:] for v in vars_info)) @@ -91,19 +90,15 @@ def rmap( """ if start_point: - result = dict(start_point) + result = start_point.copy() else: result = {} - if not isinstance(array, RaveledVars): - raise TypeError("`array` must be a `RaveledVars` type") - last_idx = 0 - for name, shape, dtype in array.point_map_info: - arr_len = np.prod(shape, dtype=int) - var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype) - result[name] = var - last_idx += arr_len + for name, shape, size, dtype in array.point_map_info: + end = last_idx + size + result[name] = array.data[last_idx:end].reshape(shape).astype(dtype) + last_idx = end return result diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index a94863738a..4edc80433d 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -228,15 +228,12 @@ def __init__( self._shared_point = {} self._point = {} - for name, shape, dtype in DictToArrayBijection.map(start).point_map_info: - size = 1 - for dim in shape: - size *= int(dim) - size *= dtype.itemsize - if size != ctypes.c_size_t(size).value: + for name, shape, size, dtype in DictToArrayBijection.map(start).point_map_info: + byte_size = size * dtype.itemsize + if byte_size != ctypes.c_size_t(byte_size).value: raise ValueError(f"Variable {name} is too large") - array = mp_ctx.RawArray("c", size) + array = mp_ctx.RawArray("c", byte_size) self._shared_point[name] = (array, shape, dtype) array_np = np.frombuffer(array, dtype).reshape(shape) array_np[...] = start[name] diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index 53185bbb85..59df52fd31 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -363,11 +363,11 @@ def raise_ok(self, map_info): if np.any(self._stds == 0): errmsg = ["Mass matrix contains zeros on the diagonal. "] last_idx = 0 - for name, shape, dtype in map_info: - arr_len = np.prod(shape, dtype=int) - index = np.where(self._stds[last_idx : last_idx + arr_len] == 0)[0] + for name, shape, size, dtype in map_info: + end = last_idx + size + index = np.where(self._stds[last_idx:end] == 0)[0] errmsg.append(f"The derivative of RV `{name}`.ravel()[{index}] is zero.") - last_idx += arr_len + last_idx += end raise ValueError("\n".join(errmsg)) @@ -375,11 +375,11 @@ def raise_ok(self, map_info): errmsg = ["Mass matrix contains non-finite values on the diagonal. "] last_idx = 0 - for name, shape, dtype in map_info: - arr_len = np.prod(shape, dtype=int) - index = np.where(~np.isfinite(self._stds[last_idx : last_idx + arr_len]))[0] + for name, shape, size, dtype in map_info: + end = last_idx + size + index = np.where(~np.isfinite(self._stds[last_idx:end]))[0] errmsg.append(f"The derivative of RV `{name}`.ravel()[{index}] is non-finite.") - last_idx += arr_len + last_idx = end raise ValueError("\n".join(errmsg)) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index c085af5d25..22d3ffb415 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -143,7 +143,7 @@ def find_MAP( compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False), start) logp_func = lambda x: compiled_logp_func(RaveledVars(x, x0.point_map_info)) # noqa: E731 - rvs = [model.values_to_rvs[vars_dict[name]] for name, _, _ in x0.point_map_info] + rvs = [model.values_to_rvs[vars_dict[name]] for name, _, _, _ in x0.point_map_info] try: # This might be needed for calls to `dlogp_func` # start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars) From 8708f21279d6ff751511ef57f539273649135cee Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 22 Nov 2024 02:07:55 +0100 Subject: [PATCH 08/11] Optimize ArrayStepShared.step --- pymc/step_methods/arraystep.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index f2b8c39ad5..dda81b5403 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -99,26 +99,27 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None): :py:func:`pymc.util.get_random_generator` for more information. """ self.vars = vars + self.var_names = tuple(cast(str, var.name) for var in vars) self.shared = {get_var_name(var): shared for var, shared in shared.items()} self.blocked = blocked self.rng = get_random_generator(rng) def step(self, point: PointType) -> tuple[PointType, StatsType]: - for name, shared_var in self.shared.items(): - shared_var.set_value(point[name]) - - var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars} - q = DictToArrayBijection.map(var_dict) - + full_point = None + if self.shared: + for name, shared_var in self.shared.items(): + shared_var.set_value(point[name], borrow=True) + full_point = point + point = {name: point[name] for name in self.var_names} + + q = DictToArrayBijection.map(point) apoint, stats = self.astep(q) if not isinstance(apoint, RaveledVars): # We assume that the mapping has stayed the same apoint = RaveledVars(apoint, q.point_map_info) - new_point = DictToArrayBijection.rmap(apoint, start_point=point) - - return new_point, stats + return DictToArrayBijection.rmap(apoint, start_point=full_point), stats @abstractmethod def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: From 5efe09e4dd61e8ec34f8a0ce4d7ceb36df53dba4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 22 Nov 2024 02:11:26 +0100 Subject: [PATCH 09/11] Optimize NUTS --- pymc/step_methods/hmc/nuts.py | 83 +++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 64ee97188b..cc29e0334a 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -19,8 +19,8 @@ import numpy as np -from pymc.math import logbern -from pymc.pytensorf import floatX +from pytensor import config + from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence from pymc.step_methods.hmc import integration @@ -205,11 +205,12 @@ def _hamiltonian_step(self, start, p0, step_size): else: max_treedepth = self.max_treedepth - tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax, rng=self.rng) + rng = self.rng + tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax, rng=rng) reached_max_treedepth = False for _ in range(max_treedepth): - direction = logbern(np.log(0.5), rng=self.rng) * 2 - 1 + direction = (rng.random() < 0.5) * 2 - 1 divergence_info, turning = tree.extend(direction) if divergence_info or turning: @@ -218,9 +219,8 @@ def _hamiltonian_step(self, start, p0, step_size): reached_max_treedepth = not self.tune stats = tree.stats() - accept_stat = stats["mean_tree_accept"] stats["reached_max_treedepth"] = reached_max_treedepth - return HMCStepData(tree.proposal, accept_stat, divergence_info, stats) + return HMCStepData(tree.proposal, stats["mean_tree_accept"], divergence_info, stats) @staticmethod def competence(var, has_grad): @@ -241,6 +241,27 @@ def competence(var, has_grad): class _Tree: + __slots__ = ( + "ndim", + "integrator", + "start", + "step_size", + "Emax", + "start_energy", + "rng", + "left", + "right", + "proposal", + "depth", + "log_size", + "log_accept_sum", + "mean_tree_accept", + "n_proposals", + "p_sum", + "max_energy_change", + "floatX", + ) + def __init__( self, ndim: int, @@ -273,7 +294,7 @@ def __init__( self.rng = rng self.left = self.right = start - self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0) + self.proposal = Proposal(start.q, start.q_grad, start.energy, start.model_logp, 0) self.depth = 0 self.log_size = 0.0 self.log_accept_sum = -np.inf @@ -281,6 +302,7 @@ def __init__( self.n_proposals = 0 self.p_sum = start.p.copy() self.max_energy_change = 0.0 + self.floatX = config.floatX def extend(self, direction): """Double the treesize by extending the tree in the given direction. @@ -296,7 +318,7 @@ def extend(self, direction): """ if direction > 0: tree, diverging, turning = self._build_subtree( - self.right, self.depth, floatX(np.asarray(self.step_size)) + self.right, self.depth, np.asarray(self.step_size, dtype=self.floatX) ) leftmost_begin, leftmost_end = self.left, self.right rightmost_begin, rightmost_end = tree.left, tree.right @@ -305,7 +327,7 @@ def extend(self, direction): self.right = tree.right else: tree, diverging, turning = self._build_subtree( - self.left, self.depth, floatX(np.asarray(-self.step_size)) + self.left, self.depth, np.asarray(-self.step_size, dtype=self.floatX) ) leftmost_begin, leftmost_end = tree.right, tree.left rightmost_begin, rightmost_end = self.left, self.right @@ -318,23 +340,27 @@ def extend(self, direction): if diverging or turning: return diverging, turning - size1, size2 = self.log_size, tree.log_size - if logbern(size2 - size1, rng=self.rng): + self_log_size, tree_log_size = self.log_size, tree.log_size + if np.log(self.rng.random()) < (tree_log_size - self_log_size): self.proposal = tree.proposal - self.log_size = np.logaddexp(self.log_size, tree.log_size) - self.p_sum[:] += tree.p_sum + self.log_size = np.logaddexp(tree_log_size, self_log_size) + + p_sum = self.p_sum + p_sum[:] += tree.p_sum # Additional turning check only when tree depth > 0 to avoid redundant work if self.depth > 0: left, right = self.left, self.right - p_sum = self.p_sum turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) - p_sum1 = leftmost_p_sum + rightmost_begin.p - turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0) - p_sum2 = leftmost_end.p + rightmost_p_sum - turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0) - turning = turning | turning1 | turning2 + if not turning: + p_sum1 = leftmost_p_sum + rightmost_begin.p + turning = (p_sum1.dot(leftmost_begin.v) <= 0) or ( + p_sum1.dot(rightmost_begin.v) <= 0 + ) + if not turning: + p_sum2 = leftmost_end.p + rightmost_p_sum + turning = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0) return diverging, turning @@ -356,7 +382,10 @@ def _single_step(self, left: State, epsilon: float): if np.isnan(energy_change): energy_change = np.inf - self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change)) + self.log_accept_sum = np.logaddexp( + self.log_accept_sum, (-energy_change if energy_change > 0 else 0) + ) + # self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change)) if np.abs(energy_change) > np.abs(self.max_energy_change): self.max_energy_change = energy_change @@ -366,7 +395,7 @@ def _single_step(self, left: State, epsilon: float): # Saturated Metropolis accept probability with Boltzmann weight log_size = -energy_change proposal = Proposal( - right.q.data, + right.q, right.q_grad, right.energy, right.model_logp, @@ -399,15 +428,15 @@ def _build_subtree(self, left, depth, epsilon): p_sum = tree1.p_sum + tree2.p_sum turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) # Additional U turn check only when depth > 1 to avoid redundant work. - if depth - 1 > 0: + if (not turning) and (depth - 1 > 0): p_sum1 = tree1.p_sum + tree2.left.p - turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0) - p_sum2 = tree1.right.p + tree2.p_sum - turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0) - turning = turning | turning1 | turning2 + turning = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0) + if not turning: + p_sum2 = tree1.right.p + tree2.p_sum + turning = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0) log_size = np.logaddexp(tree1.log_size, tree2.log_size) - if logbern(tree2.log_size - log_size, rng=self.rng): + if np.log(self.rng.random()) < (tree2.log_size - log_size): proposal = tree2.proposal else: proposal = tree1.proposal From 8ac7108fa0293668d2c45f6b05a820ecfb08379e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 22 Nov 2024 02:23:24 +0100 Subject: [PATCH 10/11] Avoid recompiling initial_point and logp functions in sample Also removes default `model.check_start_vals()` --- pymc/backends/__init__.py | 7 +- pymc/model/core.py | 6 +- pymc/sampling/mcmc.py | 172 ++++++++++++++++---------- pymc/step_methods/arraystep.py | 2 + pymc/step_methods/hmc/base_hmc.py | 23 ++-- pymc/step_methods/metropolis.py | 42 ++++--- pymc/step_methods/slicer.py | 10 +- tests/helpers.py | 10 +- tests/sampling/test_mcmc.py | 57 ++++----- tests/step_methods/test_metropolis.py | 11 +- 10 files changed, 210 insertions(+), 130 deletions(-) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 986a34f4ba..e348b12747 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -72,6 +72,7 @@ from pymc.backends.arviz import predictions_to_inference_data, to_inference_data from pymc.backends.base import BaseTrace, IBaseTrace from pymc.backends.ndarray import NDArray +from pymc.blocking import PointType from pymc.model import Model from pymc.step_methods.compound import BlockedStep, CompoundStep @@ -100,11 +101,12 @@ def _init_trace( trace: BaseTrace | None, model: Model, trace_vars: list[TensorVariable] | None = None, + initial_point: PointType | None = None, ) -> BaseTrace: """Initialize a trace backend for a chain.""" strace: BaseTrace if trace is None: - strace = NDArray(model=model, vars=trace_vars) + strace = NDArray(model=model, vars=trace_vars, test_point=initial_point) elif isinstance(trace, BaseTrace): if len(trace) > 0: raise ValueError("Continuation of traces is no longer supported.") @@ -122,7 +124,7 @@ def init_traces( chains: int, expected_length: int, step: BlockedStep | CompoundStep, - initial_point: Mapping[str, np.ndarray], + initial_point: PointType, model: Model, trace_vars: list[TensorVariable] | None = None, ) -> tuple[RunType | None, Sequence[IBaseTrace]]: @@ -145,6 +147,7 @@ def init_traces( trace=backend, model=model, trace_vars=trace_vars, + initial_point=initial_point, ) for chain_number in range(chains) ] diff --git a/pymc/model/core.py b/pymc/model/core.py index 94756e6c06..f3a0658cc3 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -48,7 +48,7 @@ ShapeError, ShapeWarning, ) -from pymc.initial_point import make_initial_point_fn +from pymc.initial_point import PointType, make_initial_point_fn from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values @@ -174,7 +174,7 @@ def __init__( casting="no", compute_grads=True, model=None, - initial_point=None, + initial_point: PointType | None = None, ravel_inputs: bool | None = None, **kwargs, ): @@ -533,7 +533,7 @@ def logp_dlogp_function( self, grad_vars=None, tempered=False, - initial_point=None, + initial_point: PointType | None = None, ravel_inputs: bool | None = None, **kwargs, ): diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 54dd15d766..defa5b5383 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -101,7 +101,9 @@ def instantiate_steppers( model: Model, steps: list[Step], selected_steps: Mapping[type[BlockedStep], list[Any]], + *, step_kwargs: dict[str, dict] | None = None, + initial_point: PointType | None = None, ) -> Step | list[Step]: """Instantiate steppers assigned to the model variables. @@ -131,13 +133,22 @@ def instantiate_steppers( step_kwargs = {} used_keys = set() - for step_class, vars in selected_steps.items(): - if vars: - name = getattr(step_class, "name") - args = step_kwargs.get(name, {}) - used_keys.add(name) - step = step_class(vars=vars, model=model, **args) - steps.append(step) + if selected_steps: + if initial_point is None: + initial_point = model.initial_point() + + for step_class, vars in selected_steps.items(): + if vars: + name = getattr(step_class, "name") + kwargs = step_kwargs.get(name, {}) + used_keys.add(name) + step = step_class( + vars=vars, + model=model, + initial_point=initial_point, + **kwargs, + ) + steps.append(step) unused_args = set(step_kwargs).difference(used_keys) if unused_args: @@ -161,18 +172,22 @@ def assign_step_methods( model: Model, step: Step | Sequence[Step] | None = None, methods: Sequence[type[BlockedStep]] | None = None, - step_kwargs: dict[str, Any] | None = None, -) -> Step | list[Step]: +) -> tuple[list[Step], dict[type[BlockedStep], list[Variable]]]: """Assign model variables to appropriate step methods. - Passing a specified model will auto-assign its constituent stochastic - variables to step methods based on the characteristics of the variables. + Passing a specified model will auto-assign its constituent value + variables to step methods based on the characteristics of the respective + random variables, and whether the logp can be differentiated with respect to it. + This function is intended to be called automatically from ``sample()``, but may be called manually. Each step method passed should have a ``competence()`` method that returns an ordinal competence value corresponding to the variable passed to it. This value quantifies the appropriateness of the step method for sampling the variable. + The outputs of this function can then be passed to `instantiate_steppers()` + to initialize the assigned step samplers. + Parameters ---------- model : Model object @@ -183,24 +198,32 @@ def assign_step_methods( methods : iterable of step method classes, optional The set of step methods from which the function may choose. Defaults to the main step methods provided by PyMC. - step_kwargs : dict, optional - Parameters for the samplers. Keys are the lower case names of - the step method, values a dict of arguments. Returns ------- - methods : list - List of step methods associated with the model's variables. + provided_steps: list of Step instances + List of user provided instantiated step(s) + assigned_steps: dict of Step class to Variable + Dictionary with automatically selected step classes as keys and associated value variables as values """ - steps: list[Step] = [] + provided_steps: list[Step] = [] assigned_vars: set[Variable] = set() if step is not None: if isinstance(step, BlockedStep | CompoundStep): - steps.append(step) + provided_steps = [step] + elif isinstance(step, Sequence): + provided_steps = list(step) else: - steps.extend(step) - for step in steps: + raise ValueError(f"Step should be a Step or a sequence of Steps, got {step}") + + for step in provided_steps: + if not isinstance(step, BlockedStep | CompoundStep): + if issubclass(step, BlockedStep | CompoundStep): + raise ValueError(f"Provided {step} was not initialized") + else: + raise ValueError(f"{step} is not a Step instance") + for var in step.vars: if var not in model.value_vars: raise ValueError( @@ -235,7 +258,7 @@ def assign_step_methods( ) selected_steps.setdefault(selected, []).append(var) - return instantiate_steppers(model, steps, selected_steps, step_kwargs) + return provided_steps, selected_steps def _print_step_hierarchy(s: Step, level: int = 0) -> None: @@ -719,22 +742,23 @@ def joined_blas_limiter(): msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate." _log.warning(msg) - auto_nuts_init = True - if step is not None: - if isinstance(step, CompoundStep): - for method in step.methods: - if isinstance(method, NUTS): - auto_nuts_init = False - elif isinstance(step, NUTS): - auto_nuts_init = False - - initial_points = None - step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs) + provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS) + exclusive_nuts = ( + # User provided an instantiated NUTS step, and nothing else is needed + (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS)) + or + # Only automatically selected NUTS step is needed + ( + not provided_steps + and len(selected_steps) == 1 + and issubclass(next(iter(selected_steps)), NUTS) + ) + ) if nuts_sampler != "pymc": - if not isinstance(step, NUTS): + if not exclusive_nuts: raise ValueError( - "Model can not be sampled with NUTS alone. Your model is probably not continuous." + "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability." ) with joined_blas_limiter(): @@ -755,13 +779,11 @@ def joined_blas_limiter(): **kwargs, ) - if isinstance(step, list): - step = CompoundStep(step) - elif isinstance(step, NUTS) and auto_nuts_init: + if exclusive_nuts and not provided_steps: + # Special path for NUTS initialization if "nuts" in kwargs: nuts_kwargs = kwargs.pop("nuts") [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] - _log.info("Auto-assigning NUTS sampler...") with joined_blas_limiter(): initial_points, step = init_nuts( init=init, @@ -775,9 +797,8 @@ def joined_blas_limiter(): initvals=initvals, **kwargs, ) - - if initial_points is None: - # Time to draw/evaluate numeric start points for each chain. + else: + # Get initial points ipfns = make_initial_point_fns_per_chain( model=model, overrides=initvals, @@ -786,11 +807,16 @@ def joined_blas_limiter(): ) initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)] - # One final check that shapes and logps at the starting points are okay. - ip: dict[str, np.ndarray] - for ip in initial_points: - model.check_start_vals(ip) - _check_start_shape(model, ip) + # Instantiate automatically selected steps + step = instantiate_steppers( + model, + steps=provided_steps, + selected_steps=selected_steps, + step_kwargs=kwargs, + initial_point=initial_points[0], + ) + if isinstance(step, list): + step = CompoundStep(step) if var_names is not None: trace_vars = [v for v in model.unobserved_RVs if v.name in var_names] @@ -806,7 +832,7 @@ def joined_blas_limiter(): expected_length=draws + tune, step=step, trace_vars=trace_vars, - initial_point=ip, + initial_point=initial_points[0], model=model, ) @@ -954,7 +980,6 @@ def _sample_return( f"took {t_sampling:.0f} seconds." ) - idata = None if compute_convergence_checks or return_inferencedata: ikwargs: dict[str, Any] = {"model": model, "save_warmup": not discard_tuned_samples} ikwargs.update(idata_kwargs) @@ -1159,7 +1184,6 @@ def _iter_sample( diverging : bool Indicates if the draw is divergent. Only available with some samplers. """ - model = modelcontext(model) draws = int(draws) if draws < 1: @@ -1174,8 +1198,6 @@ def _iter_sample( if hasattr(step, "reset_tuning"): step.reset_tuning() for i in range(draws): - diverging = False - if i == 0 and hasattr(step, "iter_count"): step.iter_count = 0 if i == tune: @@ -1298,6 +1320,7 @@ def _init_jitter( seeds: Sequence[int] | np.ndarray, jitter: bool, jitter_max_retries: int, + logp_dlogp_func=None, ) -> list[PointType]: """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. @@ -1328,19 +1351,30 @@ def _init_jitter( if not jitter: return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] + model_logp_fn: Callable + if logp_dlogp_func is None: + model_logp_fn = model.compile_logp() + else: + + def model_logp_fn(ip): + q, _ = DictToArrayBijection.map(ip) + return logp_dlogp_func([q], extra_vars={})[0] + initial_points = [] for ipfn, seed in zip(ipfns, seeds): - rng = np.random.RandomState(seed) + rng = np.random.default_rng(seed) for i in range(jitter_max_retries + 1): point = ipfn(seed) - if i < jitter_max_retries: - try: + point_logp = model_logp_fn(point) + if not np.isfinite(point_logp): + if i == jitter_max_retries: + # Print informative message on last attempted point model.check_start_vals(point) - except SamplingError: - # Retry with a new seed - seed = rng.randint(2**30, dtype=np.int64) - else: - break + # Retry with a new seed + seed = rng.integers(2**30, dtype=np.int64) + else: + break + initial_points.append(point) return initial_points @@ -1436,10 +1470,12 @@ def init_nuts( _log.info(f"Initializing NUTS using {init}...") - cb = [ - pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"), - pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), - ] + cb = [] + if "advi" in init: + cb = [ + pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"), + pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), + ] logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True) logp_dlogp_func.trust_input = True @@ -1449,6 +1485,7 @@ def init_nuts( seeds=random_seed_list, jitter="jitter" in init, jitter_max_retries=jitter_max_retries, + logp_dlogp_func=logp_dlogp_func, ) apoints = [DictToArrayBijection.map(point) for point in initial_points] @@ -1562,7 +1599,14 @@ def init_nuts( else: raise ValueError(f"Unknown initializer: {init}.") - step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs) + step = pm.NUTS( + potential=potential, + model=model, + rng=random_seed_list[0], + initial_point=initial_points[0], + logp_dlogp_func=logp_dlogp_func, + **kwargs, + ) # Filter deterministics from initial_points value_var_names = [var.name for var in model.value_vars] diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index dda81b5403..7ddfb65f06 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -181,6 +181,7 @@ def __init__( dtype=None, logp_dlogp_func=None, rng: RandomGenerator = None, + initial_point: PointType | None = None, **pytensor_kwargs, ): model = modelcontext(model) @@ -190,6 +191,7 @@ def __init__( vars, dtype=dtype, ravel_inputs=True, + initial_point=initial_point, **pytensor_kwargs, ) logp_dlogp_func.trust_input = True diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 7195d6ee63..564daebed4 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -22,7 +22,7 @@ import numpy as np -from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType +from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType from pymc.exceptions import SamplingError from pymc.model import Point, modelcontext from pymc.pytensorf import floatX @@ -98,6 +98,7 @@ def __init__( adapt_step_size=True, step_rand=None, rng=None, + initial_point: PointType | None = None, **pytensor_kwargs, ): """Set up Hamiltonian samplers with common structures. @@ -138,20 +139,23 @@ def __init__( else: vars = get_value_vars_from_user_vars(vars, self._model) super().__init__( - vars, blocked=blocked, model=self._model, dtype=dtype, rng=rng, **pytensor_kwargs + vars, + blocked=blocked, + model=self._model, + dtype=dtype, + rng=rng, + initial_point=initial_point, + **pytensor_kwargs, ) self.adapt_step_size = adapt_step_size self.Emax = Emax self.iter_count = 0 - # We're using the initial/test point to determine the (initial) step - # size. - # XXX: If the dimensions of these terms change, the step size - # dimension-scaling should change as well, no? - test_point = self._model.initial_point() + if initial_point is None: + initial_point = self._model.initial_point() - nuts_vars = [test_point[v.name] for v in vars] + nuts_vars = [initial_point[v.name] for v in vars] size = sum(v.size for v in nuts_vars) self.step_size = step_scale / (size**0.25) @@ -207,7 +211,8 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.potential.raise_ok(q0.point_map_info) message_energy = ( "Bad initial energy, check any log probabilities that " - f"are inf or -inf, nan or very small:\n{error_logp}" + f"are inf or -inf, nan or very small:\n{error_logp}\n." + f"Try model.debug() to identify parametrization problems." ) warning = SamplerWarning( WarningType.BAD_ENERGY, diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 601aaf5628..8aed4bfd67 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -28,6 +28,7 @@ import pymc as pm from pymc.blocking import DictToArrayBijection, RaveledVars +from pymc.initial_point import PointType from pymc.pytensorf import ( CallableTensor, compile_pymc, @@ -160,6 +161,7 @@ def __init__( model=None, mode=None, rng=None, + initial_point: PointType | None = None, blocked: bool = False, ): """Create an instance of a Metropolis stepper. @@ -189,14 +191,15 @@ def __init__( :py:func:`pymc.util.get_random_generator` for more information. """ model = pm.modelcontext(model) - initial_values = model.initial_point() + if initial_point is None: + initial_point = model.initial_point() if vars is None: vars = model.value_vars else: vars = get_value_vars_from_user_vars(vars, model) - initial_values_shape = [initial_values[v.name].shape for v in vars] + initial_values_shape = [initial_point[v.name].shape for v in vars] if S is None: S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape))) @@ -216,7 +219,7 @@ def __init__( # Determine type of variables self.discrete = np.concatenate( - [[v.dtype in pm.discrete_types] * (initial_values[v.name].size or 1) for v in vars] + [[v.dtype in pm.discrete_types] * (initial_point[v.name].size or 1) for v in vars] ) self.any_discrete = self.discrete.any() self.all_discrete = self.discrete.all() @@ -250,8 +253,8 @@ def __init__( # TODO: This is not being used when compiling the logp function! self.mode = mode - shared = pm.make_shared_replacements(initial_values, vars, model) - self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) + shared = pm.make_shared_replacements(initial_point, vars, model) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) super().__init__(vars, shared, blocked=blocked, rng=rng) def reset_tuning(self): @@ -428,6 +431,7 @@ def __init__( tune_interval=100, model=None, rng=None, + initial_point: PointType | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -549,6 +553,7 @@ def __init__( transit_p=0.8, model=None, rng=None, + initial_point: PointType | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -561,7 +566,8 @@ def __init__( vars = get_value_vars_from_user_vars(vars, model) - initial_point = model.initial_point() + if initial_point is None: + initial_point = model.initial_point() self.dim = sum(initial_point[v.name].size for v in vars) if order == "random": @@ -665,13 +671,15 @@ def __init__( order="random", model=None, rng: RandomGenerator = None, + initial_point: PointType | None = None, blocked: bool = True, ): model = pm.modelcontext(model) vars = get_value_vars_from_user_vars(vars, model) - initial_point = model.initial_point() + if initial_point is None: + initial_point = model.initial_point() dimcats: list[tuple[int, int]] = [] # The above variable is a list of pairs (aggregate dimension, number @@ -895,11 +903,13 @@ def __init__( model=None, mode=None, rng=None, + initial_point: PointType | None = None, blocked: bool = True, ): model = pm.modelcontext(model) - initial_values = model.initial_point() - initial_values_size = sum(initial_values[n.name].size for n in model.value_vars) + if initial_point is None: + initial_point = model.initial_point() + initial_values_size = sum(initial_point[n.name].size for n in model.value_vars) if vars is None: vars = model.continuous_value_vars @@ -928,8 +938,8 @@ def __init__( self.mode = mode - shared = pm.make_shared_replacements(initial_values, vars, model) - self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) + shared = pm.make_shared_replacements(initial_point, vars, model) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) super().__init__(vars, shared, blocked=blocked, rng=rng) def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: @@ -1062,13 +1072,15 @@ def __init__( tune_interval=100, tune_drop_fraction: float = 0.9, model=None, + initial_point: PointType | None = None, mode=None, rng=None, blocked: bool = True, ): model = pm.modelcontext(model) - initial_values = model.initial_point() - initial_values_size = sum(initial_values[n.name].size for n in model.value_vars) + if initial_point is None: + initial_point = model.initial_point() + initial_values_size = sum(initial_point[n.name].size for n in model.value_vars) if vars is None: vars = model.continuous_value_vars @@ -1109,8 +1121,8 @@ def __init__( self.mode = mode - shared = pm.make_shared_replacements(initial_values, vars, model) - self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) + shared = pm.make_shared_replacements(initial_point, vars, model) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) super().__init__(vars, shared, blocked=blocked, rng=rng) def reset_tuning(self): diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index c4ca03d125..57b25e9512 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -18,6 +18,7 @@ import numpy as np from pymc.blocking import RaveledVars, StatsType +from pymc.initial_point import PointType from pymc.model import modelcontext from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared @@ -84,6 +85,7 @@ def __init__( model=None, iter_limit=np.inf, rng=None, + initial_point: PointType | None = None, blocked: bool = False, # Could be true since tuning is independent across dims? ): model = modelcontext(model) @@ -97,10 +99,12 @@ def __init__( else: vars = get_value_vars_from_user_vars(vars, model) - point = model.initial_point() - shared = make_shared_replacements(point, vars, model) + if initial_point is None: + initial_point = model.initial_point() + + shared = make_shared_replacements(initial_point, vars, model) [logp], raveled_inp = join_nonshared_inputs( - point=point, outputs=[model.logp()], inputs=vars, shared_inputs=shared + point=initial_point, outputs=[model.logp()], inputs=vars, shared_inputs=shared ) self.logp = compile_pymc([raveled_inp], logp) self.logp.trust_input = True diff --git a/tests/helpers.py b/tests/helpers.py index ae62d72d56..ba481d6763 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -198,11 +198,15 @@ def continuous_steps(self, step, step_kwargs): c1 = pm.HalfNormal("c1") c2 = pm.HalfNormal("c2") + # Test methods can handle initial_point + step_kwargs.setdefault( + "initial_point", {"c1_log__": np.array(0.5), "c2_log__": np.array(0.9)} + ) with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars - assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set( - step([c1, c2], **step_kwargs).vars - ) + assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set( + step([c1, c2], **step_kwargs).vars + ) def equal_sampling_states(this, other): diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 8ba7b133ba..3219d45b76 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -742,39 +742,35 @@ class TestAssignStepMethods: def test_bernoulli(self): """Test bernoulli distribution is assigned binary gibbs metropolis method""" with pm.Model() as model: - pm.Bernoulli("x", 0.5) - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - steps = assign_step_methods(model, []) - assert isinstance(steps, BinaryGibbsMetropolis) + x = pm.Bernoulli("x", 0.5) + _, selected_steps = assign_step_methods(model, []) + assert selected_steps == {BinaryGibbsMetropolis: [model.rvs_to_values[x]]} def test_normal(self): """Test normal distribution is assigned NUTS method""" with pm.Model() as model: - pm.Normal("x", 0, 1) - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - steps = assign_step_methods(model, []) - assert isinstance(steps, NUTS) + x = pm.Normal("x", 0, 1) + _, selected_steps = assign_step_methods(model, []) + assert selected_steps == {NUTS: [model.rvs_to_values[x]]} def test_categorical(self): """Test categorical distribution is assigned categorical gibbs metropolis method""" with pm.Model() as model: - pm.Categorical("x", np.array([0.25, 0.75])) - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - steps = assign_step_methods(model, []) - assert isinstance(steps, BinaryGibbsMetropolis) + x = pm.Categorical("x", np.array([0.25, 0.75])) + _, selected_steps = assign_step_methods(model, []) + assert selected_steps == {BinaryGibbsMetropolis: [model.rvs_to_values[x]]} + with pm.Model() as model: - pm.Categorical("y", np.array([0.25, 0.70, 0.05])) - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - steps = assign_step_methods(model, []) - assert isinstance(steps, CategoricalGibbsMetropolis) + y = pm.Categorical("y", np.array([0.25, 0.70, 0.05])) + _, selected_steps = assign_step_methods(model, []) + assert selected_steps == {CategoricalGibbsMetropolis: [model.rvs_to_values[y]]} def test_binomial(self): """Test binomial distribution is assigned metropolis method.""" with pm.Model() as model: - pm.Binomial("x", 10, 0.5) - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - steps = assign_step_methods(model, []) - assert isinstance(steps, Metropolis) + x = pm.Binomial("x", 10, 0.5) + _, selected_steps = assign_step_methods(model, []) + assert selected_steps == {Metropolis: [model.rvs_to_values[x]]} def test_normal_nograd_op(self): """Test normal distribution without an implemented gradient is assigned slice method""" @@ -791,11 +787,12 @@ def kill_grad(x): return x data = np.random.normal(size=(100,)) - pm.Normal("y", mu=kill_grad(x), sigma=1, observed=data.astype(pytensor.config.floatX)) + y = pm.Normal( + "y", mu=kill_grad(x), sigma=1, observed=data.astype(pytensor.config.floatX) + ) - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - steps = assign_step_methods(model, []) - assert isinstance(steps, Slice) + _, selected_steps = assign_step_methods(model, []) + assert selected_steps == {Slice: [model.rvs_to_values[x]]} @pytest.fixture def step_methods(self): @@ -812,18 +809,18 @@ def test_modify_step_methods(self, step_methods): with pm.Model() as model: pm.Normal("x", 0, 1) - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - steps = assign_step_methods(model, []) - assert not isinstance(steps, NUTS) + + _, selected_steps = assign_step_methods(model, []) + assert NUTS not in selected_steps # add back nuts step_methods.append(NUTS) with pm.Model() as model: pm.Normal("x", 0, 1) - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - steps = assign_step_methods(model, []) - assert isinstance(steps, NUTS) + + _, selected_steps = assign_step_methods(model, []) + assert NUTS in selected_steps def test_step_vars_in_model(self): """Test if error is raised if step variable is not found in model.value_vars""" diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index 63262759cf..259b6e0546 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -368,9 +368,18 @@ def test_discrete_steps(self, step): d1 = pm.Bernoulli("d1", p=0.5) d2 = pm.Bernoulli("d2", p=0.5) + # Test it can take initial_point as a kwarg + step_kwargs = { + "initial_point": { + "d1": np.array(0, dtype="int64"), + "d2": np.array(1, dtype="int64"), + }, + } with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): assert [m.rvs_to_values[d1]] == step([d1]).vars - assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(step([d1, d2]).vars) + assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set( + step([d1, d2]).vars + ) @pytest.mark.parametrize( "step, step_kwargs", [(Metropolis, {}), (DEMetropolis, {}), (DEMetropolisZ, {})] From bd232d2b20cb613f6e4374182648ba9e168522fb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 19 Nov 2024 14:19:12 +0100 Subject: [PATCH 11/11] Allow passing compile_kwargs to step inner functions --- pymc/sampling/mcmc.py | 16 +++++++++++++- pymc/step_methods/arraystep.py | 4 ++++ pymc/step_methods/metropolis.py | 30 ++++++++++++++++++++------- pymc/step_methods/slicer.py | 5 ++++- tests/helpers.py | 4 +++- tests/step_methods/test_metropolis.py | 12 +++++------ 6 files changed, 55 insertions(+), 16 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index defa5b5383..11cff18bb6 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -104,6 +104,7 @@ def instantiate_steppers( *, step_kwargs: dict[str, dict] | None = None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, ) -> Step | list[Step]: """Instantiate steppers assigned to the model variables. @@ -146,6 +147,7 @@ def instantiate_steppers( vars=vars, model=model, initial_point=initial_point, + compile_kwargs=compile_kwargs, **kwargs, ) steps.append(step) @@ -434,6 +436,7 @@ def sample( callback=None, mp_ctx=None, blas_cores: int | None | Literal["auto"] = "auto", + compile_kwargs: dict | None = None, **kwargs, ) -> InferenceData: ... @@ -466,6 +469,7 @@ def sample( mp_ctx=None, model: Model | None = None, blas_cores: int | None | Literal["auto"] = "auto", + compile_kwargs: dict | None = None, **kwargs, ) -> MultiTrace: ... @@ -497,6 +501,7 @@ def sample( mp_ctx=None, blas_cores: int | None | Literal["auto"] = "auto", model: Model | None = None, + compile_kwargs: dict | None = None, **kwargs, ) -> InferenceData | MultiTrace: r"""Draw samples from the posterior using the given step methods. @@ -598,6 +603,9 @@ def sample( See multiprocessing documentation for details. model : Model (optional if in ``with`` context) Model to sample from. The model needs to have free random variables. + compile_kwargs: dict, optional + Dictionary with keyword argument to pass to the functions compiled by the step methods. + Returns ------- @@ -795,6 +803,7 @@ def joined_blas_limiter(): jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, + compile_kwargs=compile_kwargs, **kwargs, ) else: @@ -814,6 +823,7 @@ def joined_blas_limiter(): selected_steps=selected_steps, step_kwargs=kwargs, initial_point=initial_points[0], + compile_kwargs=compile_kwargs, ) if isinstance(step, list): step = CompoundStep(step) @@ -1390,6 +1400,7 @@ def init_nuts( jitter_max_retries: int = 10, tune: int | None = None, initvals: StartDict | Sequence[StartDict | None] | None = None, + compile_kwargs: dict | None = None, **kwargs, ) -> tuple[Sequence[PointType], NUTS]: """Set up the mass matrix initialization for NUTS. @@ -1466,6 +1477,9 @@ def init_nuts( if init == "auto": init = "jitter+adapt_diag" + if compile_kwargs is None: + compile_kwargs = {} + random_seed_list = _get_seeds_per_chain(random_seed, chains) _log.info(f"Initializing NUTS using {init}...") @@ -1477,7 +1491,7 @@ def init_nuts( pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), ] - logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True) + logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs) logp_dlogp_func.trust_input = True initial_points = _init_jitter( model, diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index 7ddfb65f06..d15b14499c 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -182,16 +182,20 @@ def __init__( logp_dlogp_func=None, rng: RandomGenerator = None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, **pytensor_kwargs, ): model = modelcontext(model) if logp_dlogp_func is None: + if compile_kwargs is None: + compile_kwargs = {} logp_dlogp_func = model.logp_dlogp_function( vars, dtype=dtype, ravel_inputs=True, initial_point=initial_point, + **compile_kwargs, **pytensor_kwargs, ) logp_dlogp_func.trust_input = True diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 8aed4bfd67..b6d82243a1 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -162,6 +162,7 @@ def __init__( mode=None, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = False, ): """Create an instance of a Metropolis stepper. @@ -254,7 +255,7 @@ def __init__( self.mode = mode shared = pm.make_shared_replacements(initial_point, vars, model) - self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs) super().__init__(vars, shared, blocked=blocked, rng=rng) def reset_tuning(self): @@ -432,6 +433,7 @@ def __init__( model=None, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -447,7 +449,9 @@ def __init__( if not all(v.dtype in pm.discrete_types for v in vars): raise ValueError("All variables must be Bernoulli for BinaryMetropolis") - super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) + if compile_kwargs is None: + compile_kwargs = {} + super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -554,6 +558,7 @@ def __init__( model=None, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -582,7 +587,10 @@ def __init__( if not all(v.dtype in pm.discrete_types for v in vars): raise ValueError("All variables must be binary for BinaryGibbsMetropolis") - super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) + if compile_kwargs is None: + compile_kwargs = {} + + super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) def reset_tuning(self): # There are no tuning parameters in this step method. @@ -672,6 +680,7 @@ def __init__( model=None, rng: RandomGenerator = None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -728,7 +737,9 @@ def __init__( # that indicates whether a draw was done in a tuning phase. self.tune = True - super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) + if compile_kwargs is None: + compile_kwargs = {} + super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) def reset_tuning(self): # There are no tuning parameters in this step method. @@ -904,6 +915,7 @@ def __init__( mode=None, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -939,7 +951,7 @@ def __init__( self.mode = mode shared = pm.make_shared_replacements(initial_point, vars, model) - self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs) super().__init__(vars, shared, blocked=blocked, rng=rng) def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: @@ -1073,6 +1085,7 @@ def __init__( tune_drop_fraction: float = 0.9, model=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, mode=None, rng=None, blocked: bool = True, @@ -1122,7 +1135,7 @@ def __init__( self.mode = mode shared = pm.make_shared_replacements(initial_point, vars, model) - self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs) super().__init__(vars, shared, blocked=blocked, rng=rng) def reset_tuning(self): @@ -1213,6 +1226,7 @@ def delta_logp( logp: pt.TensorVariable, vars: list[pt.TensorVariable], shared: dict[pt.TensorVariable, pt.sharedvar.TensorSharedVariable], + compile_kwargs: dict | None, ) -> pytensor.compile.Function: [logp0], inarray0 = join_nonshared_inputs( point=point, outputs=[logp], inputs=vars, shared_inputs=shared @@ -1225,6 +1239,8 @@ def delta_logp( # Replace any potential duplicated RNG nodes (logp1,) = replace_rng_nodes((logp1,)) - f = compile_pymc([inarray1, inarray0], logp1 - logp0) + if compile_kwargs is None: + compile_kwargs = {} + f = compile_pymc([inarray1, inarray0], logp1 - logp0, **compile_kwargs) f.trust_input = True return f diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 57b25e9512..b84674390d 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -86,6 +86,7 @@ def __init__( iter_limit=np.inf, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = False, # Could be true since tuning is independent across dims? ): model = modelcontext(model) @@ -106,7 +107,9 @@ def __init__( [logp], raveled_inp = join_nonshared_inputs( point=initial_point, outputs=[model.logp()], inputs=vars, shared_inputs=shared ) - self.logp = compile_pymc([raveled_inp], logp) + if compile_kwargs is None: + compile_kwargs = {} + self.logp = compile_pymc([raveled_inp], logp, **compile_kwargs) self.logp.trust_input = True super().__init__(vars, shared, blocked=blocked, rng=rng) diff --git a/tests/helpers.py b/tests/helpers.py index ba481d6763..e4b6248930 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -26,6 +26,7 @@ import pytensor from numpy.testing import assert_array_less +from pytensor.compile.mode import Mode from pytensor.gradient import verify_grad as at_verify_grad import pymc as pm @@ -198,10 +199,11 @@ def continuous_steps(self, step, step_kwargs): c1 = pm.HalfNormal("c1") c2 = pm.HalfNormal("c2") - # Test methods can handle initial_point + # Test methods can handle initial_point and compile_kwargs step_kwargs.setdefault( "initial_point", {"c1_log__": np.array(0.5), "c2_log__": np.array(0.9)} ) + step_kwargs.setdefault("compile_kwargs", {"mode": Mode(linker="py", optimizer=None)}) with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set( diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index 259b6e0546..a01e75506b 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -22,6 +22,8 @@ import pytensor import pytest +from pytensor.compile.mode import Mode + import pymc as pm from pymc.step_methods.metropolis import ( @@ -368,18 +370,16 @@ def test_discrete_steps(self, step): d1 = pm.Bernoulli("d1", p=0.5) d2 = pm.Bernoulli("d2", p=0.5) - # Test it can take initial_point as a kwarg + # Test it can take initial_point, and compile_kwargs as a kwarg step_kwargs = { "initial_point": { "d1": np.array(0, dtype="int64"), "d2": np.array(1, dtype="int64"), }, + "compile_kwargs": {"mode": Mode(linker="py", optimizer=None)}, } - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - assert [m.rvs_to_values[d1]] == step([d1]).vars - assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set( - step([d1, d2]).vars - ) + assert [m.rvs_to_values[d1]] == step([d1]).vars + assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(step([d1, d2]).vars) @pytest.mark.parametrize( "step, step_kwargs", [(Metropolis, {}), (DEMetropolis, {}), (DEMetropolisZ, {})]