@@ -232,6 +232,36 @@ def model_observation_dtype_casting():
232
232
return model , compute_graph , plates
233
233
234
234
235
+ def model_non_random_variable_rvs ():
236
+ """Test that node types are not inferred based on the variable Op type, but
237
+ model properties
238
+
239
+ See https://github.com/pymc-devs/pymc/issues/5766
240
+ """
241
+ with pm .Model () as model :
242
+ mu = pm .Normal (name = "mu" , mu = 0.0 , sigma = 5.0 )
243
+
244
+ y_raw = pm .Normal .dist (mu )
245
+ y = pm .math .clip (y_raw , - 3 , 3 )
246
+ model .register_rv (y , name = "y" )
247
+
248
+ z_raw = pm .Normal .dist (y , shape = (5 ,))
249
+ z = pm .math .clip (z_raw , - 1 , 1 )
250
+ model .register_rv (z , name = "z" , data = [0 ] * 5 )
251
+
252
+ compute_graph = {
253
+ "mu" : set (),
254
+ "y" : {"mu" },
255
+ "z" : {"y" },
256
+ }
257
+ plates = {
258
+ "" : {"mu" , "y" },
259
+ "5" : {"z" },
260
+ }
261
+
262
+ return model , compute_graph , plates
263
+
264
+
235
265
class BaseModelGraphTest (SeededTest ):
236
266
model_func = None
237
267
@@ -360,3 +390,7 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
360
390
mg = ModelGraph (model_with_different_descendants ())
361
391
assert set (mg .vars_to_plot (var_names = var_names )) == set (vars_to_plot )
362
392
assert mg .make_compute_graph (var_names = var_names ) == compute_graph
393
+
394
+
395
+ class TestModelNonRandomVariableRVs (BaseModelGraphTest ):
396
+ model_func = model_non_random_variable_rvs
0 commit comments