Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AdamsMoulton,
)
from .tableaux.pep_explicit_rk import PEPRK
from .ufl.deriv import Dt, expand_time_derivatives, check_irksome_import_order
from .ufl.deriv import Dt, expand_time_derivatives, lag, check_irksome_import_order

check_irksome_import_order()

Expand Down Expand Up @@ -44,6 +44,7 @@
"expand_time_derivatives",
"GalerkinCollocationScheme",
"GaussLegendre",
"lag",
"LobattoIIIA",
"LobattoIIIC",
"MeshConstant",
Expand Down
24 changes: 17 additions & 7 deletions irksome/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from ufl.algorithms.analysis import extract_type
from ufl import as_tensor
from ufl import replace as ufl_replace
from ufl.classes import Variable
from pyop2.types import MixedDat
import FIAT

from .ufl.deriv import TimeDerivative
from .ufl.deriv import TimeDerivative, lag_label


def dot(A, B):
Expand Down Expand Up @@ -108,14 +109,23 @@ def getNullspace(V, Vbig, num_stages, nullspace):


def replace(e, mapping):
"""A wrapper for ufl.replace that allows numpy arrays."""
"""A wrapper for ufl.replace that allows numpy arrays and skips
substitution into sub-expressions wrapped by :func:`lag`."""
cmapping = {k: as_tensor(v) for k, v in mapping.items()}
if isinstance(e, LabelledForm):
enew = LabelledForm(*(Term(ufl_replace(term.form, cmapping), term.labels)
for term in e.terms))
return enew
else:
return ufl_replace(e, cmapping)
new_terms = []
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a bit of help from Claude for this code path. The alternate path below (not a LabelledForm) was straightforward.

Copy link
Copy Markdown
Collaborator

@pbrubeck pbrubeck May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code path is going to disappear in #227

for term in e.terms:
tmap = dict(cmapping)
for var in extract_type(term.form, Variable):
if var.ufl_operands[1] is lag_label:
tmap.setdefault(var, var)
new_terms.append(Term(ufl_replace(term.form, tmap), term.labels))
return LabelledForm(*new_terms)

for var in extract_type(e, Variable):
if var.ufl_operands[1] is lag_label:
cmapping.setdefault(var, var)
Comment on lines +125 to +127
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go at the top of the function, and the rest of the diff could go away

return ufl_replace(e, cmapping)


# Utility functions that help us refactor
Expand Down
13 changes: 12 additions & 1 deletion irksome/ufl/deriv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering
from ufl.form import BaseForm
from ufl.classes import (Coefficient, Conj, Curl, ConstantValue, Derivative,
Div, Expr, Grad, Indexed, ReferenceGrad,
Div, Expr, Grad, Indexed, Label, ReferenceGrad,
ReferenceValue, SpatialCoordinate, Variable)
from ufl.corealg.multifunction import MultiFunction

Expand Down Expand Up @@ -85,6 +85,17 @@ def Dt(f, order=1):
return f


# A :class:`ufl.Label` to mark nodes that are only evaluated at the start of
# the timestep.
lag_label = Label()


def lag(expr):
"""Mark a sub-expression to be evaluated only at the start of the
timestep during the implicit solve."""
return Variable(expr, lag_label)


class TimeDerivativeRuleset(GenericDerivativeRuleset):
"""Apply AD rules to time derivative expressions."""
def __init__(self, t=None, timedep_coeffs=None):
Expand Down
38 changes: 38 additions & 0 deletions tests/test_lag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import firedrake
from firedrake import Constant, inner, grad, dx, conditional
import irksome
from irksome import Dt, lag


def test_stefan_implicit():
"""Test lagging the conductivity on the Stefan problem"""
nx = 32
mesh = firedrake.UnitIntervalMesh(nx)
V = firedrake.FunctionSpace(mesh, "CG", 1)
x, = firedrake.SpatialCoordinate(mesh)

u = firedrake.Function(V)
u.interpolate(1 - 2 * x)

k_solid = Constant(2.0)
k_liquid = Constant(1.0)
k = lag(conditional(u < 0, k_solid, k_liquid))

v = firedrake.TestFunction(V)
F = (Dt(u) * v + k * inner(grad(u), grad(v))) * dx

T_1 = Constant(1.0)
T_2 = Constant(-1.0)
bcs = [firedrake.DirichletBC(V, T_1, 1), firedrake.DirichletBC(V, T_2, 2)]

t = Constant(0.0)
dt = Constant(0.1)
solver_params = {"snes_type": "newtonls", "snes_converged_reason": None}
params = {"bcs": bcs, "solver_parameters": solver_params}
method = irksome.BackwardEuler()
stepper = irksome.TimeStepper(F, method, t, dt, u, **params)

final_time = 10.0
num_steps = int(final_time / float(dt))
for step in range(num_steps):
stepper.advance()
Loading