diff --git a/docs/source/_static/references.bib b/docs/source/_static/references.bib index 7f5e7dfc99..06c3202e01 100644 --- a/docs/source/_static/references.bib +++ b/docs/source/_static/references.bib @@ -6,6 +6,18 @@ @comment{ This bibtex file is for references in the documentation that aren't specifically to Firedrake. } +@article{Clement1975, + author = {Cl{\'e}ment, Ph}, + journal = {Revue fran{\c{c}}aise d'automatique, informatique, recherche op{\'e}rationnelle. Analyse num{\'e}rique}, + number = {R2}, + pages = {77--84}, + publisher = {EDP Sciences}, + title = {Approximation by finite element functions using local regularization}, + url = {https://doi.org/10.1051/m2an/197509R200771}, + volume = {9}, + year = {1975} +} + @article{Farrell2013, author = {Farrell, Patrick E and Ham, David A and Funke, Simon W and Rognes, Marie E}, doi = {10.1137/120873558}, diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 05d3438318..9f9e3b8c55 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -22,7 +22,8 @@ import finat import firedrake -from firedrake import tsfc_interface, utils, functionspaceimpl +from firedrake import tsfc_interface, utils, functionspaceimpl, parloops +import firedrake.function as ffunc from firedrake.ufl_expr import Argument, action, adjoint as expr_adjoint from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError from firedrake.petsc import PETSc @@ -38,6 +39,7 @@ "DofNotDefinedError", "CrossMeshInterpolator", "SameMeshInterpolator", + "ClementInterpolator", ) @@ -1664,3 +1666,96 @@ def multHermitian(self, source_vec, target_vec): # matrix will then have rows of zeros for those points. target_vec.zeroEntries() self.reduce(source_vec, target_vec) + + +class ClementInterpolator(SameMeshInterpolator): + r""" + Compute the Clément interpolant of a :math:`\mathbb{P}0` source field, i.e., take + the volume average over neighbouring cells at each vertex. + + See :cite:`Clement1975` for details. + + For arguments, see :class:`.Interpolator`. + """ + + # NOTE: We need to overload the __new__ inherited from Interpolator because it will + # only ever return instances of SameMeshInterpolator or CrossMeshInterpolator, not + # ClementInterpolator. + def __new__(cls, *args, **kwargs): + return object.__new__(ClementInterpolator) + + @no_annotations + def __init__( + self, + expr, + V, + subset=None, + freeze_expr=False, + access=op2.WRITE, + bcs=None, + allow_missing_dofs=False, + ): + if subset: + raise NotImplementedError("subset not implemented") + if freeze_expr: + raise NotImplementedError("freeze_expr not implemented") + if access != op2.WRITE: + raise NotImplementedError("access other than op2.WRITE not implemented") + if bcs: + raise NotImplementedError("bcs not implemented") + target_mesh = as_domain(V) + source_mesh = extract_unique_domain(expr) or target_mesh + if target_mesh is not source_mesh: + raise ValueError("Clément interpolation requires the source and target meshes to coincide.") + element = V.ufl_element() + if element.family() != "Lagrange" or element.degree() != 1: + raise ValueError("Clément interpolation must target a P1 space.") + super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, access=access, + bcs=bcs, allow_missing_dofs=allow_missing_dofs) + + @PETSc.Log.EventDecorator() + def interpolate(self, function, output=None, adjoint=False): + """ + Compute the Clément interpolant applied to a source function. + + Parameters + ---------- + function: firedrake.function.Function or firedrake.cofunction.Cofunction + Function to be interpolated. + output: firedrake.function.Function or firedrake.cofunction.Cofunction + A function to contain the output. + adjoint: bool + Set to true to apply the adjoint of the interpolation + operator. + + Returns + ------- + firedrake.function.Function or firedrake.cofunction.Cofunction + The resulting interpolated function. + """ + if adjoint: + raise NotImplementedError("Adjoint of Clément interpolation not implemented.") + Vs = function.function_space() + element = Vs.ufl_element() + if not (element.family() == "Discontinuous Lagrange" and element.degree() == 0): + raise ValueError("Source function must live in P0 space.") + rank = len(Vs.value_shape) + if rank != len(self.V.value_shape): + raise ValueError(f"Rank-{rank} input inconsistent with target space.") + mesh = self.V.mesh() + if output is None: + output = ffunc.Function(self.V) + + # Take the weighted average of the source function over the neighbouring cells + domain = f"{{[i, j]: 0 <= i < out.dofs and 0 <= j < {Vs.block_size}}}" + instructions = "out[i, j] = out[i, j] + vol[0] * f[0, j]" + keys = {"f": (function, op2.READ), "vol": (mesh.cell_volume, op2.READ), "out": (output, op2.RW)} + parloops.par_loop((domain, instructions), ufl.dx(domain=mesh), keys) + + # Divide by the volume of the patch of neighbouring cells + domain = f"{{[j]: 0 <= j < {Vs.block_size}}}" + instructions = "out[0, j] = out[0, j] / patch[0]" + keys = {"patch": (mesh.patch_volume, op2.READ), "out": (output, op2.RW)} + parloops.par_loop((domain, instructions), parloops.direct, keys) + + return output diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 0be58d6d95..819e7f0b9d 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -2415,6 +2415,37 @@ def clear_cell_sizes(self): except AttributeError: pass + @utils.cached_property + def cell_volume(self): + """ + A :class:`~.Function` in the :math:`P^0` space containing the local mesh volume. + + This is computed by interpolating the UFL :class:`~.CellVolume` for the current + mesh. + """ + from firedrake.function import Function + from firedrake.functionspace import FunctionSpace + DG0 = FunctionSpace(self, "Discontinuous Lagrange", 0) + volume = Function(DG0) + return volume.interpolate(ufl.CellVolume(self)) + + @utils.cached_property + def patch_volume(self): + """ + A :class:`~.Function` in the :math:`P^1` space containing the sum of the volumes + of cells neighbouring a vertex. + """ + from firedrake.function import Function + from firedrake.functionspace import FunctionSpace + from firedrake.parloops import par_loop, READ, RW + CG1 = FunctionSpace(self, "Lagrange", 1) + patch_vol = Function(CG1) + domain = "{[i]: 0 <= i < patch.dofs}" + instructions = "patch[i] = patch[i] + vol[0]" + keys = {"vol": (self.cell_volume, READ), "patch": (patch_vol, RW)} + par_loop((domain, instructions), ufl.dx(domain=self), keys) + return patch_vol + @property def tolerance(self): """The relative tolerance (i.e. as defined on the reference cell) for diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index e8860a32e7..6707e7c224 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -566,3 +566,44 @@ def test_interpolate_logical_not(): a = assemble(interpolate(conditional(Not(x < .2), 1, 0), V)) b = assemble(interpolate(conditional(x >= .2, 1, 0), V)) assert np.allclose(a.dat.data, b.dat.data) + + +@pytest.mark.parametrize("tdim,shape", [(1, tuple()), (2, tuple()), (3, tuple()), + (1, (1,)), (2, (2,)), (2, (3,)), (3, (3,)), + (1, (1, 2)), (2, (2, 3)), (3, (2, 3))], + ids=["1d-scalar", "2d-scalar", "3d-scalar", "1d-vector", + "2d-vector", "2d-3vector", "3d-vector", "1d-matrix", + "2d-matrix", "3d-matrix"]) +def test_clement_interpolator_simplex(tdim, shape): + mesh = { + 1: UnitIntervalMesh, + 2: UnitSquareMesh, + 3: UnitCubeMesh, + }[tdim](*(5 for _ in range(tdim))) + x = SpatialCoordinate(mesh) + if len(shape) == 0: + P0 = FunctionSpace(mesh, "DG", 0) + P1 = FunctionSpace(mesh, "CG", 1) + expr = sum(x) + elif len(shape) == 1: + dim = shape[0] + P0 = VectorFunctionSpace(mesh, "DG", 0, dim=dim) + P1 = VectorFunctionSpace(mesh, "CG", 1, dim=dim) + expr = as_vector(x if dim == tdim else [x[0] for _ in range(dim)]) + else: + P0 = TensorFunctionSpace(mesh, "DG", 0, shape=shape) + P1 = TensorFunctionSpace(mesh, "CG", 1, shape=shape) + rows = [Constant(tuple(range(i+1, i+1+tdim))) for i in range(P1.block_size)] + expr = as_tensor(np.reshape([dot(row, x) for row in rows], shape)) + + # Projecting into P0 space and then applying Clement interpolation should recover + # the original function + x_P0 = assemble(project(expr, P0)) + interpolator = ClementInterpolator(TestFunction(P0), P1) + x_P1 = interpolator.interpolate(x_P0) + x_P1_direct = Function(P1).interpolate(expr) + + # Account for the fact that the Clement interpolant breaks down at domain boundaries + DirichletBC(P1, x_P1_direct, "on_boundary").apply(x_P1) + + assert np.isclose(errornorm(x_P1_direct, x_P1), 0)