Skip to content

Commit c792e88

Browse files
LukeLBLuke LB
and
Luke LB
authored
Probabilty inference for arc transformations
Co-authored-by: Luke LB <[email protected]>
1 parent 154f5b0 commit c792e88

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

pymc/logprob/transforms.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
from pytensor.scalar import (
5454
Abs,
5555
Add,
56+
ArcCosh,
57+
ArcSinh,
58+
ArcTanh,
5659
Cosh,
5760
Erf,
5861
Erfc,
@@ -71,6 +74,9 @@
7174
from pytensor.tensor.math import (
7275
abs,
7376
add,
77+
arccosh,
78+
arcsinh,
79+
arctanh,
7480
cosh,
7581
erf,
7682
erfc,
@@ -369,7 +375,23 @@ def apply(self, fgraph: FunctionGraph):
369375
class MeasurableTransform(MeasurableElemwise):
370376
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
371377

372-
valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh, Erf, Erfc, Erfcx)
378+
valid_scalar_types = (
379+
Exp,
380+
Log,
381+
Add,
382+
Mul,
383+
Pow,
384+
Abs,
385+
Sinh,
386+
Cosh,
387+
Tanh,
388+
ArcSinh,
389+
ArcCosh,
390+
ArcTanh,
391+
Erf,
392+
Erfc,
393+
Erfcx,
394+
)
373395

374396
# Cannot use `transform` as name because it would clash with the property added by
375397
# the `TransformValuesRewrite`
@@ -501,7 +523,9 @@ def measurable_sub_to_neg(fgraph, node):
501523
return [pt.add(minuend, pt.neg(subtrahend))]
502524

503525

504-
@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh, erf, erfc, erfcx])
526+
@node_rewriter(
527+
[exp, log, add, mul, pow, abs, sinh, cosh, tanh, arcsinh, arccosh, arctanh, erf, erfc, erfcx]
528+
)
505529
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
506530
"""Find measurable transformations from Elemwise operators."""
507531

@@ -544,6 +568,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
544568
Sinh: SinhTransform(),
545569
Cosh: CoshTransform(),
546570
Tanh: TanhTransform(),
571+
ArcSinh: ArcsinhTransform(),
572+
ArcCosh: ArccoshTransform(),
573+
ArcTanh: ArctanhTransform(),
547574
Erf: ErfTransform(),
548575
Erfc: ErfcTransform(),
549576
Erfcx: ErfcxTransform(),
@@ -660,6 +687,39 @@ def backward(self, value, *inputs):
660687
return pt.arctanh(value)
661688

662689

690+
class ArcsinhTransform(RVTransform):
691+
name = "arcsinh"
692+
ndim_supp = 0
693+
694+
def forward(self, value, *inputs):
695+
return pt.arcsinh(value)
696+
697+
def backward(self, value, *inputs):
698+
return pt.sinh(value)
699+
700+
701+
class ArccoshTransform(RVTransform):
702+
name = "arccosh"
703+
ndim_supp = 0
704+
705+
def forward(self, value, *inputs):
706+
return pt.arccosh(value)
707+
708+
def backward(self, value, *inputs):
709+
return pt.cosh(value)
710+
711+
712+
class ArctanhTransform(RVTransform):
713+
name = "arctanh"
714+
ndim_supp = 0
715+
716+
def forward(self, value, *inputs):
717+
return pt.arctanh(value)
718+
719+
def backward(self, value, *inputs):
720+
return pt.tanh(value)
721+
722+
663723
class ErfTransform(RVTransform):
664724
name = "erf"
665725
ndim_supp = 0

tests/logprob/test_transforms.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
from pymc.logprob.abstract import MeasurableVariable, _logprob
5252
from pymc.logprob.basic import conditional_logp, logp
5353
from pymc.logprob.transforms import (
54+
ArccoshTransform,
55+
ArcsinhTransform,
56+
ArctanhTransform,
5457
ChainedTransform,
5558
CoshTransform,
5659
ErfcTransform,
@@ -1028,6 +1031,9 @@ def test_multivariate_transform(shift, scale):
10281031
(pt.sinh, SinhTransform()),
10291032
(pt.cosh, CoshTransform()),
10301033
(pt.tanh, TanhTransform()),
1034+
(pt.arcsinh, ArcsinhTransform()),
1035+
(pt.arccosh, ArccoshTransform()),
1036+
(pt.arctanh, ArctanhTransform()),
10311037
],
10321038
)
10331039
def test_erf_logp(pt_transform, transform):
@@ -1060,6 +1066,9 @@ def test_erf_logp(pt_transform, transform):
10601066
SinhTransform(),
10611067
CoshTransform(),
10621068
TanhTransform(),
1069+
ArcsinhTransform(),
1070+
ArccoshTransform(),
1071+
ArctanhTransform(),
10631072
],
10641073
)
10651074
def test_check_jac_det(transform):

0 commit comments

Comments
 (0)