From fd5f642a3b681594c8444bcdcea13dfe01cc5dfc Mon Sep 17 00:00:00 2001
From: Aryan <nandaaryan823@gmail.com>
Date: Sat, 3 Feb 2024 22:34:27 +0530
Subject: [PATCH 1/6] Remove block_diag from pymc.math in favor of alias to
 pytensor.tensor.slinalg.block_diag

---
 pymc/math.py | 63 ++++++++--------------------------------------------
 1 file changed, 9 insertions(+), 54 deletions(-)

diff --git a/pymc/math.py b/pymc/math.py
index 7c9ceaa9ec..1c8cb1f138 100644
--- a/pymc/math.py
+++ b/pymc/math.py
@@ -22,8 +22,6 @@
 import pytensor.sparse
 import pytensor.tensor as pt
 import pytensor.tensor.slinalg
-import scipy as sp
-import scipy.sparse
 
 from pytensor.graph.basic import Apply
 from pytensor.graph.op import Op
@@ -93,9 +91,8 @@
 from pytensor.tensor.linalg import solve_triangular
 from pytensor.tensor.nlinalg import matrix_inverse
 from pytensor.tensor.special import log_softmax, softmax
-from scipy.linalg import block_diag as scipy_block_diag
 
-from pymc.pytensorf import floatX, ix_, largest_common_dtype
+from pymc.pytensorf import floatX
 
 __all__ = [
     "abs",
@@ -513,55 +510,9 @@ def batched_diag(C):
         raise ValueError("Input should be 2 or 3 dimensional")
 
 
-class BlockDiagonalMatrix(Op):
-    __props__ = ("sparse", "format")
-
-    def __init__(self, sparse=False, format="csr"):
-        if format not in ("csr", "csc"):
-            raise ValueError(f"format must be one of: 'csr', 'csc', got {format}")
-        self.sparse = sparse
-        self.format = format
-
-    def make_node(self, *matrices):
-        if not matrices:
-            raise ValueError("no matrices to allocate")
-        matrices = list(map(pt.as_tensor, matrices))
-        if any(mat.type.ndim != 2 for mat in matrices):
-            raise TypeError("all data arguments must be matrices")
-        if self.sparse:
-            out_type = pytensor.sparse.matrix(self.format, dtype=largest_common_dtype(matrices))
-        else:
-            out_type = pytensor.tensor.matrix(dtype=largest_common_dtype(matrices))
-        return Apply(self, matrices, [out_type])
-
-    def perform(self, node, inputs, output_storage, params=None):
-        dtype = largest_common_dtype(inputs)
-        if self.sparse:
-            output_storage[0][0] = sp.sparse.block_diag(inputs, self.format, dtype)
-        else:
-            output_storage[0][0] = scipy_block_diag(*inputs).astype(dtype)
-
-    def grad(self, inputs, gout):
-        shapes = pt.stack([i.shape for i in inputs])
-        index_end = shapes.cumsum(0)
-        index_begin = index_end - shapes
-        slices = [
-            ix_(
-                pt.arange(index_begin[i, 0], index_end[i, 0]),
-                pt.arange(index_begin[i, 1], index_end[i, 1]),
-            )
-            for i in range(len(inputs))
-        ]
-        return [gout[0][slc] for slc in slices]
-
-    def infer_shape(self, fgraph, nodes, shapes):
-        first, second = zip(*shapes)
-        return [(pt.add(*first), pt.add(*second))]
-
-
-def block_diagonal(matrices, sparse=False, format="csr"):
-    r"""See scipy.sparse.block_diag or
-    scipy.linalg.block_diag for reference
+def block_diagonal(*matrices, sparse=False, format="csr"):
+    r"""See pt.slinalg.block_diag or
+    pytensor.sparse.basic.block_diag for reference
 
     Parameters
     ----------
@@ -577,4 +528,8 @@ def block_diagonal(matrices, sparse=False, format="csr"):
     """
     if len(matrices) == 1:  # graph optimization
         return matrices[0]
-    return BlockDiagonalMatrix(sparse=sparse, format=format)(*matrices)
+
+    if sparse:
+        return pytensor.sparse.basic.block_diag(*matrices, format=format)
+    else:
+        return pt.slinalg.block_diag(*matrices)

From a5913e5aabb7b87c3b17ab53dbb2ac21d1c670ad Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sat, 3 Feb 2024 17:15:06 +0000
Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 requirements-dev.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/requirements-dev.txt b/requirements-dev.txt
index b21437de01..8aff7d60c9 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -30,4 +30,4 @@ sphinx>=1.5
 sphinxext-rediraffe
 types-cachetools
 typing-extensions>=3.7.4
-watermark
\ No newline at end of file
+watermark

From b55c8773ef2aa10d95ee14fc439d831a6ca79050 Mon Sep 17 00:00:00 2001
From: Aryan <nandaaryan823@gmail.com>
Date: Tue, 6 Feb 2024 23:34:22 +0530
Subject: [PATCH 3/6] FutureWarning Added

---
 pymc/math.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/pymc/math.py b/pymc/math.py
index 1c8cb1f138..1924bb6043 100644
--- a/pymc/math.py
+++ b/pymc/math.py
@@ -527,8 +527,12 @@ def block_diagonal(*matrices, sparse=False, format="csr"):
     matrix
     """
     if len(matrices) == 1:  # graph optimization
+        warnings.warn(
+            "The behavior of block_diagonal when only one matrix is provided is deprecated. Use pytensor function instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
         return matrices[0]
-
     if sparse:
         return pytensor.sparse.basic.block_diag(*matrices, format=format)
     else:

From 33201553badd699589f50bc55deab51ab4005b31 Mon Sep 17 00:00:00 2001
From: Aryan <nandaaryan823@gmail.com>
Date: Thu, 8 Feb 2024 17:16:20 +0530
Subject: [PATCH 4/6] Update

---
 pymc/math.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/pymc/math.py b/pymc/math.py
index 1924bb6043..90351fb028 100644
--- a/pymc/math.py
+++ b/pymc/math.py
@@ -526,12 +526,12 @@ def block_diagonal(*matrices, sparse=False, format="csr"):
     -------
     matrix
     """
+    warnings.warn(
+        "The behavior of block_diagonal when only one matrix is provided is deprecated. Use pytensor function instead.",
+        FutureWarning,
+        stacklevel=2,
+    )
     if len(matrices) == 1:  # graph optimization
-        warnings.warn(
-            "The behavior of block_diagonal when only one matrix is provided is deprecated. Use pytensor function instead.",
-            FutureWarning,
-            stacklevel=2,
-        )
         return matrices[0]
     if sparse:
         return pytensor.sparse.basic.block_diag(*matrices, format=format)

From a01db7f77456c52672398589e083ed4e67a60655 Mon Sep 17 00:00:00 2001
From: Aryan <nandaaryan823@gmail.com>
Date: Thu, 8 Feb 2024 19:21:54 +0530
Subject: [PATCH 5/6] Update

---
 pymc/math.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/pymc/math.py b/pymc/math.py
index 90351fb028..e9f41be018 100644
--- a/pymc/math.py
+++ b/pymc/math.py
@@ -527,9 +527,7 @@ def block_diagonal(*matrices, sparse=False, format="csr"):
     matrix
     """
     warnings.warn(
-        "The behavior of block_diagonal when only one matrix is provided is deprecated. Use pytensor function instead.",
-        FutureWarning,
-        stacklevel=2,
+        "pymc.math.block_diagonal is deprecated in favor of `pytensor.tensor.linalg.block_diag` and `pytensor.sparse.block_diag` functions. This function will be removed in a future release",
     )
     if len(matrices) == 1:  # graph optimization
         return matrices[0]

From fe8a06531581becfafcbe36f9037cc4207463611 Mon Sep 17 00:00:00 2001
From: Aryan <nandaaryan823@gmail.com>
Date: Thu, 8 Feb 2024 19:30:00 +0530
Subject: [PATCH 6/6] Update

---
 pymc/math.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pymc/math.py b/pymc/math.py
index e9f41be018..7fe8d1e5e5 100644
--- a/pymc/math.py
+++ b/pymc/math.py
@@ -510,7 +510,7 @@ def batched_diag(C):
         raise ValueError("Input should be 2 or 3 dimensional")
 
 
-def block_diagonal(*matrices, sparse=False, format="csr"):
+def block_diagonal(matrices, sparse=False, format="csr"):
     r"""See pt.slinalg.block_diag or
     pytensor.sparse.basic.block_diag for reference