Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions irksome/dirk_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
solver_parameters=None,
appctx=None, nullspace=None,
transpose_nullspace=None, near_nullspace=None,
stage_update_callback=None,
**kwargs):
assert butcher_tableau.is_diagonally_implicit

Expand All @@ -81,6 +82,9 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
self.AAb = numpy.vstack((butcher_tableau.A, butcher_tableau.b))
self.CCone = numpy.append(butcher_tableau.c, 1.0)

# Store the stage update callback
self.stage_update_callback = stage_update_callback

# Need to be able to set BCs for either the DIRK or explicit cases.

# For DIRK, we say that the stage i solution should match the
Expand Down Expand Up @@ -162,12 +166,23 @@ def advance(self):
u0 = self.u0
dt = self.dt
for i in range(self.num_stages):
# Call user-provided callback before solving this stage
# This allows updating time-dependent forcings or other coefficients
if self.stage_update_callback is not None:
# Use the Butcher tableau c value directly, not the mutable constant
# which has already been updated by update_bc_constants above
c_value = self.butcher_tableau.c[i]
stage_time = float(self.t) + c_value * float(dt)
self.stage_update_callback(i, stage_time)

# compute the already-known part of the state in the
# variational form
g.assign(sum((ks[j] * (self.AA[i, j] * dt) for j in range(i)), u0))

# update BC constants for the variational problem
self.update_bc_constants(i, c)


a.assign(self.AA[i, i])

# solve new variational problem, stash the computed
Expand Down
4 changes: 2 additions & 2 deletions irksome/explicit_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class ExplicitTimeStepper(DIRKTimeStepper):
def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
solver_parameters=None,
appctx=None):
appctx=None, stage_update_callback=None):
assert butcher_tableau.is_explicit
# we just have one mass matrix we're reusing for each time step and
# each stage, so we can nudge this along
Expand All @@ -19,4 +19,4 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
super(ExplicitTimeStepper, self).__init__(
F, butcher_tableau, t, dt, u0, bcs=bcs,
solver_parameters=solver_parameters, appctx=appctx,
nullspace=None)
nullspace=None, stage_update_callback=stage_update_callback)
3 changes: 2 additions & 1 deletion irksome/stage_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,11 @@ class StageDerivativeTimeStepper(StageCoupledTimeStepper):
"""
def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
solver_parameters=None, splitting=AI,
appctx=None, bc_type="DAE", **kwargs):
appctx=None, bc_type="DAE", stage_update_callback=None, **kwargs):

self.num_fields = len(u0.function_space())
self.butcher_tableau = butcher_tableau
self.stage_update_callback = stage_update_callback
A1, A2 = splitting(butcher_tableau.A)
try:
self.updateb = vecconst(numpy.linalg.solve(A2.T, butcher_tableau.b))
Expand Down
12 changes: 7 additions & 5 deletions irksome/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
"appctx", "options_prefix", "pre_apply_bcs")

valid_kwargs_per_stage_type = {
"deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters"],
"deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters", "stage_update_callback"],
"value": ["stage_type", "basis_type",
"update_solver_parameters", "splitting", "bounds", "use_collocation_update"],
"dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx"],
"explicit": ["stage_type", "bcs", "solver_parameters", "appctx"],
"dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx", "stage_update_callback"],
"explicit": ["stage_type", "bcs", "solver_parameters", "appctx", "stage_update_callback"],
"imex": ["Fexp", "stage_type", "it_solver_parameters", "prop_solver_parameters",
"splitting", "num_its_initial", "num_its_per_step"],
"dirkimex": ["Fexp", "stage_type", "mass_parameters"],
Expand Down Expand Up @@ -164,11 +164,13 @@ def TimeStepper(F, method, t, dt, u0, **kwargs):
bounds=bounds, use_collocation_update=use_collocation_update,
**base_kwargs)
elif stage_type == "dirk":
stage_update_callback = kwargs.get("stage_update_callback")
return DIRKTimeStepper(
F, method, t, dt, u0, bcs, **base_kwargs)
F, method, t, dt, u0, bcs, stage_update_callback=stage_update_callback, **base_kwargs)
elif stage_type == "explicit":
stage_update_callback = kwargs.get("stage_update_callback")
return ExplicitTimeStepper(
F, method, t, dt, u0, bcs, **base_kwargs)
F, method, t, dt, u0, bcs, stage_update_callback=stage_update_callback, **base_kwargs)
elif stage_type == "imex":
Fimp, Fexp = imex_separation(F, kwargs.get("Fexp"), stage_type)
appctx = base_kwargs.get("appctx")
Expand Down