16
16
17
17
from typing import Union
18
18
19
- from aesara .graph .basic import walk
19
+ from aesara .compile import SharedVariable
20
+ from aesara .graph .basic import Constant , walk
20
21
from aesara .tensor .basic import TensorVariable , Variable
21
22
from aesara .tensor .elemwise import DimShuffle
22
23
from aesara .tensor .random .basic import RandomVariable
23
24
from aesara .tensor .random .var import (
24
25
RandomGeneratorSharedVariable ,
25
26
RandomStateSharedVariable ,
26
27
)
27
- from aesara .tensor .var import TensorConstant
28
28
29
29
from pymc .model import Model
30
30
@@ -163,7 +163,7 @@ def _is_potential_or_determinstic(var: Variable) -> bool:
163
163
# in case other code overrides str_repr, fallback
164
164
return False
165
165
166
- if isinstance (var , TensorConstant ):
166
+ if isinstance (var , ( Constant , SharedVariable ) ):
167
167
return _str_for_constant (var , formatting )
168
168
elif isinstance (
169
169
var .owner .op , (RandomVariable , SymbolicRandomVariable )
@@ -189,15 +189,22 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
189
189
return _str
190
190
191
191
192
- def _str_for_constant (var : TensorConstant , formatting : str ) -> str :
193
- if len (var .data .shape ) == 0 :
194
- return f"{ var .data :.3g} "
195
- elif len (var .data .shape ) == 1 and var .data .shape [0 ] == 1 :
196
- return f"{ var .data [0 ]:.3g} "
192
+ def _str_for_constant (var : Union [Constant , SharedVariable ], formatting : str ) -> str :
193
+ if isinstance (var , Constant ):
194
+ var_data = var .data
195
+ var_type = "constant"
196
+ else :
197
+ var_data = var .get_value ()
198
+ var_type = "shared"
199
+
200
+ if len (var_data .shape ) == 0 :
201
+ return f"{ var_data :.3g} "
202
+ elif len (var_data .shape ) == 1 and var_data .shape [0 ] == 1 :
203
+ return f"{ var_data [0 ]:.3g} "
197
204
elif "latex" in formatting :
198
- return r "\text{<constant> }"
205
+ return rf "\text{{< { var_type } >} }"
199
206
else :
200
- return r"<constant >"
207
+ return rf"< { var_type } >"
201
208
202
209
203
210
def _str_for_expression (var : Variable , formatting : str ) -> str :
0 commit comments