diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 76d59a87a5..6891823576 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -392,6 +392,13 @@ def make_node(self, rng, size, *dist_params): out_type = TensorType(dtype=self.dtype, shape=static_shape) outputs = (rng.type(), out_type()) + if self.dtype == "floatX": + # Commit to a specific float type if the Op is still using "floatX" + dtype = config.floatX + props = self._props_dict() + props["dtype"] = dtype + self = type(self)(**props) + return Apply(self, inputs, outputs) def batch_ndim(self, node: Apply) -> int: diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 29a9f5f996..7f1b9ecddb 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -1,7 +1,7 @@ import warnings import pytensor.xtensor.rewriting -from pytensor.xtensor import linalg +from pytensor.xtensor import linalg, random from pytensor.xtensor.math import dot from pytensor.xtensor.shape import concat from pytensor.xtensor.type import ( diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index fbbd75ef68..ad6f22bf51 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -101,7 +101,7 @@ def _as_xelemwise(core_op: ScalarOp) -> XElemwise: maximum = _as_xelemwise(ps.scalar_maximum) minimum = _as_xelemwise(ps.scalar_minimum) second = _as_xelemwise(ps.second) -sigmoid = _as_xelemwise(ps.sigmoid) +sigmoid = expit = _as_xelemwise(ps.sigmoid) sign = _as_xelemwise(ps.sign) sin = _as_xelemwise(ps.sin) sinh = _as_xelemwise(ps.sinh) diff --git a/pytensor/xtensor/random.py b/pytensor/xtensor/random.py index 8ee1be072b..8f24ae24e1 100644 --- a/pytensor/xtensor/random.py +++ b/pytensor/xtensor/random.py @@ -5,8 +5,8 @@ import pytensor.tensor.random.basic as ptr from pytensor.graph.basic import Variable from pytensor.tensor.random.op import RandomVariable -from pytensor.xtensor import as_xtensor from pytensor.xtensor.math import sqrt +from pytensor.xtensor.type import as_xtensor from pytensor.xtensor.vectorization import XRV @@ -14,6 +14,7 @@ def _as_xrv( core_op: RandomVariable, core_inps_dims_map: Sequence[Sequence[int]] | None = None, core_out_dims_map: Sequence[int] | None = None, + name: str | None = None, ): """Helper function to define an XRV constructor. @@ -41,7 +42,14 @@ def _as_xrv( core_out_dims_map = tuple(range(core_op.ndim_supp)) core_dims_needed = max( - (*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0 + max( + ( + max((entry + 1 for entry in dims_map), default=0) + for dims_map in core_inps_dims_map + ), + default=0, + ), + max((entry + 1 for entry in core_out_dims_map), default=0), ) @wraps(core_op) @@ -76,7 +84,10 @@ def xrv_constructor( extra_dims = {} return XRV( - core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys()) + core_op, + core_dims=full_core_dims, + extra_dims=tuple(extra_dims.keys()), + name=name, )(rng, *extra_dims.values(), *params) return xrv_constructor diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py index bf4ef5f802..f21747c2e6 100644 --- a/pytensor/xtensor/rewriting/utils.py +++ b/pytensor/xtensor/rewriting/utils.py @@ -1,6 +1,7 @@ from pytensor.compile import optdb -from pytensor.graph.rewriting.basic import NodeRewriter +from pytensor.graph.rewriting.basic import NodeRewriter, in2out from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase +from pytensor.tensor.rewriting.ofg import inline_ofg_expansion lower_xtensor_db = EquilibriumDB(ignore_newtrees=False) @@ -14,6 +15,15 @@ position=0.1, ) +# Register OFG inline again after lowering xtensor +optdb.register( + "inline_ofg_expansion_xtensor", + in2out(inline_ofg_expansion), + "fast_run", + "fast_compile", + position=0.11, +) + def register_lower_xtensor( node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs diff --git a/pytensor/xtensor/rewriting/vectorization.py b/pytensor/xtensor/rewriting/vectorization.py index cc3834cc48..bed7da564b 100644 --- a/pytensor/xtensor/rewriting/vectorization.py +++ b/pytensor/xtensor/rewriting/vectorization.py @@ -116,7 +116,7 @@ def lower_rv(fgraph, node): size = [*extra_dim_lengths, *param_batch_shape] # RVs are their own core Op - new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs + new_next_rng, tensor_out = core_op.make_node(rng, size, *tensor_params).outputs # Convert output Tensors to XTensors new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 94b0eeedfe..c5f345e45a 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -71,6 +71,8 @@ def __init__( self.name = name self.numpy_dtype = np.dtype(self.dtype) self.filter_checks_isfinite = False + # broadcastable is here just for code that would work fine with XTensorType but checks for it + self.broadcastable = (False,) * self.ndim def clone( self, @@ -93,6 +95,10 @@ def filter(self, value, strict=False, allow_downcast=None): self, value, strict=strict, allow_downcast=allow_downcast ) + @staticmethod + def may_share_memory(a, b): + return TensorType.may_share_memory(a, b) + def filter_variable(self, other, allow_convert=True): if not isinstance(other, Variable): # The value is not a Variable: we cast it into @@ -160,7 +166,7 @@ def convert_variable(self, var): return None def __repr__(self): - return f"XTensorType({self.dtype}, {self.dims}, {self.shape})" + return f"XTensorType({self.dtype}, shape={self.shape}, dims={self.dims})" def __hash__(self): return hash((type(self), self.dtype, self.shape, self.dims)) diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py index 7d99b9c63c..8243e78170 100644 --- a/pytensor/xtensor/vectorization.py +++ b/pytensor/xtensor/vectorization.py @@ -142,8 +142,12 @@ def __init__( core_op, core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]], extra_dims: tuple[str, ...], + name: str | None = None, ): super().__init__() + if name is None: + name = getattr(core_op, "name", None) + self.name = name self.core_op = core_op inps_core_dims, out_core_dims = core_dims for operand_dims in (*inps_core_dims, out_core_dims): @@ -154,6 +158,15 @@ def __init__( raise ValueError("size_dims must be unique") self.extra_dims = tuple(extra_dims) + def __str__(self): + if self.name is not None: + name = self.name + attrs = f"(core_dims={self.core_dims}, extra_dims={self.extra_dims})" + else: + name = self.__class__.__name__ + attrs = f"(core_op={self.core_op}, core_dims={self.core_dims}, extra_dims={self.extra_dims})" + return f"{name}({attrs})" + def update(self, node): # RNG input and update are the first input and output respectively return {node.inputs[0]: node.outputs[0]} diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 4075ed3ed6..7da993b3dc 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -1159,6 +1159,10 @@ def test_csm_grad(self): structured=True, ) + @pytest.mark.skipif( + version.parse(sp.__version__) >= version.parse("1.16.0"), + reason="Scipy 1.16 introduced some changes that make this test fail", + ) def test_csm_sparser(self): # Test support for gradients sparser than the input. @@ -1191,6 +1195,10 @@ def test_csm_sparser(self): assert len(spmat.data) == len(res) + @pytest.mark.skipif( + version.parse(sp.__version__) >= version.parse("1.16.0"), + reason="Scipy 1.16 introduced some changes that make this test fail", + ) def test_csm_unsorted(self): # Test support for gradients of unsorted inputs. diff --git a/tests/xtensor/test_random.py b/tests/xtensor/test_random.py index de248c3cb7..cf822a03de 100644 --- a/tests/xtensor/test_random.py +++ b/tests/xtensor/test_random.py @@ -7,7 +7,7 @@ import pytensor.tensor.random as ptr import pytensor.xtensor.random as pxr -from pytensor import function, shared +from pytensor import config, function, shared from pytensor.graph import rewrite_graph from pytensor.graph.basic import equal_computations from pytensor.tensor import broadcast_arrays, tensor @@ -112,6 +112,19 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): ) +def test_dtype(): + x = normal(0, 1) + assert x.type.dtype == config.floatX + + with config.change_flags(floatX="float64"): + x = normal(0, 1) + assert x.type.dtype == "float64" + + with config.change_flags(floatX="float32"): + x = normal(0, 1) + assert x.type.dtype == "float32" + + def test_normal(): rng = random_generator_type("rng") c_size = tensor("c_size", shape=(), dtype=int)