Skip to content

Add experimental dims module with objects that follow dim-based semantics (like xarray) #7820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 17, 2025

This builds on top of PyTensor xtensor module, to introduce distributions and model objects that follow xarray-like semantics. Example model:

import numpy as np
import pymc as pm
import pymc.dims as pmd

# Very realistic looking data!
observed_response_np = np.ones((5, 20), dtype=int)
coords = coords = {
    "participant": range(5),
    "trial": range(20),
    "item": range(3),
}
with pm.Model(coords=coords) as dmodel:
    observed_response = pmd.Data(
        "observed_response", observed_response_np, dims=("participant", "trial")
    )

    # Participant constant preferences for each item
    participant_preference = pmd.ZeroSumNormal(
        "participant_preference", 
        core_dims="item", 
        dims=("participant", "item"),
    )

    # Shared time effects across all participants
    time_effects = pmd.Normal("time_effects", dims=("item", "trial"))

    trial_preference = pmd.Deterministic(
        "trial_pereference",
        participant_preference + time_effects,
        dims=(...,),  # No need to specify, PyMC knows them
    )

    response = pmd.Categorical(
        "response",
        p=pmd.math.softmax(trial_preference, dim="item"),
        core_dims="item",
        observed=observed_response,
        dims=(...,), # No need to specify, PyMC knows them
    )

Equivalently, with the traditional API:

with pm.Model(coords=coords) as model:
    observed_response = pm.Data(
        "observed_response", observed_response_np, dims=("participant", "trial")
    )

    # Participant constant preferences for each item
    participant_preference = pm.ZeroSumNormal(
        "participant_preference", 
        n_zerosum_axes=1,
       dims=("participant", "item"),
    )

    # Shared time effects across all participants
    time_effects = pm.Normal("time_effects", dims=("trial", "item"))

    trial_preference = pm.Deterministic(
        "trial_preference",
        participant_preference[:, None, :] + time_effects[None, :, :],
        dims=("participant", "trial", "item"),
    )

    response = pm.Categorical(
        "response",
        p=pm.math.softmax(trial_preference, axis=-1),
        observed=observed_response,
        dims=("participant", "trial"),
    )

More details in the new core notebook

day_of_conception = datetime.date(2025, 6, 17)
day_of_last_bug = datetime.date(2025, 6, 17)
today = datetime.date.today()
days_with_bugs = (day_of_last_bug - day_of_conception).days
Copy link
Member

@twiecki twiecki Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wtf 😆

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has two purposes: distract reviewers so they don't focus on the critical changes, and prove that OSS libraries can't be fun.

@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 9 times, most recently from 1cfde5b to f571e5d Compare June 21, 2025 15:54
@twiecki
Copy link
Member

twiecki commented Jun 21, 2025

Can this index using labels? x["a"]

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 21, 2025

Can this index using labels? x["a"]

I don't know what x["a"] means :).

Is "a" a coordinate? x.loc["a"] would be the xarray syntax? You can't do that.

Like in xarray, you can do x.isel(dim=idxs) or x[{dim: idxs}].

You cannot do x.sel(dim=coords) or x.loc[coords]

The new PyTensor objects have dims but not coords. It's not trivial to encode coord based operations in our backends.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@twiecki
Copy link
Member

twiecki commented Jun 30, 2025

We should make this the 6.0 release.

@ricardoV94
Copy link
Member Author

We should make this the 6.0 release.

I agree, but would perhaps wait until we beta-tested it to the point it no longer feels too experimental

@ricardoV94 ricardoV94 changed the title Model with dims Add experimental dims module with objects that follow dim-based semantics (like xarray) Jun 30, 2025
@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 2 times, most recently from 9129df4 to 833e18e Compare June 30, 2025 15:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants