From ac9ef0a6d54af752b7d07238099b073b580f0a7d Mon Sep 17 00:00:00 2001
From: Luciano Paz <luciano.paz.neuro@gmail.com>
Date: Tue, 14 Jan 2025 10:57:05 +0100
Subject: [PATCH] Fix conditional import of zarr

---
 pymc/backends/zarr.py | 37 +++++++++++++++++++++++--------------
 pymc/sampling/mcmc.py |  6 +++++-
 2 files changed, 28 insertions(+), 15 deletions(-)

diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py
index e9aba5fe0d..b9c1e49ea3 100644
--- a/pymc/backends/zarr.py
+++ b/pymc/backends/zarr.py
@@ -15,14 +15,11 @@
 from typing import Any
 
 import arviz as az
-import numcodecs
 import numpy as np
 import xarray as xr
-import zarr
 
 from arviz.data.base import make_attrs
 from arviz.data.inference_data import WARMUP_TAG
-from numcodecs.abc import Codec
 from pytensor.tensor.variable import TensorVariable
 
 import pymc
@@ -44,11 +41,23 @@
 from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name
 
 try:
+    import numcodecs
+    import zarr
+
+    from numcodecs.abc import Codec
+    from zarr import Group
     from zarr.storage import BaseStore, default_compressor
     from zarr.sync import Synchronizer
 
     _zarr_available = True
 except ImportError:
+    from typing import TYPE_CHECKING, TypeVar
+
+    if not TYPE_CHECKING:
+        Codec = TypeVar("Codec")
+        Group = TypeVar("Group")
+        BaseStore = TypeVar("BaseStore")
+        Synchronizer = TypeVar("Synchronizer")
     _zarr_available = False
 
 
@@ -243,7 +252,7 @@ def flush(self):
 
 def get_initial_fill_value_and_codec(
     dtype: Any,
-) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
+) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]:
     _dtype = np.dtype(dtype)
     fill_value: FILL_VALUE_TYPE = None
     codec = None
@@ -366,27 +375,27 @@ def groups(self) -> list[str]:
         return [str(group_name) for group_name, _ in self.root.groups()]
 
     @property
-    def posterior(self) -> zarr.Group:
+    def posterior(self) -> Group:
         return self.root.posterior
 
     @property
-    def unconstrained_posterior(self) -> zarr.Group:
+    def unconstrained_posterior(self) -> Group:
         return self.root.unconstrained_posterior
 
     @property
-    def sample_stats(self) -> zarr.Group:
+    def sample_stats(self) -> Group:
         return self.root.sample_stats
 
     @property
-    def constant_data(self) -> zarr.Group:
+    def constant_data(self) -> Group:
         return self.root.constant_data
 
     @property
-    def observed_data(self) -> zarr.Group:
+    def observed_data(self) -> Group:
         return self.root.observed_data
 
     @property
-    def _sampling_state(self) -> zarr.Group:
+    def _sampling_state(self) -> Group:
         return self.root._sampling_state
 
     def init_trace(
@@ -646,12 +655,12 @@ def init_sampling_state_group(self, tune: int, chains: int):
 
     def init_group_with_empty(
         self,
-        group: zarr.Group,
+        group: Group,
         var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]],
         chains: int,
         draws: int,
         extra_var_attrs: dict | None = None,
-    ) -> zarr.Group:
+    ) -> Group:
         group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)}
         for name, (_dtype, shape) in var_dtype_and_shape.items():
             fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype)
@@ -689,8 +698,8 @@ def init_group_with_empty(
             array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
         return group
 
-    def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None:
-        group: zarr.Group | None = None
+    def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None:
+        group: Group | None = None
         if data_dict:
             group_coords = {}
             group = self.root.create_group(name=name, overwrite=True)
diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py
index ca91325ff1..7cbb6df26e 100644
--- a/pymc/sampling/mcmc.py
+++ b/pymc/sampling/mcmc.py
@@ -41,7 +41,6 @@
 from rich.theme import Theme
 from threadpoolctl import threadpool_limits
 from typing_extensions import Protocol
-from zarr.storage import MemoryStore
 
 import pymc as pm
 
@@ -80,6 +79,11 @@
 )
 from pymc.vartypes import discrete_types
 
+try:
+    from zarr.storage import MemoryStore
+except ImportError:
+    MemoryStore = type("MemoryStore", (), {})
+
 sys.setrecursionlimit(10000)
 
 __all__ = [