Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

from collections import OrderedDict
from collections.abc import Callable
from typing import Any, Protocol, runtime_checkable
from typing import Any, Optional, Protocol, Union, runtime_checkable

try:
from typing import ParamSpec, TypeAlias
except ImportError:
from typing_extensions import ParamSpec, TypeAlias

import numpy as np

import jax
from jax.typing import ArrayLike

Expand All @@ -21,6 +23,18 @@
TraceT: TypeAlias = OrderedDict[str, Message]


NonScalarArray = Union[np.ndarray, jax.Array]
"""An alias for array-like types excluding scalars."""


NumLike = Union[NonScalarArray, np.number, int, float, complex]
"""An alias for array-like types excluding `np.bool_` and `bool`."""


PyTree: TypeAlias = Any
"""A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays."""


@runtime_checkable
class ConstraintT(Protocol):
is_discrete: bool = ...
Expand Down Expand Up @@ -87,20 +101,25 @@ def is_discrete(self) -> bool: ...

@runtime_checkable
class TransformT(Protocol):
domain = ConstraintT
codomain = ConstraintT
_inv: "TransformT" = None
domain: ConstraintT = ...
codomain: ConstraintT = ...
_inv: Optional["TransformT"] = ...

def __call__(self, x: ArrayLike) -> ArrayLike: ...
def _inverse(self, y: ArrayLike) -> ArrayLike: ...
def __call__(self, x: NumLike) -> NumLike: ...
def _inverse(self, y: NumLike) -> NumLike: ...
def log_abs_det_jacobian(
self, x: ArrayLike, y: ArrayLike, intermediates=None
) -> ArrayLike: ...
def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ...
self,
x: NumLike,
y: NumLike,
intermediates: Optional[PyTree] = None,
) -> NumLike: ...
def call_with_intermediates(
self, x: NumLike
) -> tuple[NumLike, Optional[PyTree]]: ...
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...

@property
def inv(self) -> "TransformT": ...
@property
def sign(self) -> ArrayLike: ...
def sign(self) -> NumLike: ...
22 changes: 11 additions & 11 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,18 +801,18 @@ def tree_flatten(self):
corr_cholesky: ConstraintT = _CorrCholesky()
corr_matrix: ConstraintT = _CorrMatrix()
dependent: ConstraintT = _Dependent()
greater_than: ConstraintT = _GreaterThan
greater_than_eq: ConstraintT = _GreaterThanEq
less_than: ConstraintT = _LessThan
less_than_eq: ConstraintT = _LessThanEq
independent: ConstraintT = _IndependentConstraint
integer_interval: ConstraintT = _IntegerInterval
integer_greater_than: ConstraintT = _IntegerGreaterThan
interval: ConstraintT = _Interval
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
less_than_eq = _LessThanEq
independent = _IndependentConstraint
integer_interval = _IntegerInterval
integer_greater_than = _IntegerGreaterThan
interval = _Interval
l1_ball: ConstraintT = _L1Ball()
lower_cholesky: ConstraintT = _LowerCholesky()
scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky()
multinomial: ConstraintT = _Multinomial
multinomial = _Multinomial
nonnegative: ConstraintT = _Nonnegative()
nonnegative_integer: ConstraintT = _IntegerNonnegative()
ordered_vector: ConstraintT = _OrderedVector()
Expand All @@ -830,5 +830,5 @@ def tree_flatten(self):
softplus_positive: ConstraintT = _SoftplusPositive()
sphere: ConstraintT = _Sphere()
unit_interval: ConstraintT = _UnitInterval()
open_interval: ConstraintT = _OpenInterval
zero_sum: ConstraintT = _ZeroSum
open_interval = _OpenInterval
zero_sum = _ZeroSum
Loading