|
53 | 53 | from pytensor.scalar import (
|
54 | 54 | Abs,
|
55 | 55 | Add,
|
| 56 | + ArcCosh, |
| 57 | + ArcSinh, |
| 58 | + ArcTanh, |
56 | 59 | Cosh,
|
57 | 60 | Erf,
|
58 | 61 | Erfc,
|
|
71 | 74 | from pytensor.tensor.math import (
|
72 | 75 | abs,
|
73 | 76 | add,
|
| 77 | + arccosh, |
| 78 | + arcsinh, |
| 79 | + arctanh, |
74 | 80 | cosh,
|
75 | 81 | erf,
|
76 | 82 | erfc,
|
@@ -369,7 +375,23 @@ def apply(self, fgraph: FunctionGraph):
|
369 | 375 | class MeasurableTransform(MeasurableElemwise):
|
370 | 376 | """A placeholder used to specify a log-likelihood for a transformed measurable variable"""
|
371 | 377 |
|
372 |
| - valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh, Erf, Erfc, Erfcx) |
| 378 | + valid_scalar_types = ( |
| 379 | + Exp, |
| 380 | + Log, |
| 381 | + Add, |
| 382 | + Mul, |
| 383 | + Pow, |
| 384 | + Abs, |
| 385 | + Sinh, |
| 386 | + Cosh, |
| 387 | + Tanh, |
| 388 | + ArcSinh, |
| 389 | + ArcCosh, |
| 390 | + ArcTanh, |
| 391 | + Erf, |
| 392 | + Erfc, |
| 393 | + Erfcx, |
| 394 | + ) |
373 | 395 |
|
374 | 396 | # Cannot use `transform` as name because it would clash with the property added by
|
375 | 397 | # the `TransformValuesRewrite`
|
@@ -501,7 +523,9 @@ def measurable_sub_to_neg(fgraph, node):
|
501 | 523 | return [pt.add(minuend, pt.neg(subtrahend))]
|
502 | 524 |
|
503 | 525 |
|
504 |
| -@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh, erf, erfc, erfcx]) |
| 526 | +@node_rewriter( |
| 527 | + [exp, log, add, mul, pow, abs, sinh, cosh, tanh, arcsinh, arccosh, arctanh, erf, erfc, erfcx] |
| 528 | +) |
505 | 529 | def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
|
506 | 530 | """Find measurable transformations from Elemwise operators."""
|
507 | 531 |
|
@@ -544,6 +568,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
|
544 | 568 | Sinh: SinhTransform(),
|
545 | 569 | Cosh: CoshTransform(),
|
546 | 570 | Tanh: TanhTransform(),
|
| 571 | + ArcSinh: ArcsinhTransform(), |
| 572 | + ArcCosh: ArccoshTransform(), |
| 573 | + ArcTanh: ArctanhTransform(), |
547 | 574 | Erf: ErfTransform(),
|
548 | 575 | Erfc: ErfcTransform(),
|
549 | 576 | Erfcx: ErfcxTransform(),
|
@@ -660,6 +687,39 @@ def backward(self, value, *inputs):
|
660 | 687 | return pt.arctanh(value)
|
661 | 688 |
|
662 | 689 |
|
| 690 | +class ArcsinhTransform(RVTransform): |
| 691 | + name = "arcsinh" |
| 692 | + ndim_supp = 0 |
| 693 | + |
| 694 | + def forward(self, value, *inputs): |
| 695 | + return pt.arcsinh(value) |
| 696 | + |
| 697 | + def backward(self, value, *inputs): |
| 698 | + return pt.sinh(value) |
| 699 | + |
| 700 | + |
| 701 | +class ArccoshTransform(RVTransform): |
| 702 | + name = "arccosh" |
| 703 | + ndim_supp = 0 |
| 704 | + |
| 705 | + def forward(self, value, *inputs): |
| 706 | + return pt.arccosh(value) |
| 707 | + |
| 708 | + def backward(self, value, *inputs): |
| 709 | + return pt.cosh(value) |
| 710 | + |
| 711 | + |
| 712 | +class ArctanhTransform(RVTransform): |
| 713 | + name = "arctanh" |
| 714 | + ndim_supp = 0 |
| 715 | + |
| 716 | + def forward(self, value, *inputs): |
| 717 | + return pt.arctanh(value) |
| 718 | + |
| 719 | + def backward(self, value, *inputs): |
| 720 | + return pt.tanh(value) |
| 721 | + |
| 722 | + |
663 | 723 | class ErfTransform(RVTransform):
|
664 | 724 | name = "erf"
|
665 | 725 | ndim_supp = 0
|
|
0 commit comments