@@ -1215,44 +1215,56 @@ def test_sample_deterministic():
1215
1215
1216
1216
1217
1217
class TestDraw (SeededTest ):
1218
- def test_draw_one_variable (self ):
1218
+ def test_univariate (self ):
1219
1219
with pm .Model ():
1220
1220
x = pm .Normal ("x" )
1221
1221
1222
1222
x_draws = pm .draw (x )
1223
- assert x_draws .shape == (1 , )
1223
+ assert x_draws .shape == ()
1224
1224
1225
- def test_draw_several_variables (self ):
1225
+ (x_draws ,) = pm .draw ([x ])
1226
+ assert x_draws .shape == ()
1227
+
1228
+ x_draws = pm .draw (x , draws = 10 )
1229
+ assert x_draws .shape == (10 ,)
1230
+
1231
+ (x_draws ,) = pm .draw ([x ], draws = 10 )
1232
+ assert x_draws .shape == (10 ,)
1233
+
1234
+ def test_multivariate (self ):
1235
+ with pm .Model ():
1236
+ mln = pm .Multinomial ("mln" , n = 5 , p = np .array ([0.25 , 0.25 , 0.25 , 0.25 ]))
1237
+
1238
+ mln_draws = pm .draw (mln , draws = 1 )
1239
+ assert mln_draws .shape == (4 ,)
1240
+
1241
+ (mln_draws ,) = pm .draw ([mln ], draws = 1 )
1242
+ assert mln_draws .shape == (4 ,)
1243
+
1244
+ mln_draws = pm .draw (mln , draws = 10 )
1245
+ assert mln_draws .shape == (10 , 4 )
1246
+
1247
+ (mln_draws ,) = pm .draw ([mln ], draws = 10 )
1248
+ assert mln_draws .shape == (10 , 4 )
1249
+
1250
+ def test_multiple_variables (self ):
1226
1251
with pm .Model ():
1227
1252
x = pm .Normal ("x" )
1228
1253
y = pm .Normal ("y" , shape = 10 )
1229
1254
z = pm .Uniform ("z" , shape = 5 )
1255
+ w = pm .Dirichlet ("w" , a = [1 , 1 , 1 ])
1230
1256
1231
- num_draws = 1000
1232
- # Draw samples of a list variables
1233
- draws = pm .draw ([x , y , z ], draws = num_draws )
1234
- assert draws [0 ].shape == (num_draws ,)
1235
- assert draws [1 ].shape == (num_draws , 10 )
1236
- assert draws [2 ].shape == (num_draws , 5 )
1237
-
1238
- # Draw samples of a tuple variables
1239
- draws = pm .draw ((x , y , z ), draws = num_draws )
1257
+ num_draws = 100
1258
+ draws = pm .draw ((x , y , z , w ), draws = num_draws )
1240
1259
assert draws [0 ].shape == (num_draws ,)
1241
1260
assert draws [1 ].shape == (num_draws , 10 )
1242
1261
assert draws [2 ].shape == (num_draws , 5 )
1243
-
1244
- def test_multivariate (self ):
1245
- with pm .Model ():
1246
- mln = pm .Multinomial ("mln" , n = 5 , p = np .array ([0.25 , 0.25 , 0.25 , 0.25 ]))
1247
-
1248
- mln_draws = pm .draw (mln , draws = 100 )
1249
- assert mln_draws .shape == (100 , 4 )
1262
+ assert draws [3 ].shape == (num_draws , 3 )
1250
1263
1251
1264
def test_draw_different_samples (self ):
1252
1265
with pm .Model ():
1253
1266
x = pm .Normal ("x" )
1254
1267
1255
1268
x_draws_1 = pm .draw (x , 100 )
1256
1269
x_draws_2 = pm .draw (x , 100 )
1257
- # Check if the draw function will draw different samples each time
1258
1270
assert not np .all (np .isclose (x_draws_1 , x_draws_2 ))
0 commit comments