From 53ec0dce3bc9ceeb2a520a0066e4b595df38f232 Mon Sep 17 00:00:00 2001 From: sfonxu Date: Wed, 23 Jul 2025 12:59:58 +0200 Subject: [PATCH 1/6] Add interpolate function and replace it in shallow-water examples --- PyMPDATA/impl/interpolate.py | 6 ++++++ .../Jarecka_et_al_2015/simulation.py | 11 ++--------- scenarios_mpi/shallow_water.py | 12 ++---------- 3 files changed, 10 insertions(+), 19 deletions(-) create mode 100644 PyMPDATA/impl/interpolate.py diff --git a/PyMPDATA/impl/interpolate.py b/PyMPDATA/impl/interpolate.py new file mode 100644 index 00000000..789d0a46 --- /dev/null +++ b/PyMPDATA/impl/interpolate.py @@ -0,0 +1,6 @@ +import numpy as np + + +def interpolate(psi, axis): + idx = ((slice(None, -1), slice(None, None)), (slice(None, None), slice(None, -1))) + return np.diff(psi, axis=axis) / 2 * psi[idx[axis]] diff --git a/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py b/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py index ff9b44bc..e254becc 100644 --- a/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py +++ b/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py @@ -3,6 +3,7 @@ from PyMPDATA import ScalarField, Solver, Stepper, VectorField from PyMPDATA.boundary_conditions import Constant +from PyMPDATA.impl.interpolate import interpolate class Simulation: @@ -36,14 +37,6 @@ def __init__(self, settings): k: Solver(stepper, v, self.advector) for k, v in advectees.items() } - @staticmethod - def interpolate(psi, axis): - idx = ( - (slice(None, -1), slice(None, None)), - (slice(None, None), slice(None, -1)), - ) - return np.diff(psi, axis=axis) / 2 + psi[idx[axis]] - def run(self): s = self.settings grid_step = (s.dx, s.dy) @@ -57,7 +50,7 @@ def run(self): vel = np.where(mask, np.nan, 0) np.divide(self.solvers[k].advectee.get(), h, where=mask, out=vel) self.advector.get_component(xy)[idx[xy]] = ( - self.interpolate(vel, axis=xy) * s.dt / grid_step[xy] + interpolate(vel, axis=xy) * s.dt / grid_step[xy] ) self.solvers["h"].advance(1) assert h.ctypes.data == self.solvers["h"].advectee.get().ctypes.data diff --git a/scenarios_mpi/shallow_water.py b/scenarios_mpi/shallow_water.py index 1ec9bec9..98772ac1 100644 --- a/scenarios_mpi/shallow_water.py +++ b/scenarios_mpi/shallow_water.py @@ -10,6 +10,7 @@ from PyMPDATA.boundary_conditions import Periodic from PyMPDATA.impl.domain_decomposition import make_subdomain from PyMPDATA.impl.enumerations import INNER, OUTER +from PyMPDATA.impl.interpolate import interpolate from scenarios_mpi._scenario import _Scenario subdomain = make_subdomain(jit_flags={}) @@ -122,15 +123,6 @@ def initial_condition(x, y, lx, ly): k: Solver(stepper, v, self.advector) for k, v in advectees.items() } - @staticmethod - def interpolate(psi, axis): - """Method that does simple interpolation of given field""" - idx = ( - (slice(None, -1), slice(None, None)), - (slice(None, None), slice(None, -1)), - ) - return np.diff(psi, axis=axis) / 2 + psi[idx[axis]] - def __getitem__(self, key): return self.solvers[key].advectee.get() @@ -152,7 +144,7 @@ def _solver_advance(self, n_steps): ].advectee._debug_fill_halos(self.traversals, range(self.n_threads)) np.divide(self.data(k), self.data("h"), where=mask, out=vel) self.advector.data[xy][:] = ( - self.interpolate(vel, xy) * self.dt / grid_step[xy] + interpolate(vel, xy) * self.dt / grid_step[xy] ) self.solvers["h"].advance(1) self.solvers[ # pylint: disable=protected-access From be4ec4b6b44cf00729e3beec4cc6ff35dd791f38 Mon Sep 17 00:00:00 2001 From: sfonxu Date: Wed, 23 Jul 2025 17:17:27 +0200 Subject: [PATCH 2/6] Make interpolate JIT-compilable --- PyMPDATA/impl/interpolate.py | 16 +++++++++++++--- .../Jarecka_et_al_2015/simulation.py | 6 ++++-- scenarios_mpi/shallow_water.py | 6 ++++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/PyMPDATA/impl/interpolate.py b/PyMPDATA/impl/interpolate.py index 789d0a46..deef25f7 100644 --- a/PyMPDATA/impl/interpolate.py +++ b/PyMPDATA/impl/interpolate.py @@ -1,6 +1,16 @@ +import numba import numpy as np -def interpolate(psi, axis): - idx = ((slice(None, -1), slice(None, None)), (slice(None, None), slice(None, -1))) - return np.diff(psi, axis=axis) / 2 * psi[idx[axis]] +def make_interpolate(options): + """Function that returns JIT-compilable interpolate function""" + + @numba.njit(**options.jit_flags) + def interpolate(psi, axis): + idx = ( + (slice(None, -1), slice(None, None)), + (slice(None, None), slice(None, -1)), + ) + return np.diff(psi, axis=axis) / 2 * psi[idx[axis]] + + return interpolate diff --git a/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py b/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py index e254becc..5b72f15d 100644 --- a/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py +++ b/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py @@ -3,7 +3,7 @@ from PyMPDATA import ScalarField, Solver, Stepper, VectorField from PyMPDATA.boundary_conditions import Constant -from PyMPDATA.impl.interpolate import interpolate +from PyMPDATA.impl.interpolate import make_interpolate class Simulation: @@ -37,6 +37,8 @@ def __init__(self, settings): k: Solver(stepper, v, self.advector) for k, v in advectees.items() } + self.interpolate = make_interpolate(settings.options.jit_flags) + def run(self): s = self.settings grid_step = (s.dx, s.dy) @@ -50,7 +52,7 @@ def run(self): vel = np.where(mask, np.nan, 0) np.divide(self.solvers[k].advectee.get(), h, where=mask, out=vel) self.advector.get_component(xy)[idx[xy]] = ( - interpolate(vel, axis=xy) * s.dt / grid_step[xy] + self.interpolate(vel, axis=xy) * s.dt / grid_step[xy] ) self.solvers["h"].advance(1) assert h.ctypes.data == self.solvers["h"].advectee.get().ctypes.data diff --git a/scenarios_mpi/shallow_water.py b/scenarios_mpi/shallow_water.py index 98772ac1..6e34d71e 100644 --- a/scenarios_mpi/shallow_water.py +++ b/scenarios_mpi/shallow_water.py @@ -10,7 +10,7 @@ from PyMPDATA.boundary_conditions import Periodic from PyMPDATA.impl.domain_decomposition import make_subdomain from PyMPDATA.impl.enumerations import INNER, OUTER -from PyMPDATA.impl.interpolate import interpolate +from PyMPDATA.impl.interpolate import make_interpolate from scenarios_mpi._scenario import _Scenario subdomain = make_subdomain(jit_flags={}) @@ -123,6 +123,8 @@ def initial_condition(x, y, lx, ly): k: Solver(stepper, v, self.advector) for k, v in advectees.items() } + self.interpolate = make_interpolate(mpdata_options.jit_flags) + def __getitem__(self, key): return self.solvers[key].advectee.get() @@ -144,7 +146,7 @@ def _solver_advance(self, n_steps): ].advectee._debug_fill_halos(self.traversals, range(self.n_threads)) np.divide(self.data(k), self.data("h"), where=mask, out=vel) self.advector.data[xy][:] = ( - interpolate(vel, xy) * self.dt / grid_step[xy] + self.interpolate(vel, xy) * self.dt / grid_step[xy] ) self.solvers["h"].advance(1) self.solvers[ # pylint: disable=protected-access From d11158b2bba173a1fe81195c2cb22a7e5ac50194 Mon Sep 17 00:00:00 2001 From: sfonxu Date: Wed, 23 Jul 2025 17:27:52 +0200 Subject: [PATCH 3/6] Syntax fixes --- PyMPDATA/impl/interpolate.py | 2 ++ examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py | 2 +- scenarios_mpi/shallow_water.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/PyMPDATA/impl/interpolate.py b/PyMPDATA/impl/interpolate.py index deef25f7..6868f97b 100644 --- a/PyMPDATA/impl/interpolate.py +++ b/PyMPDATA/impl/interpolate.py @@ -1,3 +1,5 @@ +"""Interpolation formulae sketch""" + import numba import numpy as np diff --git a/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py b/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py index 5b72f15d..ab9541d0 100644 --- a/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py +++ b/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py @@ -37,7 +37,7 @@ def __init__(self, settings): k: Solver(stepper, v, self.advector) for k, v in advectees.items() } - self.interpolate = make_interpolate(settings.options.jit_flags) + self.interpolate = make_interpolate(settings.options) def run(self): s = self.settings diff --git a/scenarios_mpi/shallow_water.py b/scenarios_mpi/shallow_water.py index 6e34d71e..06bc737f 100644 --- a/scenarios_mpi/shallow_water.py +++ b/scenarios_mpi/shallow_water.py @@ -123,7 +123,7 @@ def initial_condition(x, y, lx, ly): k: Solver(stepper, v, self.advector) for k, v in advectees.items() } - self.interpolate = make_interpolate(mpdata_options.jit_flags) + self.interpolate = make_interpolate(mpdata_options) def __getitem__(self, key): return self.solvers[key].advectee.get() From 9f7f0981556c6e9250bcfc36d49a751524a6916c Mon Sep 17 00:00:00 2001 From: sfonxu Date: Wed, 23 Jul 2025 17:35:14 +0200 Subject: [PATCH 4/6] Disable too-few-public-methods warning in SWE example --- examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py b/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py index ab9541d0..d62cf72a 100644 --- a/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py +++ b/examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py @@ -7,6 +7,7 @@ class Simulation: + # pylint: disable=too-few-public-methods def __init__(self, settings): self.settings = settings s = settings From 85dde036c1614933a3130a8ba8aed73ee8de01d3 Mon Sep 17 00:00:00 2001 From: sfonxu Date: Wed, 23 Jul 2025 18:47:14 +0200 Subject: [PATCH 5/6] Make a simple JIT-compatible interpolate by hand --- PyMPDATA/impl/interpolate.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/PyMPDATA/impl/interpolate.py b/PyMPDATA/impl/interpolate.py index 6868f97b..d39cdc7e 100644 --- a/PyMPDATA/impl/interpolate.py +++ b/PyMPDATA/impl/interpolate.py @@ -13,6 +13,13 @@ def interpolate(psi, axis): (slice(None, -1), slice(None, None)), (slice(None, None), slice(None, -1)), ) - return np.diff(psi, axis=axis) / 2 * psi[idx[axis]] + s1 = 2 * [slice(None)] + s2 = 2 * [slice(None)] + s1[axis] = slice(1, None) + s2[axis] = slice(None, -1) + s1 = tuple(s1) + s2 = tuple(s2) + out = psi[s1] - psi[s2] + return out / 2 * psi[idx[axis]] return interpolate From 5a0d9736625010c2c0898997631e8f28b3465f5d Mon Sep 17 00:00:00 2001 From: sfonxu Date: Fri, 25 Jul 2025 20:44:04 +0200 Subject: [PATCH 6/6] Add ati indexer, add support for multiple advectees in Solver, add unit tests for indexers --- PyMPDATA/impl/indexers.py | 43 +++++++++- PyMPDATA/solver.py | 29 ++++--- PyMPDATA/stepper.py | 134 ++++++++++++++++++-------------- tests/unit_tests/test_solver.py | 14 ++++ 4 files changed, 147 insertions(+), 73 deletions(-) diff --git a/PyMPDATA/impl/indexers.py b/PyMPDATA/impl/indexers.py index f40fc66d..f55f75eb 100644 --- a/PyMPDATA/impl/indexers.py +++ b/PyMPDATA/impl/indexers.py @@ -31,6 +31,14 @@ def ats_1d(focus, arr, k, _=INVALID_INDEX, __=INVALID_INDEX): def atv_1d(focus, arrs, k, _=INVALID_INDEX, __=INVALID_INDEX): return arrs[INNER][focus[INNER] + int(k - 0.5)] + @staticmethod + @numba.njit(**jit_flags) + def ati_1d(focus, arrs, k, _=INVALID_INDEX, __=INVALID_INDEX): + return ( + arrs[INNER][focus[INNER] + int(k - 0.5)] + + arrs[INNER][focus[INNER] + int(k + 0.5)] + ) / 2 + @staticmethod @numba.njit(**jit_flags) def set(arr, _, __, k, value): @@ -70,6 +78,30 @@ def atv_axis1(focus, arrs, k, i=0, _=INVALID_INDEX): dim, _ii, _kk = OUTER, int(i - 0.5), int(k) return arrs[dim][focus[OUTER] + _ii, focus[INNER] + _kk] + @staticmethod + @numba.njit(**jit_flags) + def ati_axis0(focus, arrs, i, k=0, _=INVALID_INDEX): + if _is_integral(i): + dim, _ii, _kk = INNER, int(i), int(k - 0.5) + else: + dim, _ii, _kk = OUTER, int(i - 0.5), int(k) + return ( + arrs[dim][focus[OUTER] + _ii, focus[INNER] + _kk] + + arrs[dim][focus[OUTER] + _ii + 1, focus[INNER] + _kk] + ) / 2 + + @staticmethod + @numba.njit(**jit_flags) + def ati_axis1(focus, arrs, k, i=0, _=INVALID_INDEX): + if _is_integral(i): + dim, _ii, _kk = INNER, int(i), int(k - 0.5) + else: + dim, _ii, _kk = OUTER, int(i - 0.5), int(k) + return ( + arrs[dim][focus[OUTER] + _ii, focus[INNER] + _kk] + + arrs[dim][focus[OUTER] + _ii, focus[INNER] + _kk + 1] + ) / 2 + @staticmethod @numba.njit(**jit_flags) def set(arr, i, _, k, value): @@ -140,17 +172,23 @@ def get(arr, i, j, k): return arr[i, j, k] Indexers = namedtuple( # pylint: disable=invalid-name - Path(__file__).stem + "_Indexers", ("ats", "atv", "set", "get", "n_dims") + Path(__file__).stem + "_Indexers", ("ats", "atv", "ati", "set", "get", "n_dims") ) indexers = ( None, Indexers( - (None, None, _1D.ats_1d), (None, None, _1D.atv_1d), _1D.set, _1D.get, 1 + (None, None, _1D.ats_1d), + (None, None, _1D.atv_1d), + (None, None, _1D.ati_1d), + _1D.set, + _1D.get, + 1, ), Indexers( (_2D.ats_axis0, None, _2D.ats_axis1), (_2D.atv_axis0, None, _2D.atv_axis1), + (_2D.ati_axis0, None, _2D.ati_axis1), _2D.set, _2D.get, 2, @@ -158,6 +196,7 @@ def get(arr, i, j, k): Indexers( (_3D.ats_axis0, _3D.ats_axis1, _3D.ats_axis2), (_3D.atv_axis0, _3D.atv_axis1, _3D.atv_axis2), + (None, None, None), _3D.set, _3D.get, 3, diff --git a/PyMPDATA/solver.py b/PyMPDATA/solver.py index 8e6fff61..7ae1dea2 100644 --- a/PyMPDATA/solver.py +++ b/PyMPDATA/solver.py @@ -3,7 +3,7 @@ class grouping user-supplied stepper, fields and post-step/post-iter hooks, as well as self-initialised temporary storage """ -from typing import Union +from typing import Iterable, Union import numba @@ -14,7 +14,7 @@ class grouping user-supplied stepper, fields and post-step/post-iter hooks, @numba.experimental.jitclass([]) -class PostStepNull: # pylint: disable=too-few-public-methods +class AntePostStepNull: # pylint: disable=too-few-public-methods """do-nothing version of the post-step hook""" def __init__(self): @@ -48,15 +48,18 @@ class Solver: def __init__( self, stepper: Stepper, - advectee: ScalarField, + advectee: [ScalarField, Iterable[ScalarField]], advector: VectorField, g_factor: [ScalarField, None] = None, ): + if isinstance(advectee, ScalarField): + advectee = (advectee,) + def scalar_field(dtype=None): - return ScalarField.clone(advectee, dtype=dtype) + return ScalarField.clone(advectee[0], dtype=dtype) def null_scalar_field(): - return ScalarField.make_null(advectee.n_dims, stepper.traversals) + return ScalarField.make_null(advectee[0].n_dims, stepper.traversals) def vector_field(): return VectorField.clone(advector) @@ -64,7 +67,7 @@ def vector_field(): def null_vector_field(): return VectorField.make_null(advector.n_dims, stepper.traversals) - for field in [advector, advectee] + ( + for field in [advector, *advectee] + ( [g_factor] if g_factor is not None else [] ): assert field.halo == stepper.options.n_halo @@ -93,16 +96,17 @@ def null_vector_field(): else null_scalar_field() ), } - for field in self.__fields.values(): - field.assemble(stepper.traversals) + for key, value in self.__fields.items(): + for field in (value,) if key != "advectee" else value: + field.assemble(stepper.traversals) self.__stepper = stepper @property - def advectee(self) -> ScalarField: + def advectee(self, index=0) -> ScalarField: """advectee scalar field (with halo), modified by advance(), may be modified from user code (e.g., source-term handling)""" - return self.__fields["advectee"] + return self.__fields["advectee"][index] @property def advector(self) -> VectorField: @@ -126,6 +130,7 @@ def advance( self, n_steps: int, mu_coeff: Union[tuple, None] = None, + ante_step=None, post_step=None, post_iter=None, ): @@ -144,12 +149,14 @@ def advance( ): raise NotImplementedError() - post_step = post_step or PostStepNull() + ante_step = ante_step or AntePostStepNull() + post_step = post_step or AntePostStepNull() post_iter = post_iter or PostIterNull() return self.__stepper( n_steps=n_steps, mu_coeff=mu_coeff, + ante_step=ante_step, post_step=post_step, post_iter=post_iter, fields=self.__fields, diff --git a/PyMPDATA/stepper.py b/PyMPDATA/stepper.py index 34c70f7a..d6ca6834 100644 --- a/PyMPDATA/stepper.py +++ b/PyMPDATA/stepper.py @@ -108,18 +108,28 @@ def n_dims(self) -> int: """dimensionality (1, 2 or 3)""" return self.__n_dims - def __call__(self, *, n_steps, mu_coeff, post_step, post_iter, fields): + def __call__(self, *, n_steps, mu_coeff, ante_step, post_step, post_iter, fields): assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads with warnings.catch_warnings(): warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning) wall_time_per_timestep = self.__call( n_steps, mu_coeff, + ante_step, post_step, post_iter, *( - _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC]) - for v in fields.values() + ( + _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC]) + if k != "advectee" + else tuple( + _Impl( + field=vv.impl[IMPL_META_AND_DATA], bc=vv.impl[IMPL_BC] + ) + for vv in v + ) + ) + for k, v in fields.items() ), self.traversals.data, ) @@ -163,9 +173,10 @@ def make_step_impl( def step( n_steps, mu_coeff, + ante_step, post_step, post_iter, - advectee, + advectees, advector, g_factor, vectmp_a, @@ -177,65 +188,68 @@ def step( ): time = clock() for step in range(n_steps): - if non_zero_mu_coeff: - advector_orig = advector - advector = vectmp_c - for iteration in range(n_iters): - if iteration == 0: - if nonoscillatory: - nonoscillatory_psi_extrema(null_impl, psi_extrema, advectee) - if non_zero_mu_coeff: - laplacian(null_impl, advector, advectee) - axpy( - *advector.field, - mu_coeff, - *advector.field, - *advector_orig.field, - ) - flux_first_pass(null_impl, vectmp_a, advector, advectee) - flux = vectmp_a - else: - if iteration == 1: - advector_oscil = advector - advector_nonos = vectmp_a - flux = vectmp_b - elif iteration % 2 == 0: - advector_oscil = vectmp_a - advector_nonos = vectmp_b + for advectee in advectees: + ante_step.call(advectee.field[ARG_DATA], step) + if non_zero_mu_coeff: + advector_orig = advector + advector = vectmp_c + for iteration in range(n_iters): + if iteration == 0: + if nonoscillatory: + nonoscillatory_psi_extrema(null_impl, psi_extrema, advectee) + if non_zero_mu_coeff: + laplacian(null_impl, advector, advectee) + axpy( + *advector.field, + mu_coeff, + *advector.field, + *advector_orig.field, + ) + flux_first_pass(null_impl, vectmp_a, advector, advectee) flux = vectmp_a else: - advector_oscil = vectmp_b - advector_nonos = vectmp_a - flux = vectmp_b - if iteration < n_iters - 1: - antidiff( - null_impl, - advector_nonos, - advectee, - advector_oscil, - g_factor, - ) - else: - antidiff_last_pass( - null_impl, - advector_nonos, - advectee, - advector_oscil, - g_factor, - ) - flux_subsequent(null_impl, flux, advectee, advector_nonos) - if nonoscillatory: - nonoscillatory_beta( - null_impl, beta, flux, advectee, psi_extrema, g_factor - ) - # note: in libmpdata++, the oscillatory advector from prev iter is used - nonoscillatory_correction(null_impl, advector_nonos, beta) + if iteration == 1: + advector_oscil = advector + advector_nonos = vectmp_a + flux = vectmp_b + elif iteration % 2 == 0: + advector_oscil = vectmp_a + advector_nonos = vectmp_b + flux = vectmp_a + else: + advector_oscil = vectmp_b + advector_nonos = vectmp_a + flux = vectmp_b + if iteration < n_iters - 1: + antidiff( + null_impl, + advector_nonos, + advectee, + advector_oscil, + g_factor, + ) + else: + antidiff_last_pass( + null_impl, + advector_nonos, + advectee, + advector_oscil, + g_factor, + ) flux_subsequent(null_impl, flux, advectee, advector_nonos) - upwind(null_impl, advectee, flux, g_factor) - post_iter.call(flux.field, g_factor.field, step, iteration) - if non_zero_mu_coeff: - advector = advector_orig - post_step.call(advectee.field[ARG_DATA], step) + if nonoscillatory: + nonoscillatory_beta( + null_impl, beta, flux, advectee, psi_extrema, g_factor + ) + # note: in libmpdata++, the oscillatory advector from prev iter is used + nonoscillatory_correction(null_impl, advector_nonos, beta) + flux_subsequent(null_impl, flux, advectee, advector_nonos) + upwind(null_impl, advectee, flux, g_factor) + post_iter.call(flux.field, g_factor.field, step, iteration) + if non_zero_mu_coeff: + advector = advector_orig + + post_step.call(advectee.field[ARG_DATA], step) return (clock() - time) / n_steps if n_steps > 0 else np.nan return step, traversals diff --git a/tests/unit_tests/test_solver.py b/tests/unit_tests/test_solver.py index 23baf7a0..7a3ae9ce 100644 --- a/tests/unit_tests/test_solver.py +++ b/tests/unit_tests/test_solver.py @@ -35,3 +35,17 @@ def test_mu_arg_handling(case): sut = Solver(stepper, advectee, advector, case["g_factor"]) sut.advance(1, mu_coeff=case["mu"]) + + +def test_multiple_scalar_fields(): + opt = Options() + data = np.asarray([4.0, 5]) + advector = VectorField((np.asarray([1.0, 2, 3]),), opt.n_halo, BCS) + advectees = [ScalarField(data, opt.n_halo, BCS)] * 5 + stepper = Stepper(options=opt, n_dims=1) + sut = Solver(stepper, advectees, advector) + + sut.advance(1) + + for advectee in advectees: + assert (advectee.get() != data).all()