From c30d03226d73a24ce51aebfe5c93c9f4f21b3865 Mon Sep 17 00:00:00 2001 From: Sia Ghelichkhan Date: Wed, 12 Nov 2025 23:10:13 +1100 Subject: [PATCH 1/3] Adding stage callback function --- irksome/dirk_stepper.py | 14 ++++++++++++++ irksome/explicit_stepper.py | 4 ++-- irksome/stage_derivative.py | 3 ++- irksome/stepper.py | 12 +++++++----- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/irksome/dirk_stepper.py b/irksome/dirk_stepper.py index e71f14f0..d1217165 100644 --- a/irksome/dirk_stepper.py +++ b/irksome/dirk_stepper.py @@ -69,6 +69,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, solver_parameters=None, appctx=None, nullspace=None, transpose_nullspace=None, near_nullspace=None, + stage_update_callback=None, **kwargs): assert butcher_tableau.is_diagonally_implicit @@ -81,6 +82,9 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, self.AAb = numpy.vstack((butcher_tableau.A, butcher_tableau.b)) self.CCone = numpy.append(butcher_tableau.c, 1.0) + # Store the stage update callback + self.stage_update_callback = stage_update_callback + # Need to be able to set BCs for either the DIRK or explicit cases. # For DIRK, we say that the stage i solution should match the @@ -168,6 +172,16 @@ def advance(self): # update BC constants for the variational problem self.update_bc_constants(i, c) + + # Call user-provided callback before solving this stage + # This allows updating time-dependent forcings or other coefficients + if self.stage_update_callback is not None: + # Use the Butcher tableau c value directly, not the mutable constant + # which has already been updated by update_bc_constants above + c_value = self.butcher_tableau.c[i] + stage_time = float(self.t) + c_value * float(dt) + self.stage_update_callback(i, stage_time) + a.assign(self.AA[i, i]) # solve new variational problem, stash the computed diff --git a/irksome/explicit_stepper.py b/irksome/explicit_stepper.py index 42827ca3..6ce6a500 100644 --- a/irksome/explicit_stepper.py +++ b/irksome/explicit_stepper.py @@ -7,7 +7,7 @@ class ExplicitTimeStepper(DIRKTimeStepper): def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, solver_parameters=None, - appctx=None): + appctx=None, stage_update_callback=None): assert butcher_tableau.is_explicit # we just have one mass matrix we're reusing for each time step and # each stage, so we can nudge this along @@ -19,4 +19,4 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, super(ExplicitTimeStepper, self).__init__( F, butcher_tableau, t, dt, u0, bcs=bcs, solver_parameters=solver_parameters, appctx=appctx, - nullspace=None) + nullspace=None, stage_update_callback=stage_update_callback) diff --git a/irksome/stage_derivative.py b/irksome/stage_derivative.py index 48fcde02..13bb4635 100644 --- a/irksome/stage_derivative.py +++ b/irksome/stage_derivative.py @@ -180,10 +180,11 @@ class StageDerivativeTimeStepper(StageCoupledTimeStepper): """ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, solver_parameters=None, splitting=AI, - appctx=None, bc_type="DAE", **kwargs): + appctx=None, bc_type="DAE", stage_update_callback=None, **kwargs): self.num_fields = len(u0.function_space()) self.butcher_tableau = butcher_tableau + self.stage_update_callback = stage_update_callback A1, A2 = splitting(butcher_tableau.A) try: self.updateb = vecconst(numpy.linalg.solve(A2.T, butcher_tableau.b)) diff --git a/irksome/stepper.py b/irksome/stepper.py index a57002a4..947ba963 100644 --- a/irksome/stepper.py +++ b/irksome/stepper.py @@ -14,11 +14,11 @@ "appctx", "options_prefix", "pre_apply_bcs") valid_kwargs_per_stage_type = { - "deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters"], + "deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters", "stage_update_callback"], "value": ["stage_type", "basis_type", "update_solver_parameters", "splitting", "bounds", "use_collocation_update"], - "dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx"], - "explicit": ["stage_type", "bcs", "solver_parameters", "appctx"], + "dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx", "stage_update_callback"], + "explicit": ["stage_type", "bcs", "solver_parameters", "appctx", "stage_update_callback"], "imex": ["Fexp", "stage_type", "it_solver_parameters", "prop_solver_parameters", "splitting", "num_its_initial", "num_its_per_step"], "dirkimex": ["Fexp", "stage_type", "mass_parameters"], @@ -164,11 +164,13 @@ def TimeStepper(F, method, t, dt, u0, **kwargs): bounds=bounds, use_collocation_update=use_collocation_update, **base_kwargs) elif stage_type == "dirk": + stage_update_callback = kwargs.get("stage_update_callback") return DIRKTimeStepper( - F, method, t, dt, u0, bcs, **base_kwargs) + F, method, t, dt, u0, bcs, stage_update_callback=stage_update_callback, **base_kwargs) elif stage_type == "explicit": + stage_update_callback = kwargs.get("stage_update_callback") return ExplicitTimeStepper( - F, method, t, dt, u0, bcs, **base_kwargs) + F, method, t, dt, u0, bcs, stage_update_callback=stage_update_callback, **base_kwargs) elif stage_type == "imex": Fimp, Fexp = imex_separation(F, kwargs.get("Fexp"), stage_type) appctx = base_kwargs.get("appctx") From 0a5c0cb8019c97a85238946fd9d1d13904c673f3 Mon Sep 17 00:00:00 2001 From: Sia Ghelichkhan Date: Thu, 13 Nov 2025 13:52:28 +1100 Subject: [PATCH 2/3] a hack for now --- irksome/dirk_stepper.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/irksome/dirk_stepper.py b/irksome/dirk_stepper.py index d1217165..f5f02493 100644 --- a/irksome/dirk_stepper.py +++ b/irksome/dirk_stepper.py @@ -182,6 +182,15 @@ def advance(self): stage_time = float(self.t) + c_value * float(dt) self.stage_update_callback(i, stage_time) + # Invalidate form cache to ensure updated Functions are picked up + # The form references Functions by identity, so updates should be visible, + # but we need to ensure the form cache is invalidated + if hasattr(self.problem, '_form_cache'): + self.problem._form_cache.clear() + # Also invalidate the solver's form cache if it exists + if hasattr(self.solver, '_form_cache'): + self.solver._form_cache.clear() + a.assign(self.AA[i, i]) # solve new variational problem, stash the computed From a7051684de9d97006b0d3b97e21ed46680a1baa3 Mon Sep 17 00:00:00 2001 From: Sia Ghelichkhan Date: Thu, 13 Nov 2025 16:01:43 +1100 Subject: [PATCH 3/3] submitting change --- irksome/dirk_stepper.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/irksome/dirk_stepper.py b/irksome/dirk_stepper.py index f5f02493..2ec2709f 100644 --- a/irksome/dirk_stepper.py +++ b/irksome/dirk_stepper.py @@ -166,13 +166,6 @@ def advance(self): u0 = self.u0 dt = self.dt for i in range(self.num_stages): - # compute the already-known part of the state in the - # variational form - g.assign(sum((ks[j] * (self.AA[i, j] * dt) for j in range(i)), u0)) - - # update BC constants for the variational problem - self.update_bc_constants(i, c) - # Call user-provided callback before solving this stage # This allows updating time-dependent forcings or other coefficients if self.stage_update_callback is not None: @@ -182,14 +175,13 @@ def advance(self): stage_time = float(self.t) + c_value * float(dt) self.stage_update_callback(i, stage_time) - # Invalidate form cache to ensure updated Functions are picked up - # The form references Functions by identity, so updates should be visible, - # but we need to ensure the form cache is invalidated - if hasattr(self.problem, '_form_cache'): - self.problem._form_cache.clear() - # Also invalidate the solver's form cache if it exists - if hasattr(self.solver, '_form_cache'): - self.solver._form_cache.clear() + # compute the already-known part of the state in the + # variational form + g.assign(sum((ks[j] * (self.AA[i, j] * dt) for j in range(i)), u0)) + + # update BC constants for the variational problem + self.update_bc_constants(i, c) + a.assign(self.AA[i, i])