From 3acc9da9c3f9c787ac9139d0238ba61314d5c980 Mon Sep 17 00:00:00 2001 From: Vivek Anand Singh <17vivekanandsingh@gmail.com> Date: Wed, 15 Nov 2023 00:06:32 +0530 Subject: [PATCH] Updated MeasurableComparison variable to Tensor Variable --- pymc/logprob/binary.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index a344d80673..6003b269c1 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -20,6 +20,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar.basic import GE, GT, LE, LT, Invert +from pytensor.tensor import TensorVariable from pytensor.tensor.math import ge, gt, invert, le, lt from pymc.logprob.abstract import ( @@ -41,7 +42,7 @@ class MeasurableComparison(MeasurableElemwise): @node_rewriter(tracks=[gt, lt, ge, le]) def find_measurable_comparisons( fgraph: FunctionGraph, node: Node -) -> Optional[List[MeasurableComparison]]: +) -> Optional[List[TensorVariable]]: rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover