diff --git a/irksome/stage_value.py b/irksome/stage_value.py index 3739c3f9..e68969fb 100644 --- a/irksome/stage_value.py +++ b/irksome/stage_value.py @@ -3,15 +3,16 @@ from FIAT import Bernstein, ufc_simplex from FIAT.barycentric_interpolation import LagrangePolynomialSet from firedrake import (Function, NonlinearVariationalProblem, - NonlinearVariationalSolver, TestFunction, dx, - inner) + NonlinearVariationalSolver, TestFunction) from ufl import as_tensor, Form from ufl.constantvalue import as_ufl from .bcs import stage2spaces4bc from .tableaux.ButcherTableaux import CollocationButcherTableau from .ufl.deriv import expand_time_derivatives -from .ufl.manipulation import split_time_derivative_terms, remove_time_derivatives +from .ufl.manipulation import (has_nonlinear_time_derivative, + split_time_derivative_terms, + remove_time_derivatives) from .tools import AI, dot, reshape, replace from .constant import vecconst from .base_time_stepper import StageCoupledTimeStepper @@ -189,20 +190,30 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, if use_collocation_update: # Use the terminal value of the collocation polynomial to update the solution. Note: collocation update is only implemented for constant-in-time boundary conditions. # TODO: create an assertion to check for constant-in-time boundary conditions. - self.collocation_vander = self.tabulate_poly((1.0,)) self._update = self._update_collocation elif (not butcher_tableau.is_stiffly_accurate) and (vandermonde is None): - try: - A = butcher_tableau.A - b = butcher_tableau.b - self.bAinv = vecconst(numpy.linalg.solve(A.T, b)) - self.update_scale = 1-numpy.sum(self.bAinv) - self._update = self._update_Ainv - except numpy.linalg.LinAlgError: + # For nonlinear Dt(g(u)) we need the conservative variational + # update; the bAinv linear combination of stages is correct + # for g = identity but breaks for nonlinear g. For the + # linear case we keep the bAinv shortcut: it is conservative + # by construction (g = id makes both formulations agree) + # AND it handles DAEs correctly, which the conservative + # variational update does not. + if has_nonlinear_time_derivative(F, u0): self.unew, self.update_solver = self.get_update_solver(update_solver_parameters) self._update = self._update_general + else: + try: + A = butcher_tableau.A + b = butcher_tableau.b + self.bAinv = vecconst(numpy.linalg.solve(A.T, b)) + self.update_scale = 1-numpy.sum(self.bAinv) + self._update = self._update_Ainv + except numpy.linalg.LinAlgError: + self.unew, self.update_solver = self.get_update_solver(update_solver_parameters) + self._update = self._update_general else: self._update = self._update_stiff_acc @@ -220,20 +231,44 @@ def _update_stiff_acc(self): u0bit.assign(self.stages.subfunctions[self.num_fields*(self.num_stages-1)+i]) def get_update_solver(self, update_solver_parameters): - # only form update stuff if we need it - # which means neither stiffly accurate nor Vandermonde - v, = self.F.arguments() - unew = Function(self.u0.function_space()) - Fupdate = inner(unew - self.u0, v) * dx - - C = vecconst(self.butcher_tableau.c) - B = vecconst(self.butcher_tableau.b) + """Build a conservative variational update solve for u_new. + + For a mass term ``inner(Dt(g(u)), v) * dx`` the update head is + + inner(g(u_new) - g(u_0), v) * dx + + evaluated at the stage-solve test function ``v``. For + ``g = identity`` it reduces to ``inner(u_new - u_0, v) * dx``, + so the discrete update equation is unchanged in the linear + case. The remaining (non-time-derivative) part of the form is + contributed by the standard RK quadrature + ``sum_i b_i * F_remainder(stage_i)``. + + ``update_solver_parameters`` does not inherit from + ``solver_parameters``. The update solve is a different + problem from the stage solve -- it is posed on ``V`` rather + than ``V^s = V x ... x V``, and its Jacobian is a (nonlinear) + weighted mass matrix rather than the stage operator. Stage- + tuned options such as fieldsplit indices, ``snes_type='ksponly'``, + lagged Jacobians, or custom multigrid transfers generally do + not apply. If ``update_solver_parameters`` is None, Firedrake's + default solver parameters are used (typically a sparse direct + solve). Pass an explicit dict to override. + """ F = self.F t = self.t dt = self.dt u0 = self.u0 + unew = Function(u0.function_space()) + split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u0,)) + F_dtless = remove_time_derivatives(split_form.time) F_remainder = expand_time_derivatives(split_form.remainder, t=t, timedep_coeffs=()) + + Fupdate = replace(F_dtless, {u0: unew}) - F_dtless + + C = vecconst(self.butcher_tableau.c) + B = vecconst(self.butcher_tableau.b) u_np = to_value(self.u0, self.stages, self.vandermonde) for i in range(self.num_stages): @@ -258,6 +293,8 @@ def get_update_solver(self, update_solver_parameters): return unew, update_solver def _update_general(self): + # Constant-in-time initial guess to prevent singular Jacobian + self.unew.assign(self.u0) self.update_solver.solve() self.u0.assign(self.unew) diff --git a/irksome/ufl/manipulation.py b/irksome/ufl/manipulation.py index c1f9772c..4122a818 100644 --- a/irksome/ufl/manipulation.py +++ b/irksome/ufl/manipulation.py @@ -11,6 +11,9 @@ from operator import or_ from typing import NamedTuple, Sequence, FrozenSet +from ufl import derivative +from ufl.algorithms import expand_derivatives +from ufl.algorithms.analysis import extract_coefficients, extract_type from ufl.corealg.traversal import traverse_unique_terminals from ufl.corealg.dag_traverser import DAGTraverser from ufl.classes import ( @@ -23,7 +26,46 @@ from .deriv import TimeDerivative -__all__ = ("SplitTimeForm", "check_integrals", "split_time_derivative_terms", "remove_time_derivatives") +__all__ = ("SplitTimeForm", "check_integrals", "split_time_derivative_terms", + "remove_time_derivatives", "has_nonlinear_time_derivative") + + +def has_nonlinear_time_derivative(F, u0): + """True iff ``F`` contains a TimeDerivative of an expression that is + nonlinear in u0 -- i.e. ``Dt(g(u0))`` for some nonlinear g. These + cases lose mass conservation when chain-ruled through the + stage-derivative form, and require the conservative two-evaluation + discretisation. + + For each ``Dt(f)`` in the form, the Gateaux derivative of ``f`` with + respect to ``u0`` is taken in a trial direction. If the derivative + still depends on ``u0``, ``f`` is nonlinear in u0. This delegates + the classification of linear operators (Grad, Div, Indexed, + restrictions, ListTensor, ComponentTensor, ...) to UFL's own + derivative machinery rather than maintaining a parallel exemption + list inside Irksome. + + .. warning:: + + The detection is syntactic: it checks whether ``u0`` appears + under ``Dt`` after differentiation. If a user creates an + intermediate :class:`~firedrake.Function` whose values were + interpolated from an expression in u0 and then writes + ``Dt(that_intermediate)``, the syntactic dependence on u0 is + lost and this function will declare the form safe. The + resulting discretisation is *not* mass-conservative. Always + wrap the symbolic expression directly in ``Dt`` (as + ``Dt(theta(u))``, not ``Dt(theta_function)``). + """ + for td in extract_type(F, TimeDerivative): + f, = td.ufl_operands + if u0 not in extract_coefficients(f): + # Dt(f(t,x)) -- no u0 dependence, chain-ruled analytically + continue + D = expand_derivatives(derivative(f, u0)) + if u0 in extract_coefficients(D): + return True + return False class SplitTimeForm(NamedTuple): diff --git a/tests/test_has_nonlinear_time_derivative.py b/tests/test_has_nonlinear_time_derivative.py new file mode 100644 index 00000000..b36e994d --- /dev/null +++ b/tests/test_has_nonlinear_time_derivative.py @@ -0,0 +1,75 @@ +"""Classification contract for ``has_nonlinear_time_derivative``. + +The helper is what routes stage_value between the existing +linear-combination update (``_update_Ainv``) and the new conservative +variational update. Misclassification has two failure modes: + +* False negative on a nonlinear ``Dt(g(u))`` would silently drop the form + back onto the linear-combination update and lose mass conservation for + non-stiffly-accurate tableaux. + +* False positive on a linear ``Dt(c*u)`` or ``Dt(u + f(t))`` would route + the form through the conservative variational head, which has no + algebraic block and breaks DAEs. + +Pin both directions of the contract here so a future change to the +walker cannot quietly regress either case. +""" +import pytest +from firedrake import ( + Constant, Function, FunctionSpace, TestFunction, + UnitIntervalMesh, dx, exp, inner, +) +from ufl import sin + +from irksome import Dt +from irksome.ufl.manipulation import has_nonlinear_time_derivative + + +def _theta(h, theta_r=Constant(0.15), theta_s=Constant(0.45), + alpha=Constant(0.328)): + """Exponential soil moisture, the canonical nonlinear g(h) we care about.""" + return theta_r + (theta_s - theta_r) * exp(alpha * h) + + +@pytest.fixture +def setup(): + mesh = UnitIntervalMesh(4) + V = FunctionSpace(mesh, "CG", 1) + return V, Function(V), TestFunction(V), Constant(0.0) + + +def test_linear_arithmetic_is_not_flagged(setup): + """Constant-coefficient scalings of u and additive time forcings must + not be flagged. These are linear in u; ``_update_Ainv`` is correct + for them and is also the path that handles DAE structure.""" + V, u, v, t = setup + forms = { + "Dt(2*u)": Dt(Constant(2.0) * u), + "Dt(u/2)": Dt(u / Constant(2.0)), + "Dt(u + sin(t))": Dt(u + sin(t)), + "Dt(2*u + 3)": Dt(Constant(2.0) * u + Constant(3.0)), + } + for name, expr in forms.items(): + F = inner(expr, v) * dx + assert not has_nonlinear_time_derivative(F, u), ( + f"{name} was incorrectly flagged as nonlinear -- the linear " + "stage_value path would be skipped, breaking DAEs." + ) + + +def test_nonlinear_is_flagged(setup): + """Genuinely nonlinear g(u) inside Dt must be flagged so that + stage_value routes through the conservative variational update.""" + V, u, v, t = setup + nonlinear = { + "Dt(u*u)": Dt(u * u), + "Dt(theta(u))": Dt(_theta(u)), + "Dt(1/u)": Dt(Constant(1.0) / u), + } + for name, expr in nonlinear.items(): + F = inner(expr, v) * dx + assert has_nonlinear_time_derivative(F, u), ( + f"{name} was missed -- this would silently produce a " + "non-conservative discretisation for non-SA stage_value." + ) diff --git a/tests/test_mass_conservation.py b/tests/test_mass_conservation.py index 68c32721..35ec1b98 100644 --- a/tests/test_mass_conservation.py +++ b/tests/test_mass_conservation.py @@ -11,7 +11,10 @@ Constant, Function, FunctionSpace, TestFunction, UnitSquareMesh, assemble, ds, dx, exp, grad, inner, ) -from irksome import BackwardEuler, DiscontinuousGalerkinScheme, Dt, RadauIIA, TimeStepper, BDF, AdamsMoulton, MultistepTableau +from irksome import ( + AdamsMoulton, BackwardEuler, BDF, DiscontinuousGalerkinScheme, Dt, + GaussLegendre, MultistepTableau, QinZhang, RadauIIA, TimeStepper, +) import numpy as np @@ -47,11 +50,15 @@ def run_richards(scheme, **kwargs): kwargs['startup_parameters'] = {'tableau': RadauIIA(2), 'stepper_kwargs': {'stage_type': 'value'}} + snes_params = {"snes_rtol": 1e-10, "snes_atol": 1e-14} + if kwargs.get("stage_type") == "value": + # The update solve does not inherit solver options from the stage + # solve; pass tight SNES tolerances explicitly so the conservation + # assertion below sits at machine precision rather than at the + # default SNES rtol. + kwargs["update_solver_parameters"] = snes_params stepper = TimeStepper(F, scheme, t, dt, h, - solver_parameters={ - "snes_rtol": 1e-10, - "snes_atol": 1e-14, - }, + solver_parameters=snes_params, **kwargs) if isinstance(scheme, MultistepTableau): @@ -80,8 +87,19 @@ def run_richards(scheme, **kwargs): return mean_error -@pytest.mark.parametrize("scheme", [BackwardEuler(), RadauIIA(2), BDF(1), AdamsMoulton(0), AdamsMoulton(1), AdamsMoulton(2)], - ids=["BackwardEuler", "RadauIIA2", "BDF1", "AM0", "AM1", "AM2"]) +# The three non-SA tableaux at the end (GaussLegendre(1) = ImplicitMidpoint, +# GaussLegendre(2), QinZhang) exercise the conservative variational update +# path introduced for non-stiffly-accurate stage_value. On master they fail +# this test with mass errors of order 1e-6 to 1e-7 because the linear- +# combination update destroys the conservation property the stage equations +# build for nonlinear theta(h). +@pytest.mark.parametrize("scheme", + [BackwardEuler(), RadauIIA(2), BDF(1), + AdamsMoulton(0), AdamsMoulton(1), AdamsMoulton(2), + GaussLegendre(1), GaussLegendre(2), QinZhang()], + ids=["BackwardEuler", "RadauIIA2", "BDF1", + "AM0", "AM1", "AM2", + "ImplicitMidpoint", "GaussLegendre2", "QinZhang"]) def test_mass_conservation_stage_value(scheme): """Test mass conservation with Dt(theta(h))""" err = run_richards(scheme, stage_type="value")