diff --git a/irksome/dirk_stepper.py b/irksome/dirk_stepper.py index e71f14f0..2ec2709f 100644 --- a/irksome/dirk_stepper.py +++ b/irksome/dirk_stepper.py @@ -69,6 +69,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, solver_parameters=None, appctx=None, nullspace=None, transpose_nullspace=None, near_nullspace=None, + stage_update_callback=None, **kwargs): assert butcher_tableau.is_diagonally_implicit @@ -81,6 +82,9 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, self.AAb = numpy.vstack((butcher_tableau.A, butcher_tableau.b)) self.CCone = numpy.append(butcher_tableau.c, 1.0) + # Store the stage update callback + self.stage_update_callback = stage_update_callback + # Need to be able to set BCs for either the DIRK or explicit cases. # For DIRK, we say that the stage i solution should match the @@ -162,12 +166,23 @@ def advance(self): u0 = self.u0 dt = self.dt for i in range(self.num_stages): + # Call user-provided callback before solving this stage + # This allows updating time-dependent forcings or other coefficients + if self.stage_update_callback is not None: + # Use the Butcher tableau c value directly, not the mutable constant + # which has already been updated by update_bc_constants above + c_value = self.butcher_tableau.c[i] + stage_time = float(self.t) + c_value * float(dt) + self.stage_update_callback(i, stage_time) + # compute the already-known part of the state in the # variational form g.assign(sum((ks[j] * (self.AA[i, j] * dt) for j in range(i)), u0)) # update BC constants for the variational problem self.update_bc_constants(i, c) + + a.assign(self.AA[i, i]) # solve new variational problem, stash the computed diff --git a/irksome/explicit_stepper.py b/irksome/explicit_stepper.py index 42827ca3..6ce6a500 100644 --- a/irksome/explicit_stepper.py +++ b/irksome/explicit_stepper.py @@ -7,7 +7,7 @@ class ExplicitTimeStepper(DIRKTimeStepper): def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, solver_parameters=None, - appctx=None): + appctx=None, stage_update_callback=None): assert butcher_tableau.is_explicit # we just have one mass matrix we're reusing for each time step and # each stage, so we can nudge this along @@ -19,4 +19,4 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, super(ExplicitTimeStepper, self).__init__( F, butcher_tableau, t, dt, u0, bcs=bcs, solver_parameters=solver_parameters, appctx=appctx, - nullspace=None) + nullspace=None, stage_update_callback=stage_update_callback) diff --git a/irksome/stage_derivative.py b/irksome/stage_derivative.py index 48fcde02..13bb4635 100644 --- a/irksome/stage_derivative.py +++ b/irksome/stage_derivative.py @@ -180,10 +180,11 @@ class StageDerivativeTimeStepper(StageCoupledTimeStepper): """ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, solver_parameters=None, splitting=AI, - appctx=None, bc_type="DAE", **kwargs): + appctx=None, bc_type="DAE", stage_update_callback=None, **kwargs): self.num_fields = len(u0.function_space()) self.butcher_tableau = butcher_tableau + self.stage_update_callback = stage_update_callback A1, A2 = splitting(butcher_tableau.A) try: self.updateb = vecconst(numpy.linalg.solve(A2.T, butcher_tableau.b)) diff --git a/irksome/stepper.py b/irksome/stepper.py index a57002a4..947ba963 100644 --- a/irksome/stepper.py +++ b/irksome/stepper.py @@ -14,11 +14,11 @@ "appctx", "options_prefix", "pre_apply_bcs") valid_kwargs_per_stage_type = { - "deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters"], + "deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters", "stage_update_callback"], "value": ["stage_type", "basis_type", "update_solver_parameters", "splitting", "bounds", "use_collocation_update"], - "dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx"], - "explicit": ["stage_type", "bcs", "solver_parameters", "appctx"], + "dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx", "stage_update_callback"], + "explicit": ["stage_type", "bcs", "solver_parameters", "appctx", "stage_update_callback"], "imex": ["Fexp", "stage_type", "it_solver_parameters", "prop_solver_parameters", "splitting", "num_its_initial", "num_its_per_step"], "dirkimex": ["Fexp", "stage_type", "mass_parameters"], @@ -164,11 +164,13 @@ def TimeStepper(F, method, t, dt, u0, **kwargs): bounds=bounds, use_collocation_update=use_collocation_update, **base_kwargs) elif stage_type == "dirk": + stage_update_callback = kwargs.get("stage_update_callback") return DIRKTimeStepper( - F, method, t, dt, u0, bcs, **base_kwargs) + F, method, t, dt, u0, bcs, stage_update_callback=stage_update_callback, **base_kwargs) elif stage_type == "explicit": + stage_update_callback = kwargs.get("stage_update_callback") return ExplicitTimeStepper( - F, method, t, dt, u0, bcs, **base_kwargs) + F, method, t, dt, u0, bcs, stage_update_callback=stage_update_callback, **base_kwargs) elif stage_type == "imex": Fimp, Fexp = imex_separation(F, kwargs.get("Fexp"), stage_type) appctx = base_kwargs.get("appctx")