From ffe009728f6a07f3d2d771f2b895d8f820d1b88a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 31 May 2023 09:37:05 +0200 Subject: [PATCH 1/4] Add failed assertion message in CI --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 88aac61148..b4b6c8d945 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -123,7 +123,7 @@ jobs: pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' - python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")' + python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"' env: PYTHON_VERSION: ${{ matrix.python-version }} INSTALL_NUMBA: ${{ matrix.install-numba }} @@ -175,7 +175,7 @@ jobs: pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' - python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")' + python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"' env: PYTHON_VERSION: 3.9 - name: Download previous benchmark data From 45c215391c8d18ac7a9b444148a3978e6f8f9d8f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 31 May 2023 11:43:22 +0200 Subject: [PATCH 2/4] Pin numpy upper bound in numba install numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but not numpy, even though scipy 1.7 requires numpy<1.23. When installing PyTensor next, pip installs a lower version of numpy via the PyPI. --- .github/workflows/test.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b4b6c8d945..9133a56e46 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -118,7 +118,9 @@ jobs: shell: bash -l {0} run: | mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy - if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi +# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but not numpy, even though scipy 1.7 requires numpy<1.23. When installing PyTensor next, pip installs a lower version of numpy via the PyPI. + if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi + if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro pip install -e ./ mamba list && pip freeze From a5f2b663adb711f3628d76a97f178fa87b8baea5 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 31 May 2023 11:52:21 +0200 Subject: [PATCH 3/4] Run numba and jax tests in separate jobs --- .github/workflows/test.yml | 27 +++++++++++++++++++++++-- setup.cfg | 2 ++ tests/link/jax/test_tensor_basic.py | 5 ++++- tests/link/numba/test_basic.py | 4 +++- tests/link/numba/test_cython_support.py | 5 +++++ tests/link/numba/test_performance.py | 4 ++++ tests/link/numba/test_sparse.py | 5 ++++- 7 files changed, 47 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9133a56e46..5115fa67e5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -73,7 +73,8 @@ jobs: python-version: ["3.8", "3.11"] fast-compile: [0,1] float32: [0,1] - install-numba: [1] + install-numba: [0] + install-jax: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -93,6 +94,27 @@ jobs: part: "tests/tensor/test_math.py" - fast-compile: 1 float32: 1 + include: + - install-numba: 1 + python-version: "3.8" + fast-compile: 0 + float32: 0 + part: "tests/link/numba" + - install-numba: 1 + python-version: "3.11" + fast-compile: 0 + float32: 0 + part: "tests/link/numba" + - install-jax: 1 + python-version: "3.8" + fast-compile: 0 + float32: 0 + part: "tests/link/jax" + - install-jax: 1 + python-version: "3.11" + fast-compile: 0 + float32: 0 + part: "tests/link/jax" steps: - uses: actions/checkout@v3 with: @@ -121,7 +143,7 @@ jobs: # numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but not numpy, even though scipy 1.7 requires numpy<1.23. When installing PyTensor next, pip installs a lower version of numpy via the PyPI. if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi - mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro + if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' @@ -129,6 +151,7 @@ jobs: env: PYTHON_VERSION: ${{ matrix.python-version }} INSTALL_NUMBA: ${{ matrix.install-numba }} + INSTALL_JAX: ${{ matrix.install-jax }} - name: Run tests shell: bash -l {0} diff --git a/setup.cfg b/setup.cfg index 554bfda88a..e28dc7e345 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,6 +9,8 @@ per-file-ignores = pytensor/link/jax/jax_dispatch.py:E402,F403,F401 pytensor/link/jax/jax_linker.py:E402,F403,F401 pytensor/sparse/sandbox/sp2.py:F401 + tests/link/jax/*.py:E402 + tests/link/numba/*.py:E402 tests/tensor/test_math_scipy.py:E402 tests/sparse/test_basic.py:E402 tests/sparse/test_opt.py:E402 diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 8cbbc91e97..0bc456fe22 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -1,7 +1,10 @@ -import jax.errors import numpy as np import pytest + +jax = pytest.importorskip("jax") +import jax.errors + import pytensor import pytensor.tensor.basic as at from pytensor.configdefaults import config diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index ed1dd197de..318e882dd0 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -3,10 +3,12 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union from unittest import mock -import numba import numpy as np import pytest + +numba = pytest.importorskip("numba") + import pytensor.scalar as aes import pytensor.scalar.math as aesm import pytensor.tensor as at diff --git a/tests/link/numba/test_cython_support.py b/tests/link/numba/test_cython_support.py index b96a22098f..65d1947c9d 100644 --- a/tests/link/numba/test_cython_support.py +++ b/tests/link/numba/test_cython_support.py @@ -1,6 +1,11 @@ import numpy as np import pytest import scipy.special.cython_special + + +numba = pytest.importorskip("numba") + + from numba.types import float32, float64, int32, int64 from pytensor.link.numba.dispatch.cython_support import Signature, wrap_cython_function diff --git a/tests/link/numba/test_performance.py b/tests/link/numba/test_performance.py index 4bddd70d3c..e5bf2a7f96 100644 --- a/tests/link/numba/test_performance.py +++ b/tests/link/numba/test_performance.py @@ -3,6 +3,9 @@ import numpy as np import pytest + +pytest.importorskip("numba") + import pytensor.tensor as aet from pytensor import config from pytensor.compile.function import function @@ -70,4 +73,5 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals): mean_numpy_time = np.mean(numpy_times) # mean_c_time = np.mean(c_times) + # FIXME: Why are we asserting >=? Numba could be doing worse than numpy! assert mean_numba_time / mean_numpy_time >= 0.75 diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py index 482aec9558..6a01a5db76 100644 --- a/tests/link/numba/test_sparse.py +++ b/tests/link/numba/test_sparse.py @@ -1,8 +1,11 @@ -import numba import numpy as np import pytest import scipy as sp + +numba = pytest.importorskip("numba") + + # Make sure the Numba customizations are loaded import pytensor.link.numba.dispatch.sparse # noqa: F401 from pytensor import config From 7000e36dcbd76dfcbfe84f5bc8706bc7e5234c6a Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Thu, 1 Jun 2023 22:07:59 +0200 Subject: [PATCH 4/4] Fix comment --- .github/workflows/test.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5115fa67e5..8fc482fd95 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -140,7 +140,9 @@ jobs: shell: bash -l {0} run: | mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy -# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but not numpy, even though scipy 1.7 requires numpy<1.23. When installing PyTensor next, pip installs a lower version of numpy via the PyPI. + # numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but + # not numpy, even though scipy 1.7 requires numpy<1.23. When installing + # PyTensor next, pip installs a lower version of numpy via the PyPI. if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi