Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions PyMPDATA/impl/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def ats_1d(focus, arr, k, _=INVALID_INDEX, __=INVALID_INDEX):
def atv_1d(focus, arrs, k, _=INVALID_INDEX, __=INVALID_INDEX):
return arrs[INNER][focus[INNER] + int(k - 0.5)]

@staticmethod
@numba.njit(**jit_flags)
def ati_1d(focus, arrs, k, _=INVALID_INDEX, __=INVALID_INDEX):
return (
arrs[INNER][focus[INNER] + int(k - 0.5)]
+ arrs[INNER][focus[INNER] + int(k + 0.5)]
) / 2

@staticmethod
@numba.njit(**jit_flags)
def set(arr, _, __, k, value):
Expand Down Expand Up @@ -70,6 +78,30 @@ def atv_axis1(focus, arrs, k, i=0, _=INVALID_INDEX):
dim, _ii, _kk = OUTER, int(i - 0.5), int(k)
return arrs[dim][focus[OUTER] + _ii, focus[INNER] + _kk]

@staticmethod
@numba.njit(**jit_flags)
def ati_axis0(focus, arrs, i, k=0, _=INVALID_INDEX):
if _is_integral(i):
dim, _ii, _kk = INNER, int(i), int(k - 0.5)
else:
dim, _ii, _kk = OUTER, int(i - 0.5), int(k)
return (
arrs[dim][focus[OUTER] + _ii, focus[INNER] + _kk]
+ arrs[dim][focus[OUTER] + _ii + 1, focus[INNER] + _kk]
) / 2

@staticmethod
@numba.njit(**jit_flags)
def ati_axis1(focus, arrs, k, i=0, _=INVALID_INDEX):
if _is_integral(i):
dim, _ii, _kk = INNER, int(i), int(k - 0.5)
else:
dim, _ii, _kk = OUTER, int(i - 0.5), int(k)
return (
arrs[dim][focus[OUTER] + _ii, focus[INNER] + _kk]
+ arrs[dim][focus[OUTER] + _ii, focus[INNER] + _kk + 1]
) / 2

@staticmethod
@numba.njit(**jit_flags)
def set(arr, i, _, k, value):
Expand Down Expand Up @@ -140,24 +172,31 @@ def get(arr, i, j, k):
return arr[i, j, k]

Indexers = namedtuple( # pylint: disable=invalid-name
Path(__file__).stem + "_Indexers", ("ats", "atv", "set", "get", "n_dims")
Path(__file__).stem + "_Indexers", ("ats", "atv", "ati", "set", "get", "n_dims")
)

indexers = (
None,
Indexers(
(None, None, _1D.ats_1d), (None, None, _1D.atv_1d), _1D.set, _1D.get, 1
(None, None, _1D.ats_1d),
(None, None, _1D.atv_1d),
(None, None, _1D.ati_1d),
_1D.set,
_1D.get,
1,
),
Indexers(
(_2D.ats_axis0, None, _2D.ats_axis1),
(_2D.atv_axis0, None, _2D.atv_axis1),
(_2D.ati_axis0, None, _2D.ati_axis1),
_2D.set,
_2D.get,
2,
),
Indexers(
(_3D.ats_axis0, _3D.ats_axis1, _3D.ats_axis2),
(_3D.atv_axis0, _3D.atv_axis1, _3D.atv_axis2),
(None, None, None),
_3D.set,
_3D.get,
3,
Expand Down
25 changes: 25 additions & 0 deletions PyMPDATA/impl/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Interpolation formulae sketch"""

import numba
import numpy as np


def make_interpolate(options):
"""Function that returns JIT-compilable interpolate function"""

@numba.njit(**options.jit_flags)
def interpolate(psi, axis):
idx = (
(slice(None, -1), slice(None, None)),
(slice(None, None), slice(None, -1)),
)
s1 = 2 * [slice(None)]
s2 = 2 * [slice(None)]
s1[axis] = slice(1, None)
s2[axis] = slice(None, -1)
s1 = tuple(s1)
s2 = tuple(s2)
out = psi[s1] - psi[s2]
return out / 2 * psi[idx[axis]]

return interpolate
29 changes: 18 additions & 11 deletions PyMPDATA/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ class grouping user-supplied stepper, fields and post-step/post-iter hooks,
as well as self-initialised temporary storage
"""

from typing import Union
from typing import Iterable, Union

import numba

Expand All @@ -14,7 +14,7 @@ class grouping user-supplied stepper, fields and post-step/post-iter hooks,


@numba.experimental.jitclass([])
class PostStepNull: # pylint: disable=too-few-public-methods
class AntePostStepNull: # pylint: disable=too-few-public-methods
"""do-nothing version of the post-step hook"""

def __init__(self):
Expand Down Expand Up @@ -48,23 +48,26 @@ class Solver:
def __init__(
self,
stepper: Stepper,
advectee: ScalarField,
advectee: [ScalarField, Iterable[ScalarField]],
advector: VectorField,
g_factor: [ScalarField, None] = None,
):
if isinstance(advectee, ScalarField):
advectee = (advectee,)

def scalar_field(dtype=None):
return ScalarField.clone(advectee, dtype=dtype)
return ScalarField.clone(advectee[0], dtype=dtype)

def null_scalar_field():
return ScalarField.make_null(advectee.n_dims, stepper.traversals)
return ScalarField.make_null(advectee[0].n_dims, stepper.traversals)

def vector_field():
return VectorField.clone(advector)

def null_vector_field():
return VectorField.make_null(advector.n_dims, stepper.traversals)

for field in [advector, advectee] + (
for field in [advector, *advectee] + (
[g_factor] if g_factor is not None else []
):
assert field.halo == stepper.options.n_halo
Expand Down Expand Up @@ -93,16 +96,17 @@ def null_vector_field():
else null_scalar_field()
),
}
for field in self.__fields.values():
field.assemble(stepper.traversals)
for key, value in self.__fields.items():
for field in (value,) if key != "advectee" else value:
field.assemble(stepper.traversals)

self.__stepper = stepper

@property
def advectee(self) -> ScalarField:
def advectee(self, index=0) -> ScalarField:
"""advectee scalar field (with halo), modified by advance(),
may be modified from user code (e.g., source-term handling)"""
return self.__fields["advectee"]
return self.__fields["advectee"][index]

@property
def advector(self) -> VectorField:
Expand All @@ -126,6 +130,7 @@ def advance(
self,
n_steps: int,
mu_coeff: Union[tuple, None] = None,
ante_step=None,
post_step=None,
post_iter=None,
):
Expand All @@ -144,12 +149,14 @@ def advance(
):
raise NotImplementedError()

post_step = post_step or PostStepNull()
ante_step = ante_step or AntePostStepNull()
post_step = post_step or AntePostStepNull()
post_iter = post_iter or PostIterNull()

return self.__stepper(
n_steps=n_steps,
mu_coeff=mu_coeff,
ante_step=ante_step,
post_step=post_step,
post_iter=post_iter,
fields=self.__fields,
Expand Down
134 changes: 74 additions & 60 deletions PyMPDATA/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,28 @@ def n_dims(self) -> int:
"""dimensionality (1, 2 or 3)"""
return self.__n_dims

def __call__(self, *, n_steps, mu_coeff, post_step, post_iter, fields):
def __call__(self, *, n_steps, mu_coeff, ante_step, post_step, post_iter, fields):
assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning)
wall_time_per_timestep = self.__call(
n_steps,
mu_coeff,
ante_step,
post_step,
post_iter,
*(
_Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC])
for v in fields.values()
(
_Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC])
if k != "advectee"
else tuple(
_Impl(
field=vv.impl[IMPL_META_AND_DATA], bc=vv.impl[IMPL_BC]
)
for vv in v
)
)
for k, v in fields.items()
),
self.traversals.data,
)
Expand Down Expand Up @@ -163,9 +173,10 @@ def make_step_impl(
def step(
n_steps,
mu_coeff,
ante_step,
post_step,
post_iter,
advectee,
advectees,
advector,
g_factor,
vectmp_a,
Expand All @@ -177,65 +188,68 @@ def step(
):
time = clock()
for step in range(n_steps):
if non_zero_mu_coeff:
advector_orig = advector
advector = vectmp_c
for iteration in range(n_iters):
if iteration == 0:
if nonoscillatory:
nonoscillatory_psi_extrema(null_impl, psi_extrema, advectee)
if non_zero_mu_coeff:
laplacian(null_impl, advector, advectee)
axpy(
*advector.field,
mu_coeff,
*advector.field,
*advector_orig.field,
)
flux_first_pass(null_impl, vectmp_a, advector, advectee)
flux = vectmp_a
else:
if iteration == 1:
advector_oscil = advector
advector_nonos = vectmp_a
flux = vectmp_b
elif iteration % 2 == 0:
advector_oscil = vectmp_a
advector_nonos = vectmp_b
for advectee in advectees:
ante_step.call(advectee.field[ARG_DATA], step)
if non_zero_mu_coeff:
advector_orig = advector
advector = vectmp_c
for iteration in range(n_iters):
if iteration == 0:
if nonoscillatory:
nonoscillatory_psi_extrema(null_impl, psi_extrema, advectee)
if non_zero_mu_coeff:
laplacian(null_impl, advector, advectee)
axpy(
*advector.field,
mu_coeff,
*advector.field,
*advector_orig.field,
)
flux_first_pass(null_impl, vectmp_a, advector, advectee)
flux = vectmp_a
else:
advector_oscil = vectmp_b
advector_nonos = vectmp_a
flux = vectmp_b
if iteration < n_iters - 1:
antidiff(
null_impl,
advector_nonos,
advectee,
advector_oscil,
g_factor,
)
else:
antidiff_last_pass(
null_impl,
advector_nonos,
advectee,
advector_oscil,
g_factor,
)
flux_subsequent(null_impl, flux, advectee, advector_nonos)
if nonoscillatory:
nonoscillatory_beta(
null_impl, beta, flux, advectee, psi_extrema, g_factor
)
# note: in libmpdata++, the oscillatory advector from prev iter is used
nonoscillatory_correction(null_impl, advector_nonos, beta)
if iteration == 1:
advector_oscil = advector
advector_nonos = vectmp_a
flux = vectmp_b
elif iteration % 2 == 0:
advector_oscil = vectmp_a
advector_nonos = vectmp_b
flux = vectmp_a
else:
advector_oscil = vectmp_b
advector_nonos = vectmp_a
flux = vectmp_b
if iteration < n_iters - 1:
antidiff(
null_impl,
advector_nonos,
advectee,
advector_oscil,
g_factor,
)
else:
antidiff_last_pass(
null_impl,
advector_nonos,
advectee,
advector_oscil,
g_factor,
)
flux_subsequent(null_impl, flux, advectee, advector_nonos)
upwind(null_impl, advectee, flux, g_factor)
post_iter.call(flux.field, g_factor.field, step, iteration)
if non_zero_mu_coeff:
advector = advector_orig
post_step.call(advectee.field[ARG_DATA], step)
if nonoscillatory:
nonoscillatory_beta(
null_impl, beta, flux, advectee, psi_extrema, g_factor
)
# note: in libmpdata++, the oscillatory advector from prev iter is used
nonoscillatory_correction(null_impl, advector_nonos, beta)
flux_subsequent(null_impl, flux, advectee, advector_nonos)
upwind(null_impl, advectee, flux, g_factor)
post_iter.call(flux.field, g_factor.field, step, iteration)
if non_zero_mu_coeff:
advector = advector_orig

post_step.call(advectee.field[ARG_DATA], step)
return (clock() - time) / n_steps if n_steps > 0 else np.nan

return step, traversals
10 changes: 3 additions & 7 deletions examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from PyMPDATA import ScalarField, Solver, Stepper, VectorField
from PyMPDATA.boundary_conditions import Constant
from PyMPDATA.impl.interpolate import make_interpolate


class Simulation:
# pylint: disable=too-few-public-methods
def __init__(self, settings):
self.settings = settings
s = settings
Expand Down Expand Up @@ -36,13 +38,7 @@ def __init__(self, settings):
k: Solver(stepper, v, self.advector) for k, v in advectees.items()
}

@staticmethod
def interpolate(psi, axis):
idx = (
(slice(None, -1), slice(None, None)),
(slice(None, None), slice(None, -1)),
)
return np.diff(psi, axis=axis) / 2 + psi[idx[axis]]
self.interpolate = make_interpolate(settings.options)

def run(self):
s = self.settings
Expand Down
Loading
Loading