Skip to content

Add experimental dims module with objects that follow dim-based semantics (like xarray) #7820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ jobs:
tests/logprob/test_transforms.py
tests/logprob/test_utils.py

- |
tests/dims/distributions/test_core.py
tests/dims/distributions/test_scalar.py
tests/dims/distributions/test_vector.py
tests/dims/test_model.py

fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-alternative-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- numpyro>=0.8.0
- pandas>=0.24.0
- pip
- pytensor>=2.31.2,<2.32
- pytensor>=2.31.5,<2.32
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.25.0
- pandas>=0.24.0
- pip
- pytensor>=2.31.2,<2.32
- pytensor>=2.31.5,<2.32
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- numpy>=1.25.0
- pandas>=0.24.0
- pip
- pytensor>=2.31.2,<2.32
- pytensor>=2.31.5,<2.32
- python-graphviz
- rich>=13.7.1
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- pandas>=0.24.0
- pip
- polyagamma
- pytensor>=2.31.2,<2.32
- pytensor>=2.31.5,<2.32
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.25.0
- pandas>=0.24.0
- pip
- pytensor>=2.31.2,<2.32
- pytensor>=2.31.5,<2.32
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies:
- pandas>=0.24.0
- pip
- polyagamma
- pytensor>=2.31.2,<2.32
- pytensor>=2.31.5,<2.32
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2,418 changes: 2,418 additions & 0 deletions docs/source/learn/core_notebooks/dims_module.ipynb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside the notebook you have to use myst syntax, so you'll have to replace the :reftype:`reftarget` with {reftype}`reftarget`

Large diffs are not rendered by default.

32 changes: 31 additions & 1 deletion pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,40 @@

_log = logging.getLogger(__name__)


RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False


# random variable object ...
Var = Any


def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
safe_coords = coords

if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS:
coords_lengths = {k: len(v) for k, v in coords.items()}
for var_name, var in vars_dict.items():
# Iterate in reversed because of chain/draw batch dimensions
for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)):
coord_length = coords_lengths.get(dim, None)
if (coord_length is not None) and (coord_length != dim_length):
warnings.warn(
f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n"
"The originate coordinates for this dim will not be included in the returned dataset for any of the variables. "
"Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n"
"To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`",
UserWarning,
)
if safe_coords is coords:
safe_coords = coords.copy()
safe_coords.pop(dim)
coords_lengths.pop(dim)

# FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)`
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)


def find_observations(model: "Model") -> dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
Expand Down Expand Up @@ -365,7 +395,7 @@ def priors_to_xarray(self):
priors_dict[group] = (
None
if var_names is None
else dict_to_dataset(
else dict_to_dataset_drop_incompatible_coords(
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
library=pymc,
coords=self.coords,
Expand Down
24 changes: 15 additions & 9 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

import io
import typing
import urllib.request

from collections.abc import Sequence
from copy import copy
from typing import cast
from typing import Union, cast

import numpy as np
import pandas as pd
Expand All @@ -32,12 +33,13 @@
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.variable import TensorConstant, TensorVariable

import pymc as pm

from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import convert_data
from pymc.exceptions import ShapeError
from pymc.pytensorf import convert_data, rvs_in_graph
from pymc.vartypes import isgenerator

if typing.TYPE_CHECKING:
from pymc.model.core import Model

__all__ = [
"Data",
"Minibatch",
Expand Down Expand Up @@ -197,7 +199,7 @@ def determine_coords(

if isinstance(value, np.ndarray) and dims is not None:
if len(dims) != value.ndim:
raise pm.exceptions.ShapeError(
raise ShapeError(
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
actual=value.shape,
expected=value.ndim,
Expand All @@ -222,6 +224,7 @@ def Data(
dims: Sequence[str] | None = None,
coords: dict[str, Sequence | np.ndarray] | None = None,
infer_dims_and_coords=False,
model: Union["Model", None] = None,
**kwargs,
) -> SharedVariable | TensorConstant:
"""Create a data container that registers a data variable with the model.
Expand Down Expand Up @@ -286,15 +289,18 @@ def Data(
... model.set_data("data", data_vals)
... idatas.append(pm.sample())
"""
from pymc.model.core import modelcontext

if coords is None:
coords = {}

if isinstance(value, list):
value = np.array(value)

# Add data container to the named variables of the model.
model = pm.Model.get_context(error_if_none=False)
if model is None:
try:
model = modelcontext(model)
except TypeError:
raise TypeError(
"No model on context stack, which is needed to instantiate a data container. "
"Add variable inside a 'with model:' block."
Expand All @@ -321,7 +327,7 @@ def Data(
if isinstance(dims, str):
dims = (dims,)
if not (dims is None or len(dims) == x.ndim):
raise pm.exceptions.ShapeError(
raise ShapeError(
"Length of `dims` must match the dimensions of the dataset.",
actual=len(dims),
expected=x.ndim,
Expand Down
73 changes: 73 additions & 0 deletions pymc/dims/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def __init__():
"""Make PyMC aware of the xtensor functionality.

This should be done eagerly once development matures.
"""
import datetime
import warnings

from pytensor.compile import optdb

from pymc.initial_point import initial_point_rewrites_db
from pymc.logprob.abstract import MeasurableOp
from pymc.logprob.rewriting import logprob_rewrites_db

# Filter PyTensor xtensor warning, we emmit our own warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import pytensor.xtensor

from pytensor.xtensor.vectorization import XRV

# Make PyMC aware of xtensor functionality
MeasurableOp.register(XRV)
logprob_rewrites_db.register(
"pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
)
logprob_rewrites_db.register(
"post_lower_xtensor", optdb.query("+lower_xtensor"), "cleanup", position=5.1
)
initial_point_rewrites_db.register(
"lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
)

# TODO: Better model of probability of bugs
day_of_conception = datetime.date(2025, 6, 17)
day_of_last_bug = datetime.date(2025, 6, 30)
today = datetime.date.today()
days_with_bugs = (day_of_last_bug - day_of_conception).days
Copy link
Member

@twiecki twiecki Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wtf 😆

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has two purposes: distract reviewers so they don't focus on the critical changes, and prove that OSS libraries can't be fun.

days_without_bugs = (today - day_of_last_bug).days
p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10))
if p > 0.05:
warnings.warn(
f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n"
"Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues.\n"
"Disclaimer: This an experimental API and may change at any time.",
UserWarning,
stacklevel=2,
)


__init__()
del __init__

from pytensor.xtensor import as_xtensor, concat

from pymc.dims import math
from pymc.dims.distributions import *
from pymc.dims.model import Data, Deterministic, Potential
15 changes: 15 additions & 0 deletions pymc/dims/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pymc.dims.distributions.scalar import *
from pymc.dims.distributions.vector import *
Loading
Loading