|
21 | 21 | import pytest
|
22 | 22 |
|
23 | 23 | from aesara.tensor import TensorVariable
|
24 |
| -from scipy import integrate |
25 | 24 |
|
26 | 25 | import pymc as pm
|
27 | 26 |
|
28 | 27 | from pymc.distributions import MvNormal, MvStudentT, joint_logp, logp
|
29 | 28 | from pymc.distributions.distribution import _moment, moment
|
30 | 29 | from pymc.distributions.shape_utils import to_tuple
|
31 | 30 | from pymc.tests.distributions.util import assert_moment_is_expected
|
32 |
| -from pymc.vartypes import continuous_types |
33 |
| - |
34 |
| - |
35 |
| -def integrate_nd(f, domain, shape, dtype): |
36 |
| - if shape == () or shape == (1,): |
37 |
| - if dtype in continuous_types: |
38 |
| - return integrate.quad(f, domain.lower, domain.upper, epsabs=1e-8)[0] |
39 |
| - else: |
40 |
| - return sum(f(j) for j in range(domain.lower, domain.upper + 1)) |
41 |
| - elif shape == (2,): |
42 |
| - |
43 |
| - def f2(a, b): |
44 |
| - return f([a, b]) |
45 |
| - |
46 |
| - return integrate.dblquad( |
47 |
| - f2, |
48 |
| - domain.lower[0], |
49 |
| - domain.upper[0], |
50 |
| - lambda _: domain.lower[1], |
51 |
| - lambda _: domain.upper[1], |
52 |
| - )[0] |
53 |
| - elif shape == (3,): |
54 |
| - |
55 |
| - def f3(a, b, c): |
56 |
| - return f([a, b, c]) |
57 |
| - |
58 |
| - return integrate.tplquad( |
59 |
| - f3, |
60 |
| - domain.lower[0], |
61 |
| - domain.upper[0], |
62 |
| - lambda _: domain.lower[1], |
63 |
| - lambda _: domain.upper[1], |
64 |
| - lambda _, __: domain.lower[2], |
65 |
| - lambda _, __: domain.upper[2], |
66 |
| - )[0] |
67 |
| - else: |
68 |
| - raise ValueError("Dont know how to integrate shape: " + str(shape)) |
69 | 31 |
|
70 | 32 |
|
71 | 33 | class TestBugfixes:
|
|
0 commit comments