diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 12eb36a0d..8d390aaea 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -278,7 +278,7 @@ def create_variable(self, name: str) -> pt.TensorVariable: def sample_prior( factory: VariableFactory, coords=None, - name: str = "var", + name: str = "variable", wrap: bool = False, **sample_prior_predictive_kwargs, ) -> xr.Dataset: @@ -292,7 +292,7 @@ def sample_prior( The coordinates for the variable, by default None. Only required if the dims are specified. name : str, optional - The name of the variable, by default "var". + The name of the variable, by default "variable". wrap : bool, optional Whether to wrap the variable in a `pm.Deterministic` node, by default False. sample_prior_predictive_kwargs : dict @@ -900,7 +900,7 @@ def __eq__(self, other) -> bool: def sample_prior( self, coords=None, - name: str = "var", + name: str = "variable", **sample_prior_predictive_kwargs, ) -> xr.Dataset: """Sample the prior distribution for the variable. @@ -911,7 +911,7 @@ def sample_prior( The coordinates for the variable, by default None. Only required if the dims are specified. name : str, optional - The name of the variable, by default "var". + The name of the variable, by default "variable". sample_prior_predictive_kwargs : dict Additional arguments to pass to `pm.sample_prior_predictive`. diff --git a/tests/test_prior.py b/tests/test_prior.py index 70729b9f9..a7b630f83 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -616,7 +616,10 @@ def test_custom_transform() -> None: prior = dist.sample_prior(draws=10) df_prior = prior.to_dataframe() - np.testing.assert_array_equal(df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2) + np.testing.assert_array_equal( + df_prior.variable.to_numpy(), + df_prior.variable_raw.to_numpy() ** 2, + ) def test_custom_transform_comes_first() -> None: @@ -627,7 +630,10 @@ def test_custom_transform_comes_first() -> None: prior = dist.sample_prior(draws=10) df_prior = prior.to_dataframe() - np.testing.assert_array_equal(df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy()) + np.testing.assert_array_equal( + df_prior.variable.to_numpy(), + 2 * df_prior.variable_raw.to_numpy(), + ) clear_custom_transforms() @@ -686,7 +692,7 @@ def test_sample_prior_arbitrary_no_name() -> None: prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) assert isinstance(prior, xr.Dataset) - assert "var" not in prior + assert "variable" not in prior prior_with = sample_prior( var, @@ -696,7 +702,7 @@ def test_sample_prior_arbitrary_no_name() -> None: ) assert isinstance(prior_with, xr.Dataset) - assert "var" in prior_with + assert "variable" in prior_with def test_create_prior_with_arbitrary() -> None: