Skip to content

Commit 8a1644b

Browse files
committed
Do not generate C code for BatchedDot when BLAS flags are missing
1 parent 89fe939 commit 8a1644b

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

pytensor/tensor/blas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,6 +1795,10 @@ def c_header_dirs(self, **kwargs):
17951795
return ldflags(libs=False, include_dir=True)
17961796

17971797
def c_code(self, node, name, inp, out, sub):
1798+
# Can only compile if linked to blas libraries
1799+
if len(self.c_libraries()) <= 0:
1800+
raise NotImplementedError()
1801+
17981802
_x, _y = inp
17991803
(_z,) = out
18001804
fail = sub["fail"]

tests/tensor/test_blas.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytensor.tensor import inplace
2424
from pytensor.tensor.basic import as_tensor_variable
2525
from pytensor.tensor.blas import (
26+
BatchedDot,
2627
Dot22,
2728
Dot22Scalar,
2829
Gemm,
@@ -2700,6 +2701,31 @@ def check_first_dim(inverted):
27002701
check_first_dim(inverted)
27012702

27022703

2704+
@config.change_flags(mode="FAST_RUN")
2705+
def test_batched_dot_blas_flags():
2706+
"""Test that BatchedDot works regardless of presence of Blas flags"""
2707+
mode = "FAST_RUN"
2708+
rng = np.random.default_rng(2708)
2709+
2710+
x = tensor("x", shape=(2, 5, 3))
2711+
y = tensor("y", shape=(2, 3, 1))
2712+
out = batched_dot(x, y)
2713+
assert isinstance(out.owner.op, BatchedDot)
2714+
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
2715+
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
2716+
2717+
fn = function([x, y], out, mode=mode)
2718+
[batched_dot_thunk] = fn.vm.thunks
2719+
assert hasattr(batched_dot_thunk, "cthunk")
2720+
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
2721+
2722+
with config.change_flags(blas__ldflags=""):
2723+
fn = function([x, y], out, mode=mode)
2724+
[batched_dot_thunk] = fn.vm.thunks
2725+
assert not hasattr(batched_dot_thunk, "cthunk")
2726+
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
2727+
2728+
27032729
def test_batched_tensordot():
27042730
rng = np.random.default_rng(unittest_tools.fetch_seed())
27052731
first = tensor4("first")

0 commit comments

Comments
 (0)