Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AdamsMoulton,
)
from .tableaux.pep_explicit_rk import PEPRK
from .ufl.deriv import Dt, expand_time_derivatives, check_irksome_import_order
from .ufl.deriv import Dt, expand_time_derivatives, lag, check_irksome_import_order

check_irksome_import_order()

Expand Down Expand Up @@ -44,6 +44,7 @@
"expand_time_derivatives",
"GalerkinCollocationScheme",
"GaussLegendre",
"lag",
"LobattoIIIA",
"LobattoIIIC",
"MeshConstant",
Expand Down
24 changes: 17 additions & 7 deletions irksome/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from ufl.algorithms.analysis import extract_type
from ufl import as_tensor
from ufl import replace as ufl_replace
from ufl.classes import Variable
from pyop2.types import MixedDat
import FIAT

from .ufl.deriv import TimeDerivative
from .ufl.deriv import TimeDerivative, lag_label


def dot(A, B):
Expand Down Expand Up @@ -108,14 +109,23 @@ def getNullspace(V, Vbig, num_stages, nullspace):


def replace(e, mapping):
"""A wrapper for ufl.replace that allows numpy arrays."""
"""A wrapper for ufl.replace that allows numpy arrays and skips
substitution into sub-expressions wrapped by :func:`lag`."""
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
cmapping = {k: as_tensor(v) for k, v in mapping.items()}
Comment thread
pbrubeck marked this conversation as resolved.
Comment thread
pbrubeck marked this conversation as resolved.
if isinstance(e, LabelledForm):
enew = LabelledForm(*(Term(ufl_replace(term.form, cmapping), term.labels)
for term in e.terms))
return enew
else:
return ufl_replace(e, cmapping)
new_terms = []
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
for term in e.terms:
tmap = dict(cmapping)
for var in extract_type(term.form, Variable):
if var.ufl_operands[1] is lag_label:
tmap.setdefault(var, var)
new_terms.append(Term(ufl_replace(term.form, tmap), term.labels))
return LabelledForm(*new_terms)
Comment thread
pbrubeck marked this conversation as resolved.
Outdated

for var in extract_type(e, Variable):
if var.ufl_operands[1] is lag_label:
cmapping.setdefault(var, var)
Comment thread
pbrubeck marked this conversation as resolved.
Comment thread
pbrubeck marked this conversation as resolved.
return ufl_replace(e, cmapping)
Comment thread
pbrubeck marked this conversation as resolved.
Outdated


# Utility functions that help us refactor
Expand Down
13 changes: 12 additions & 1 deletion irksome/ufl/deriv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering
from ufl.form import BaseForm
from ufl.classes import (Coefficient, Conj, Curl, ConstantValue, Derivative,
Div, Expr, Grad, Indexed, ReferenceGrad,
Div, Expr, Grad, Indexed, Label, ReferenceGrad,
ReferenceValue, SpatialCoordinate, Variable)
from ufl.corealg.multifunction import MultiFunction

Expand Down Expand Up @@ -85,6 +85,17 @@ def Dt(f, order=1):
return f


# A :class:`ufl.Label` to mark nodes that are only evaluated at the start of
# the timestep.
lag_label = Label()


def lag(expr):
"""Mark a sub-expression to be evaluated only at the start of the
timestep during the implicit solve."""
return Variable(expr, lag_label)


Comment on lines +88 to +98
Copy link
Copy Markdown
Collaborator

@pbrubeck pbrubeck May 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file does not seem the right place to define this. I think we should be creating a new file.

class TimeDerivativeRuleset(GenericDerivativeRuleset):
"""Apply AD rules to time derivative expressions."""
def __init__(self, t=None, timedep_coeffs=None):
Expand Down
38 changes: 38 additions & 0 deletions tests/test_lag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import firedrake
from firedrake import Constant, inner, grad, dx, conditional
import irksome
from irksome import Dt, lag


def test_stefan_implicit():
"""Test lagging the conductivity on the Stefan problem"""
nx = 32
mesh = firedrake.UnitIntervalMesh(nx)
V = firedrake.FunctionSpace(mesh, "CG", 1)
x, = firedrake.SpatialCoordinate(mesh)

u = firedrake.Function(V)
u.interpolate(1 - 2 * x)

k_solid = Constant(2.0)
k_liquid = Constant(1.0)
k = lag(conditional(u < 0, k_solid, k_liquid))

v = firedrake.TestFunction(V)
F = (Dt(u) * v + k * inner(grad(u), grad(v))) * dx

T_1 = Constant(1.0)
T_2 = Constant(-1.0)
bcs = [firedrake.DirichletBC(V, T_1, 1), firedrake.DirichletBC(V, T_2, 2)]

t = Constant(0.0)
dt = Constant(0.1)
solver_params = {"snes_type": "newtonls", "snes_converged_reason": None}
params = {"bcs": bcs, "solver_parameters": solver_params}
method = irksome.BackwardEuler()
stepper = irksome.TimeStepper(F, method, t, dt, u, **params)

final_time = 10.0
num_steps = int(final_time / float(dt))
for step in range(num_steps):
stepper.advance()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an assertion here?

Loading