Skip to content

Commit f571e5d

Browse files
committed
Allow registering XTensorVariables directly in model
1 parent 0e03123 commit f571e5d

File tree

9 files changed

+152
-32
lines changed

9 files changed

+152
-32
lines changed

pymc/dims/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,15 @@ def __init__():
3636

3737
# Make PyMC aware of xtensor functionality
3838
MeasurableOp.register(XRV)
39-
lower_xtensor_query = optdb.query("+lower_xtensor")
40-
logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
41-
initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
39+
logprob_rewrites_db.register(
40+
"pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
41+
)
42+
logprob_rewrites_db.register(
43+
"post_lower_xtensor", optdb.query("+lower_xtensor"), "cleanup", position=5.1
44+
)
45+
initial_point_rewrites_db.register(
46+
"lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
47+
)
4248

4349
# TODO: Better model of probability of bugs
4450
day_of_conception = datetime.date(2025, 6, 17)

pymc/dims/distribution_core.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,20 @@
1414
from collections.abc import Callable, Sequence
1515
from itertools import chain
1616

17+
from pytensor.graph import node_rewriter
1718
from pytensor.tensor.elemwise import DimShuffle
1819
from pytensor.xtensor import as_xtensor
20+
from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor
1921
from pytensor.xtensor.type import XTensorVariable
2022

2123
from pymc import modelcontext
2224
from pymc.dims.model import with_dims
23-
from pymc.distributions import transforms
25+
from pymc.dims.transforms import log_odds_transform, log_transform
2426
from pymc.distributions.distribution import _support_point, support_point
2527
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims
28+
from pymc.logprob.abstract import MeasurableOp, _logprob
29+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
30+
from pymc.logprob.utils import filter_measurable_variables
2631
from pymc.util import UNSET
2732

2833

@@ -34,6 +39,38 @@ def dimshuffle_support_point(ds_op, _, rv):
3439
return ds_op(support_point(rv))
3540

3641

42+
@_support_point.register(XTensorFromTensor)
43+
def xtensor_from_tensor_support_point(xtensor_op, _, rv):
44+
# We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering
45+
return xtensor_op(support_point(rv))
46+
47+
48+
class MeasurableXTensorFromTensor(MeasurableOp, XTensorFromTensor):
49+
pass
50+
51+
52+
@node_rewriter([XTensorFromTensor])
53+
def find_measurable_xtensor_from_tensor(fgraph, node) -> list[XTensorVariable] | None:
54+
if isinstance(node.op, MeasurableXTensorFromTensor):
55+
return None
56+
57+
if not filter_measurable_variables(node.inputs):
58+
return None
59+
60+
return [MeasurableXTensorFromTensor(dims=node.op.dims)(*node.inputs)]
61+
62+
63+
@_logprob.register(MeasurableXTensorFromTensor)
64+
def measurable_xtensor_from_tensor(op, values, rv, **kwargs):
65+
rv_logp = _logprob(rv.owner.op, tuple(v.values for v in values), *rv.owner.inputs, **kwargs)
66+
return xtensor_from_tensor(rv_logp, dims=op.dims)
67+
68+
69+
measurable_ir_rewrites_db.register(
70+
"measurable_xtensor_from_tensor", find_measurable_xtensor_from_tensor, "basic", "xtensor"
71+
)
72+
73+
3774
class DimDistribution:
3875
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
3976

@@ -117,10 +154,10 @@ def __new__(
117154
else:
118155
# Align observed dims with those of the RV
119156
# TODO: If this fails give a more informative error message
120-
observed = observed.transpose(*rv_dims).values
157+
observed = observed.transpose(*rv_dims)
121158

122159
rv = model.register_rv(
123-
rv.values,
160+
rv,
124161
name=name,
125162
observed=observed,
126163
total_size=total_size,
@@ -177,10 +214,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
177214
class PositiveDimDistribution(DimDistribution):
178215
"""Base class for positive continuous distributions."""
179216

180-
default_transform = transforms.log
217+
default_transform = log_transform
181218

182219

183220
class UnitDimDistribution(DimDistribution):
184221
"""Base class for unit-valued distributions."""
185222

186-
default_transform = transforms.logodds
223+
default_transform = log_odds_transform

pymc/dims/transforms.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytensor.xtensor as ptx
15+
16+
from pymc.logprob.transforms import Transform
17+
18+
19+
class LogTransform(Transform):
20+
name = "log"
21+
22+
def forward(self, value, *inputs):
23+
return ptx.math.log(value)
24+
25+
def backward(self, value, *inputs):
26+
return ptx.math.exp(value)
27+
28+
def log_jac_det(self, value, *inputs):
29+
return value
30+
31+
32+
log_transform = LogTransform()
33+
34+
35+
class LogOddsTransform(Transform):
36+
name = "logodds"
37+
38+
def backward(self, value, *inputs):
39+
return ptx.math.expit(value)
40+
41+
def forward(self, value, *inputs):
42+
return ptx.math.log(value / (1 - value))
43+
44+
def log_jac_det(self, value, *inputs):
45+
sigmoid_value = ptx.math.sigmoid(value)
46+
return ptx.math.log(sigmoid_value) + ptx.math.log1p(-sigmoid_value)
47+
48+
49+
log_odds_transform = LogOddsTransform()

pymc/initial_point.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import pytensor
2121
import pytensor.tensor as pt
2222

23-
from pytensor.graph.basic import Constant, Variable
23+
from pytensor.compile.ops import TypeCastingOp
24+
from pytensor.graph.basic import Apply, Constant, Variable
2425
from pytensor.graph.fg import FunctionGraph
2526
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
2627
from pytensor.tensor.variable import TensorVariable
@@ -195,6 +196,14 @@ def inner(seed, *args, **kwargs):
195196
return make_seeded_function(func)
196197

197198

199+
class InitialPoint(TypeCastingOp):
200+
def make_node(self, var):
201+
return Apply(self, [var], [var.type()])
202+
203+
204+
initial_point_op = InitialPoint()
205+
206+
198207
def make_initial_point_expression(
199208
*,
200209
free_rvs: Sequence[TensorVariable],
@@ -235,6 +244,9 @@ def make_initial_point_expression(
235244

236245
# Clone free_rvs so we don't modify the original graph
237246
initial_point_fgraph = FunctionGraph(outputs=free_rvs, clone=True)
247+
# Wrap each rv in an initial_point Operation to avoid losing dependency between the RVs
248+
replacements = tuple((rv, initial_point_op(rv)) for rv in initial_point_fgraph.outputs)
249+
toposort_replace(initial_point_fgraph, replacements, reverse=True)
238250

239251
# Apply any rewrites necessary to compute the initial points.
240252
initial_point_rewriter = initial_point_rewrites_db.query(initial_point_basic_query)
@@ -254,10 +266,10 @@ def make_initial_point_expression(
254266
if isinstance(strategy, str):
255267
if strategy == "support_point":
256268
try:
257-
value = support_point(variable)
269+
value = support_point(variable.owner.inputs[0])
258270
except NotImplementedError:
259271
warnings.warn(
260-
f"Moment not defined for variable {variable} of type "
272+
f"support_point not defined for variable {variable} of type "
261273
f"{variable.owner.op.__class__.__name__}, defaulting to "
262274
f"a draw from the prior. This can lead to difficulties "
263275
f"during tuning. You can manually define an initval or "

pymc/logprob/basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def normal_logp(value, mu, sigma):
197197
[ir_valued_var] = fgraph.outputs
198198
[ir_rv, ir_value] = ir_valued_var.owner.inputs
199199
expr = _logprob_helper(ir_rv, ir_value, **kwargs)
200-
cleanup_ir([expr])
200+
[expr] = cleanup_ir([expr])
201201
if warn_rvs:
202202
_warn_rvs_in_inferred_graph(expr)
203203
return expr
@@ -297,7 +297,7 @@ def normal_logcdf(value, mu, sigma):
297297
[ir_valued_rv] = fgraph.outputs
298298
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
299299
expr = _logcdf_helper(ir_rv, ir_value, **kwargs)
300-
cleanup_ir([expr])
300+
[expr] = cleanup_ir([expr])
301301
if warn_rvs:
302302
_warn_rvs_in_inferred_graph(expr)
303303
return expr
@@ -379,7 +379,7 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> Tens
379379
[ir_valued_rv] = fgraph.outputs
380380
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
381381
expr = _icdf_helper(ir_rv, ir_value, **kwargs)
382-
cleanup_ir([expr])
382+
[expr] = cleanup_ir([expr])
383383
if warn_rvs:
384384
_warn_rvs_in_inferred_graph(expr)
385385
return expr
@@ -540,15 +540,15 @@ def conditional_logp(
540540
f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
541541
)
542542

543-
logprobs = list(values_to_logprobs.values())
544-
cleanup_ir(logprobs)
543+
values, logprobs = zip(*values_to_logprobs.items())
544+
logprobs = cleanup_ir(logprobs)
545545

546546
if warn_rvs:
547547
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
548548
if rvs_in_logp_expressions:
549549
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)
550550

551-
return values_to_logprobs
551+
return dict(zip(values, logprobs))
552552

553553

554554
def transformed_conditional_logp(

pymc/logprob/rewriting.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def remove_DiracDelta(fgraph, node):
133133

134134

135135
logprob_rewrites_basic_query = RewriteDatabaseQuery(include=["basic"])
136+
logprob_rewrites_cleanup_query = RewriteDatabaseQuery(include=["cleanup"])
137+
136138
logprob_rewrites_db = SequenceDB()
137139
logprob_rewrites_db.name = "logprob_rewrites_db"
138140

@@ -276,10 +278,11 @@ def construct_ir_fgraph(
276278
return fgraph
277279

278280

279-
def cleanup_ir(vars: Sequence[Variable]) -> None:
281+
def cleanup_ir(vars: Sequence[Variable]) -> Sequence[Variable]:
280282
fgraph = FunctionGraph(outputs=vars, clone=False)
281-
ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["cleanup"]))
283+
ir_rewriter = logprob_rewrites_db.query(logprob_rewrites_cleanup_query)
282284
ir_rewriter.rewrite(fgraph)
285+
return fgraph.outputs
283286

284287

285288
def assume_valued_outputs(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]:

pymc/model/core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from pytensor.compile import DeepCopyOp, Function, ProfileStats, get_mode
3636
from pytensor.compile.sharedvalue import SharedVariable
3737
from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs
38+
from pytensor.tensor import as_tensor
39+
from pytensor.tensor.math import variadic_add
3840
from pytensor.tensor.random.op import RandomVariable
3941
from pytensor.tensor.random.type import RandomType
4042
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -232,7 +234,9 @@ def __init__(
232234
grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore")
233235
for grad_wrt, var in zip(grads, grad_vars):
234236
grad_wrt.name = f"{var.name}_grad"
235-
grads = pt.join(0, *[pt.atleast_1d(grad.ravel()) for grad in grads])
237+
grads = pt.join(
238+
0, *[as_tensor(grad, allow_xtensor_conversion=True).ravel() for grad in grads]
239+
)
236240
outputs = [cost, grads]
237241
else:
238242
outputs = [cost]
@@ -702,7 +706,9 @@ def logp(
702706
if not sum:
703707
return logp_factors
704708

705-
logp_scalar = pt.sum([pt.sum(factor) for factor in logp_factors])
709+
logp_scalar = variadic_add(
710+
*(as_tensor(factor, allow_xtensor_conversion=True).sum() for factor in logp_factors)
711+
)
706712
logp_scalar_name = "__logp" if jacobian else "__logp_nojac"
707713
if self.name:
708714
logp_scalar_name = f"{logp_scalar_name}_{self.name}"
@@ -1322,7 +1328,7 @@ def make_obs_var(
13221328
else:
13231329
if sps.issparse(data):
13241330
data = sparse.basic.as_sparse(data, name=name)
1325-
else:
1331+
elif not isinstance(data, Variable):
13261332
data = pt.as_tensor_variable(data, name=name)
13271333

13281334
if total_size:

pymc/pytensorf.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
4646
from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
4747
from pytensor.tensor.rewriting.shape import ShapeFeature
48-
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
48+
from pytensor.tensor.sharedvar import SharedVariable
4949
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
5050
from pytensor.tensor.variable import TensorVariable
5151

@@ -299,7 +299,9 @@ def smarttypeX(x):
299299

300300
def gradient1(f, v):
301301
"""Flat gradient of f wrt v."""
302-
return pt.flatten(grad(f, v, disconnected_inputs="warn"))
302+
return pt.as_tensor(
303+
grad(f, v, disconnected_inputs="warn"), allow_xtensor_conversion=True
304+
).ravel()
303305

304306

305307
empty_gradient = pt.zeros(0, dtype="float32")
@@ -418,11 +420,11 @@ def make_shared_replacements(point, vars, model):
418420

419421
def join_nonshared_inputs(
420422
point: dict[str, np.ndarray],
421-
outputs: list[TensorVariable],
422-
inputs: list[TensorVariable],
423-
shared_inputs: dict[TensorVariable, TensorSharedVariable] | None = None,
423+
outputs: Sequence[Variable],
424+
inputs: Sequence[Variable],
425+
shared_inputs: dict[Variable, Variable] | None = None,
424426
make_inputs_shared: bool = False,
425-
) -> tuple[list[TensorVariable], TensorVariable]:
427+
) -> tuple[Sequence[Variable], TensorVariable]:
426428
"""
427429
Create new outputs and input TensorVariables where the non-shared inputs are joined in a single raveled vector input.
428430
@@ -547,7 +549,9 @@ def join_nonshared_inputs(
547549
if not inputs:
548550
raise ValueError("Empty list of input variables.")
549551

550-
raveled_inputs = pt.concatenate([var.ravel() for var in inputs])
552+
raveled_inputs = pt.concatenate(
553+
[pt.as_tensor(var, allow_xtensor_conversion=True).ravel() for var in inputs]
554+
)
551555

552556
if not make_inputs_shared:
553557
tensor_type = raveled_inputs.type
@@ -559,12 +563,15 @@ def join_nonshared_inputs(
559563
if pytensor.config.compute_test_value != "off":
560564
joined_inputs.tag.test_value = raveled_inputs.tag.test_value
561565

562-
replace: dict[TensorVariable, TensorVariable] = {}
566+
replace: dict[Variable, Variable] = {}
563567
last_idx = 0
564568
for var in inputs:
565569
shape = point[var.name].shape
566570
arr_len = np.prod(shape, dtype=int)
567-
replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
571+
replacement_var = (
572+
joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
573+
)
574+
replace[var] = var.type.filter_variable(replacement_var)
568575
last_idx += arr_len
569576

570577
if shared_inputs is not None:

tests/dims/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_simple_model():
7373
np.testing.assert_allclose(draw, draw_same)
7474
assert not np.allclose(draw, draw_diff)
7575

76-
observed_values = DataArray(np.ones((3, 5)), dims=("a", "b")).transpose()
76+
observed_values = DataArray(np.ones((3, 5)), dims=("a", "b"))
7777
with observe(xmodel, {"y": observed_values}):
7878
pm.sample_prior_predictive()
7979
idata = pm.sample(

0 commit comments

Comments
 (0)