Skip to content

Commit 862222f

Browse files
committed
Run numba and jax tests in separate jobs
1 parent ec894be commit 862222f

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

.github/workflows/test.yml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ jobs:
7373
python-version: ["3.8", "3.11"]
7474
fast-compile: [0,1]
7575
float32: [0,1]
76-
install-numba: [1]
76+
install-numba: [0]
77+
install-jax: [0]
7778
part:
78-
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
79+
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse --ignore=tests/link/numba --ignore=tests/link/jax"
7980
- "tests/scan"
8081
- "tests/sparse"
8182
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py"
@@ -93,6 +94,15 @@ jobs:
9394
part: "tests/tensor/test_math.py"
9495
- fast-compile: 1
9596
float32: 1
97+
include:
98+
- install-numba: 1
99+
fast-compile: 0
100+
float32: 0
101+
part: "tests/link/numba"
102+
- install-jax: 1
103+
fast-compile: 0
104+
float32: 0
105+
part: "tests/link/jax"
96106
steps:
97107
- uses: actions/checkout@v3
98108
with:
@@ -118,8 +128,8 @@ jobs:
118128
shell: bash -l {0}
119129
run: |
120130
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
121-
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
122-
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
131+
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy "numpy=1.23"; fi
132+
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
123133
pip install -e ./
124134
mamba list && pip freeze
125135
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

0 commit comments

Comments
 (0)