Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Distributions
:toctree: generated/

Chi
Maxwell
DiscreteMarkovChain
GeneralizedPoisson
GenExtreme
Expand Down
3 changes: 2 additions & 1 deletion pymc_experimental/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Experimental probability distributions for stochastic nodes in PyMC.
"""

from pymc_experimental.distributions.continuous import Chi, GenExtreme
from pymc_experimental.distributions.continuous import Chi, GenExtreme, Maxwell
from pymc_experimental.distributions.discrete import GeneralizedPoisson, Skellam
from pymc_experimental.distributions.histogram_utils import histogram_approximation
from pymc_experimental.distributions.multivariate import R2D2M2CP
Expand All @@ -30,4 +30,5 @@
"R2D2M2CP",
"histogram_approximation",
"Chi",
"Maxwell",
]
71 changes: 71 additions & 0 deletions pymc_experimental/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.logprob.utils import CheckParameterValue
from pymc.pytensorf import floatX
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -280,3 +281,73 @@ def __new__(cls, name, nu, **kwargs):
@classmethod
def dist(cls, nu, **kwargs):
return CustomDist.dist(nu, dist=cls.chi_dist, class_name="Chi", **kwargs)


class Maxwell:
R"""
The Maxwell-Boltzmann distribution

The pdf of this distribution is

.. math::

f(x \mid a) = {\displaystyle {\sqrt {\frac {2}{\pi }}}\,{\frac {x^{2}}{a^{3}}}\,\exp \left({\frac {-x^{2}}{2a^{2}}}\right)}

Read more about it on `Wikipedia <https://en.wikipedia.org/wiki/Maxwell%E2%80%93Boltzmann_distribution>`_

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import arviz as az
plt.style.use('arviz-darkgrid')
x = np.linspace(0, 20, 200)
for a in [1, 2, 5]:
pdf = st.maxwell.pdf(x, scale=a)
plt.plot(x, pdf, label=r'$a$ = {}'.format(a))
plt.xlabel('x', fontsize=12)
plt.ylabel('f(x)', fontsize=12)
plt.legend(loc=1)
plt.show()

======== =========================================================================
Support :math:`x \in (0, \infty)`
Mean :math:`2a \sqrt{\frac{2}{\pi}}`
Variance :math:`\frac{a^2(3 \pi - 8)}{\pi}`
======== =========================================================================

Parameters
----------
a : tensor_like of float
Scale parameter (a > 0).

"""

@staticmethod
def maxwell_dist(a: TensorVariable, size: TensorVariable) -> TensorVariable:
if rv_size_is_none(size):
size = a.shape

a = CheckParameterValue("a > 0")(a, pt.all(pt.gt(a, 0)))

return Chi.dist(nu=3, size=size) * a

def __new__(cls, name, a, **kwargs):
return CustomDist(
name,
a,
dist=cls.maxwell_dist,
class_name="Maxwell",
**kwargs,
)

@classmethod
def dist(cls, a, **kwargs):
return CustomDist.dist(
a,
dist=cls.maxwell_dist,
class_name="Maxwell",
**kwargs,
)
25 changes: 24 additions & 1 deletion pymc_experimental/tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)

# the distributions to be tested
from pymc_experimental.distributions import Chi, GenExtreme
from pymc_experimental.distributions import Chi, GenExtreme, Maxwell


class TestGenExtremeClass:
Expand Down Expand Up @@ -159,3 +159,26 @@ def test_logcdf(self):
{"nu": Rplus},
lambda value, nu: sp.chi.logcdf(value, df=nu),
)


class TestMaxwell:
"""
Wrapper class so that tests of experimental additions can be dropped into
PyMC directly on adoption.
"""

def test_logp(self):
check_logp(
Maxwell,
Rplus,
{"a": Rplus},
lambda value, a: sp.maxwell.logpdf(value, scale=a),
)

def test_logcdf(self):
check_logcdf(
Maxwell,
Rplus,
{"a": Rplus},
lambda value, a: sp.maxwell.logcdf(value, scale=a),
)