Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 86 additions & 19 deletions irksome/stage_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,48 @@
from FIAT import Bernstein, ufc_simplex
from FIAT.barycentric_interpolation import LagrangePolynomialSet
from firedrake import (Function, NonlinearVariationalProblem,
NonlinearVariationalSolver, TestFunction, dx,
inner)
NonlinearVariationalSolver, TestFunction)
from ufl import as_tensor, Form
from ufl.constantvalue import as_ufl

from .bcs import stage2spaces4bc
from .tableaux.ButcherTableaux import CollocationButcherTableau
from .ufl.deriv import expand_time_derivatives
from .ufl.manipulation import split_time_derivative_terms, remove_time_derivatives
from .ufl.manipulation import (has_nonlinear_time_derivative,
split_time_derivative_terms,
remove_time_derivatives)
from .tools import AI, dot, reshape, replace
from .constant import vecconst
from .base_time_stepper import StageCoupledTimeStepper


# Default solver for the conservative update. The update problem is a
# nonlinear mass-matrix-like solve on V: the Jacobian is
# g'(u_new) * v * phi * dx (the spatial terms in the RK quadrature are
# evaluated at known stage values and are constant in u_new). That
# operator is SPD when g'(u) > 0 -- the well-posed case for Dt(g(u))
# with g monotone -- and its condition number is bounded independent
# of mesh size, so CG converges in O(1) iterations with a cheap
# preconditioner. bjacobi + icc respects the block structure of
# vector / mixed V and degenerates gracefully on scalar V.
#
# Assumptions: g is monotone in u (g'(u) > 0 a.e.). Points where g'
# vanishes trip a zero pivot regardless of preconditioner choice.
# Non-monotone g breaks SPD and needs an explicit override
# (e.g. GMRES + a more general preconditioner).
#
# Does not inherit from solver_parameters: the stage problem lives on
# V^s = V x ... x V and stage-tuned options -- fieldsplit indices,
# snes_type='ksponly', lagged Jacobians, custom MG transfers -- do not
# generally apply to the update solve on V.
_DEFAULT_UPDATE_SOLVER_PARAMETERS = {
'snes_type': 'newtonls',
'ksp_type': 'cg',
'pc_type': 'bjacobi',
'sub_pc_type': 'icc',
}


def to_value(u0, stages, vandermonde):
"""convert from Bernstein to Lagrange representation

Expand Down Expand Up @@ -189,20 +217,34 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
if use_collocation_update:
# Use the terminal value of the collocation polynomial to update the solution. Note: collocation update is only implemented for constant-in-time boundary conditions.
# TODO: create an assertion to check for constant-in-time boundary conditions.

self.collocation_vander = self.tabulate_poly((1.0,))
self._update = self._update_collocation

elif (not butcher_tableau.is_stiffly_accurate) and (vandermonde is None):
try:
A = butcher_tableau.A
b = butcher_tableau.b
self.bAinv = vecconst(numpy.linalg.solve(A.T, b))
self.update_scale = 1-numpy.sum(self.bAinv)
self._update = self._update_Ainv
except numpy.linalg.LinAlgError:
# For nonlinear Dt(g(u)) we need the conservative variational
# update; the bAinv linear combination of stages is correct
# for g = identity but breaks for nonlinear g. For the
# linear case we keep the bAinv shortcut: it is conservative
# by construction (g = id makes both formulations agree)
# AND it handles DAEs correctly, which the conservative
# variational update does not.
if has_nonlinear_time_derivative(F, u0):
if update_solver_parameters is None:
update_solver_parameters = _DEFAULT_UPDATE_SOLVER_PARAMETERS
self.unew, self.update_solver = self.get_update_solver(update_solver_parameters)
self._update = self._update_general
else:
try:
A = butcher_tableau.A
b = butcher_tableau.b
self.bAinv = vecconst(numpy.linalg.solve(A.T, b))
self.update_scale = 1-numpy.sum(self.bAinv)
self._update = self._update_Ainv
except numpy.linalg.LinAlgError:
if update_solver_parameters is None:
update_solver_parameters = _DEFAULT_UPDATE_SOLVER_PARAMETERS
self.unew, self.update_solver = self.get_update_solver(update_solver_parameters)
self._update = self._update_general
else:
self._update = self._update_stiff_acc

Expand All @@ -220,20 +262,43 @@ def _update_stiff_acc(self):
u0bit.assign(self.stages.subfunctions[self.num_fields*(self.num_stages-1)+i])

def get_update_solver(self, update_solver_parameters):
# only form update stuff if we need it
# which means neither stiffly accurate nor Vandermonde
v, = self.F.arguments()
unew = Function(self.u0.function_space())
Fupdate = inner(unew - self.u0, v) * dx

C = vecconst(self.butcher_tableau.c)
B = vecconst(self.butcher_tableau.b)
"""Build a conservative variational update solve for u_new.

For a mass term ``inner(Dt(g(u)), v) * dx`` the update head is

inner(g(u_new) - g(u_0), v) * dx

evaluated at the stage-solve test function ``v``. For
``g = identity`` it reduces to ``inner(u_new - u_0, v) * dx``,
so the discrete update equation is unchanged in the linear
case. The remaining (non-time-derivative) part of the form is
contributed by the standard RK quadrature
``sum_i b_i * F_remainder(stage_i)``.

``update_solver_parameters`` does not inherit from
``solver_parameters``. The update solve is a different
problem from the stage solve -- it is posed on ``V`` rather
than ``V^s = V x ... x V``, and its Jacobian is a (nonlinear)
weighted mass matrix rather than the stage operator. Stage-
tuned options such as fieldsplit indices, ``snes_type='ksponly'``,
lagged Jacobians, or custom multigrid transfers generally do
not apply. If ``update_solver_parameters`` is left as None the
default ``_DEFAULT_UPDATE_SOLVER_PARAMETERS`` is used.
"""
F = self.F
t = self.t
dt = self.dt
u0 = self.u0
unew = Function(u0.function_space())

split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u0,))
F_dtless = remove_time_derivatives(split_form.time)
F_remainder = expand_time_derivatives(split_form.remainder, t=t, timedep_coeffs=())

Fupdate = replace(F_dtless, {u0: unew}) - F_dtless
Comment thread
pbrubeck marked this conversation as resolved.

C = vecconst(self.butcher_tableau.c)
B = vecconst(self.butcher_tableau.b)
u_np = to_value(self.u0, self.stages, self.vandermonde)

for i in range(self.num_stages):
Expand All @@ -258,6 +323,8 @@ def get_update_solver(self, update_solver_parameters):
return unew, update_solver

def _update_general(self):
# Constant-in-time initial guess to prevent singular Jacobian
self.unew.assign(self.u0)
self.update_solver.solve()
self.u0.assign(self.unew)

Expand Down
49 changes: 46 additions & 3 deletions irksome/ufl/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,62 @@
from operator import or_
from typing import NamedTuple, Sequence, FrozenSet

from ufl import TrialFunction, derivative
Comment thread
sghelichkhani marked this conversation as resolved.
Outdated
from ufl.algorithms import expand_derivatives
from ufl.algorithms.analysis import extract_coefficients, extract_type
from ufl.corealg.traversal import traverse_unique_terminals
from ufl.corealg.dag_traverser import DAGTraverser
from ufl.classes import (
BaseForm, CellAvg, Coefficient, ComponentTensor,
Conj, Cross, Derivative, Division, Dot, Expr, FacetAvg,
Form, FormSum, Indexed, IndexSum, Inner, Integral,
Conj, Cross, Derivative, Div, Division, Dot, Expr, FacetAvg,
Form, FormSum, Grad, Indexed, IndexSum, Inner, Integral,
ListTensor, MultiIndex, NegativeRestricted, Outer, PositiveRestricted,
Product, Sum, Variable,
)

from .deriv import TimeDerivative

__all__ = ("SplitTimeForm", "check_integrals", "split_time_derivative_terms", "remove_time_derivatives")
__all__ = ("SplitTimeForm", "check_integrals", "split_time_derivative_terms",
"remove_time_derivatives", "has_nonlinear_time_derivative")


def has_nonlinear_time_derivative(F, u0):
"""True iff ``F`` contains a TimeDerivative of an expression that is
nonlinear in u0 -- i.e. ``Dt(g(u0))`` for some nonlinear g. These
cases lose mass conservation when chain-ruled through the
stage-derivative form, and require the conservative two-evaluation
discretisation.

For each ``Dt(f)`` in the form, the Gateaux derivative of ``f`` with
respect to ``u0`` is taken in a trial direction. If the derivative
still depends on ``u0``, ``f`` is nonlinear in u0. This delegates
the classification of linear operators (Grad, Div, Indexed,
restrictions, ListTensor, ComponentTensor, ...) to UFL's own
derivative machinery rather than maintaining a parallel exemption
list inside Irksome.

.. warning::

The detection is syntactic: it checks whether ``u0`` appears
under ``Dt`` after differentiation. If a user creates an
intermediate :class:`~firedrake.Function` whose values were
interpolated from an expression in u0 and then writes
``Dt(that_intermediate)``, the syntactic dependence on u0 is
lost and this function will declare the form safe. The
resulting discretisation is *not* mass-conservative. Always
wrap the symbolic expression directly in ``Dt`` (as
``Dt(theta(u))``, not ``Dt(theta_function)``).
"""
Trial = TrialFunction(u0.function_space())
Comment thread
sghelichkhani marked this conversation as resolved.
Outdated
for td in extract_type(F, TimeDerivative):
f, = td.ufl_operands
if u0 not in extract_coefficients(f):
# Dt(f(t,x)) -- no u0 dependence, chain-ruled analytically
continue
D = expand_derivatives(derivative(f, u0, Trial))
Comment thread
sghelichkhani marked this conversation as resolved.
Outdated
if u0 in extract_coefficients(D):
return True
return False


class SplitTimeForm(NamedTuple):
Expand Down
75 changes: 75 additions & 0 deletions tests/test_has_nonlinear_time_derivative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Classification contract for ``has_nonlinear_time_derivative``.

The helper is what routes stage_value between the existing
linear-combination update (``_update_Ainv``) and the new conservative
variational update. Misclassification has two failure modes:

* False negative on a nonlinear ``Dt(g(u))`` would silently drop the form
back onto the linear-combination update and lose mass conservation for
non-stiffly-accurate tableaux.

* False positive on a linear ``Dt(c*u)`` or ``Dt(u + f(t))`` would route
the form through the conservative variational head, which has no
algebraic block and breaks DAEs.

Pin both directions of the contract here so a future change to the
walker cannot quietly regress either case.
"""
import pytest
from firedrake import (
Constant, Function, FunctionSpace, TestFunction,
UnitIntervalMesh, dx, exp, inner,
)
from ufl import sin

from irksome import Dt
from irksome.ufl.manipulation import has_nonlinear_time_derivative


def _theta(h, theta_r=Constant(0.15), theta_s=Constant(0.45),
alpha=Constant(0.328)):
"""Exponential soil moisture, the canonical nonlinear g(h) we care about."""
return theta_r + (theta_s - theta_r) * exp(alpha * h)


@pytest.fixture
def setup():
mesh = UnitIntervalMesh(4)
V = FunctionSpace(mesh, "CG", 1)
return V, Function(V), TestFunction(V), Constant(0.0)


def test_linear_arithmetic_is_not_flagged(setup):
"""Constant-coefficient scalings of u and additive time forcings must
not be flagged. These are linear in u; ``_update_Ainv`` is correct
for them and is also the path that handles DAE structure."""
V, u, v, t = setup
forms = {
"Dt(2*u)": Dt(Constant(2.0) * u),
"Dt(u/2)": Dt(u / Constant(2.0)),
"Dt(u + sin(t))": Dt(u + sin(t)),
"Dt(2*u + 3)": Dt(Constant(2.0) * u + Constant(3.0)),
}
for name, expr in forms.items():
F = inner(expr, v) * dx
assert not has_nonlinear_time_derivative(F, u), (
f"{name} was incorrectly flagged as nonlinear -- the linear "
"stage_value path would be skipped, breaking DAEs."
)


def test_nonlinear_is_flagged(setup):
"""Genuinely nonlinear g(u) inside Dt must be flagged so that
stage_value routes through the conservative variational update."""
V, u, v, t = setup
nonlinear = {
"Dt(u*u)": Dt(u * u),
"Dt(theta(u))": Dt(_theta(u)),
"Dt(1/u)": Dt(Constant(1.0) / u),
}
for name, expr in nonlinear.items():
F = inner(expr, v) * dx
assert has_nonlinear_time_derivative(F, u), (
f"{name} was missed -- this would silently produce a "
"non-conservative discretisation for non-SA stage_value."
)
20 changes: 17 additions & 3 deletions tests/test_mass_conservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
Constant, Function, FunctionSpace, TestFunction,
UnitSquareMesh, assemble, ds, dx, exp, grad, inner,
)
from irksome import BackwardEuler, DiscontinuousGalerkinScheme, Dt, RadauIIA, TimeStepper, BDF, AdamsMoulton, MultistepTableau
from irksome import (
AdamsMoulton, BackwardEuler, BDF, DiscontinuousGalerkinScheme, Dt,
GaussLegendre, MultistepTableau, QinZhang, RadauIIA, TimeStepper,
)
import numpy as np


Expand Down Expand Up @@ -80,8 +83,19 @@ def run_richards(scheme, **kwargs):
return mean_error


@pytest.mark.parametrize("scheme", [BackwardEuler(), RadauIIA(2), BDF(1), AdamsMoulton(0), AdamsMoulton(1), AdamsMoulton(2)],
ids=["BackwardEuler", "RadauIIA2", "BDF1", "AM0", "AM1", "AM2"])
# The three non-SA tableaux at the end (GaussLegendre(1) = ImplicitMidpoint,
# GaussLegendre(2), QinZhang) exercise the conservative variational update
# path introduced for non-stiffly-accurate stage_value. On master they fail
# this test with mass errors of order 1e-6 to 1e-7 because the linear-
# combination update destroys the conservation property the stage equations
# build for nonlinear theta(h).
@pytest.mark.parametrize("scheme",
[BackwardEuler(), RadauIIA(2), BDF(1),
AdamsMoulton(0), AdamsMoulton(1), AdamsMoulton(2),
GaussLegendre(1), GaussLegendre(2), QinZhang()],
ids=["BackwardEuler", "RadauIIA2", "BDF1",
"AM0", "AM1", "AM2",
"ImplicitMidpoint", "GaussLegendre2", "QinZhang"])
def test_mass_conservation_stage_value(scheme):
"""Test mass conservation with Dt(theta(h))"""
err = run_richards(scheme, stage_type="value")
Expand Down
Loading