|
23 | 23 | from pytensor.tensor import inplace
|
24 | 24 | from pytensor.tensor.basic import as_tensor_variable
|
25 | 25 | from pytensor.tensor.blas import (
|
| 26 | + BatchedDot, |
26 | 27 | Dot22,
|
27 | 28 | Dot22Scalar,
|
28 | 29 | Gemm,
|
@@ -2700,6 +2701,31 @@ def check_first_dim(inverted):
|
2700 | 2701 | check_first_dim(inverted)
|
2701 | 2702 |
|
2702 | 2703 |
|
| 2704 | +@config.change_flags(mode="FAST_RUN") |
| 2705 | +def test_batched_dot_blas_flags(): |
| 2706 | + """Test that BatchedDot works regardless of presence of Blas flags""" |
| 2707 | + mode = "FAST_RUN" |
| 2708 | + rng = np.random.default_rng(2708) |
| 2709 | + |
| 2710 | + x = tensor("x", shape=(2, 5, 3)) |
| 2711 | + y = tensor("y", shape=(2, 3, 1)) |
| 2712 | + out = batched_dot(x, y) |
| 2713 | + assert isinstance(out.owner.op, BatchedDot) |
| 2714 | + x_test = rng.normal(size=x.type.shape).astype(x.type.dtype) |
| 2715 | + y_test = rng.normal(size=y.type.shape).astype(y.type.dtype) |
| 2716 | + |
| 2717 | + fn = function([x, y], out, mode=mode) |
| 2718 | + [batched_dot_thunk] = fn.vm.thunks |
| 2719 | + assert hasattr(batched_dot_thunk, "cthunk") |
| 2720 | + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) |
| 2721 | + |
| 2722 | + with config.change_flags(blas__ldflags=""): |
| 2723 | + fn = function([x, y], out, mode=mode) |
| 2724 | + [batched_dot_thunk] = fn.vm.thunks |
| 2725 | + assert not hasattr(batched_dot_thunk, "cthunk") |
| 2726 | + np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) |
| 2727 | + |
| 2728 | + |
2703 | 2729 | def test_batched_tensordot():
|
2704 | 2730 | rng = np.random.default_rng(unittest_tools.fetch_seed())
|
2705 | 2731 | first = tensor4("first")
|
|
0 commit comments