diff --git a/irksome/backends/firedrake.py b/irksome/backends/firedrake.py index 10d2b12d..9a97f114 100644 --- a/irksome/backends/firedrake.py +++ b/irksome/backends/firedrake.py @@ -36,3 +36,35 @@ def Constant(self, val=0.0) -> ufl.Coefficient: def get_mesh_constant(MC: MeshConstant | None): return MC.Constant if MC else firedrake.Constant + + +def create_variational_problem(F, u, bcs=None, J=None, Jp=None, **kwargs): + if len(F.arguments()) == 2: + a = ufl.lhs(F) + L = ufl.rhs(F) + kwargs.pop("is_linear", None) + problem = firedrake.LinearVariationalProblem(a, L, u, bcs=bcs, aP=Jp, **kwargs) + else: + constant_jacobian = kwargs.pop("constant_jacobian", False) + problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, **kwargs) + if constant_jacobian: + problem._constant_jacobian = constant_jacobian + return problem + + +def create_variational_solver(problem, **kwargs): + if isinstance(problem, firedrake.LinearVariationalProblem): + return firedrake.LinearVariationalSolver(problem, **kwargs) + else: + return firedrake.NonlinearVariationalSolver(problem, **kwargs) + + +def invalidate_jacobian(solver): + return firedrake.LinearVariationalSolver.invalidate_jacobian(solver) + + +derivative = firedrake.derivative +norm = firedrake.norm +Function = firedrake.Function +TestFunction = firedrake.TestFunction +TrialFunction = firedrake.TrialFunction diff --git a/irksome/base_time_stepper.py b/irksome/base_time_stepper.py index e926a93f..981c5297 100644 --- a/irksome/base_time_stepper.py +++ b/irksome/base_time_stepper.py @@ -1,12 +1,7 @@ from abc import abstractmethod -from firedrake import ( - derivative, replace, lhs, rhs, Function, TrialFunction, - LinearVariationalProblem, LinearVariationalSolver, - NonlinearVariationalProblem, NonlinearVariationalSolver, -) from firedrake.petsc import PETSc from .tools import AI, getNullspace, flatten_dats, split_stages -from .labeling import as_form, as_linear_form +from .labeling import as_form from .backend import get_backend import ufl import numpy @@ -99,11 +94,6 @@ def __init__(self, F, t, dt, u0, num_stages, butcher_tableau=None, bounds=None, sample_points=None, **kwargs): - is_linear = False - if len(as_form(F).arguments()) == 2: - F = as_linear_form(F, u0) - is_linear = True - super().__init__(F, t, dt, u0, bcs=bcs, appctx=appctx, nullspace=nullspace) @@ -123,44 +113,34 @@ def __init__(self, F, t, dt, u0, num_stages, stages = self.get_stages() self.stages = stages - Fbig, bigBCs = self.get_form_and_bcs(stages) + V = u0.function_space() + Vbig = stages.function_space() + + F_linear = len(as_form(F).arguments()) == 2 + stages_F = self._backend.TrialFunction(Vbig) if F_linear else stages + Fbig, bigBCs = self.get_form_and_bcs(stages_F) + Jpbig = None if Fp is not None: - Fp = as_linear_form(Fp, u0) - Fpbig, _ = self.get_form_and_bcs(stages, F=Fp, bcs=()) - Jpbig = derivative(Fpbig, stages) + Fp_linear = len(as_form(Fp).arguments()) == 2 + stages_Fp = self._backend.TrialFunction(Vbig) if Fp_linear else stages + Fpbig, _ = self.get_form_and_bcs(stages_Fp, F=Fp, bcs=()) + Jpbig = ufl.lhs(Fpbig) if Fp_linear else self._backend.derivative(Fpbig, stages_Fp) - V = u0.function_space() - Vbig = stages.function_space() nullspace = getNullspace(V, Vbig, num_stages, nullspace) transpose_nullspace = getNullspace(V, Vbig, num_stages, transpose_nullspace) near_nullspace = getNullspace(V, Vbig, num_stages, near_nullspace) self.bigBCs = bigBCs - if is_linear: - Fbig = replace(Fbig, {stages: TrialFunction(stages.function_space())}) - abig = lhs(Fbig) - Lbig = rhs(Fbig) - problem = LinearVariationalProblem( - abig, Lbig, stages, bcs=bigBCs, aP=Jpbig, - form_compiler_parameters=kwargs.pop("form_compiler_parameters", None), - constant_jacobian=kwargs.pop("constant_jacobian", False), - restrict=kwargs.pop("restrict", False), - ) - solver_constructor = LinearVariationalSolver - else: - problem = NonlinearVariationalProblem( - Fbig, stages, bcs=bigBCs, Jp=Jpbig, - form_compiler_parameters=kwargs.pop("form_compiler_parameters", None), - is_linear=kwargs.pop("is_linear", False), - restrict=kwargs.pop("restrict", False), - ) - problem._constant_jacobian = kwargs.pop("constant_jacobian", False) - solver_constructor = NonlinearVariationalSolver - - self.problem = problem - self.solver = solver_constructor( + self.problem = self._backend.create_variational_problem( + Fbig, stages, bcs=bigBCs, Jp=Jpbig, + form_compiler_parameters=kwargs.pop("form_compiler_parameters", None), + is_linear=kwargs.pop("is_linear", False), + restrict=kwargs.pop("restrict", False), + constant_jacobian=kwargs.pop("constant_jacobian", False), + ) + self.solver = self._backend.create_variational_solver( self.problem, appctx=self.appctx, nullspace=nullspace, transpose_nullspace=transpose_nullspace, @@ -193,7 +173,7 @@ def advance(self): # allow butcher tableau as input for preconditioners to create # an alternate operator @abstractmethod - def get_form_and_bcs(self, stages, tableau=None, F=None): + def get_form_and_bcs(self, stages, F=None, bcs=None, tableau=None): pass def solver_stats(self): @@ -209,30 +189,30 @@ def get_stage_bounds(self, bounds=None): Vbig = self.stages.function_space() bounds_type, lower, upper = bounds if lower is None: - slb = Function(Vbig).assign(PETSc.NINFINITY) + slb = self._backend.Function(Vbig).assign(PETSc.NINFINITY) if upper is None: - sub = Function(Vbig).assign(PETSc.INFINITY) + sub = self._backend.Function(Vbig).assign(PETSc.INFINITY) if bounds_type == "stage": if lower is not None: dats = [lower.dat] * (self.num_stages) - slb = Function(Vbig, val=flatten_dats(dats)) + slb = self._backend.Function(Vbig, val=flatten_dats(dats)) if upper is not None: dats = [upper.dat] * (self.num_stages) - sub = Function(Vbig, val=flatten_dats(dats)) + sub = self._backend.Function(Vbig, val=flatten_dats(dats)) elif bounds_type == "last_stage": V = self.u0.function_space() if lower is not None: - ninfty = Function(V).assign(PETSc.NINFINITY) + ninfty = self._backend.Function(V).assign(PETSc.NINFINITY) dats = [ninfty.dat] * (self.num_stages-1) dats.append(lower.dat) - slb = Function(Vbig, val=flatten_dats(dats)) + slb = self._backend.Function(Vbig, val=flatten_dats(dats)) if upper is not None: - infty = Function(V).assign(PETSc.INFINITY) + infty = self._backend.Function(V).assign(PETSc.INFINITY) dats = [infty.dat] * (self.num_stages-1) dats.append(upper.dat) - sub = Function(Vbig, val=flatten_dats(dats)) + sub = self._backend.Function(Vbig, val=flatten_dats(dats)) else: raise ValueError("Unknown bounds type") @@ -254,7 +234,7 @@ def build_poly(self): pts = numpy.reshape(self.sample_points, (-1, 1)) vander = self.tabulate_poly(pts) - self.u_old = Function(self.u0) + self.u_old = self._backend.Function(self.u0) ks = [self.u_old] ks.extend(split_stages(self.u0.function_space(), self.stages)) num_samples = vander.shape[1] @@ -265,4 +245,4 @@ def invalidate_jacobian(self): """ Forces the matrix to be reassembled next time it is required. """ - LinearVariationalSolver.invalidate_jacobian(self.solver) + self._backend.invalidate_jacobian(self.solver) diff --git a/irksome/dirk_stepper.py b/irksome/dirk_stepper.py index 13e01fe0..690fe6cb 100644 --- a/irksome/dirk_stepper.py +++ b/irksome/dirk_stepper.py @@ -1,47 +1,43 @@ import numpy -from firedrake import (derivative, Function, - LinearVariationalSolver, - NonlinearVariationalProblem, - NonlinearVariationalSolver) -from ufl.constantvalue import as_ufl +from ufl import as_ufl, lhs -from .ufl.deriv import TimeDerivative, expand_time_derivatives from .constant import vecconst -from .tools import replace -from .constant import MeshConstant +from .backend import get_backend from .bcs import bc2space -from .labeling import as_linear_form +from .ufl.deriv import TimeDerivative, expand_time_derivatives +from .tools import extract_timedep_arguments, replace -def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None, kgac=None): +def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None, kgac=None, backend="firedrake"): + backend_cls = get_backend(backend) if bcs is None: bcs = [] - v, = F.arguments() + v, u = extract_timedep_arguments(F, u0) V = v.function_space() assert V == u0.function_space() + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u,)) + num_stages = butch.num_stages # Note: the Constant c is used for substitution in both the # variational form and BC's, and we update it for each stage in # the loop over stages in the advance method. The Constant a is # used similarly in the variational form - MC = MeshConstant(V.mesh()) + MC = backend_cls.MeshConstant(V.mesh()) if kgac is None: - k = Function(V) - g = Function(V) + k = backend_cls.Function(V) + g = backend_cls.Function(V) a = MC.Constant(1.0) c = MC.Constant(1.0) else: k, g, a, c = kgac - # preprocess time derivatives - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) - repl = {t: t + c * dt, - u0: g + k * (a * dt), - TimeDerivative(u0): k} + u: g + k * (a * dt), + TimeDerivative(u): k} stage_F = replace(F, repl) bcnew = [] @@ -76,7 +72,9 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None, solver_parameters=None, appctx=None, nullspace=None, transpose_nullspace=None, near_nullspace=None, + backend="firedrake", **kwargs): + self._backend = backend_cls = get_backend(backend) assert butcher_tableau.is_diagonally_implicit self.num_steps = 0 @@ -114,15 +112,13 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None, self.dt = dt self.orig_bcs = bcs self.num_fields = len(u0.function_space()) - self.ks = [Function(V) for _ in range(num_stages)] + self.ks = [backend_cls.Function(V) for _ in range(num_stages)] # "k" is a generic function for which we will solve the # NVLP for the next stage value # "ks" is a list of functions for the stage values # that we update as we go. We need to remember the # stage values we've computed earlier in the time step... - F = as_linear_form(F, u0) - stage_F, kgac, bcnew, (a_vals, d_val) = getFormDIRK( F, self.ks, butcher_tableau, t, dt, u0, bcs=bcs) k, g, a, c = kgac @@ -132,9 +128,10 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None, stage_Jp = None if Fp is not None: - Fp = as_linear_form(Fp, u0) - stage_Fp, *_ = self.get_form_and_bcs(self.ks, F=Fp, bcs=()) - stage_Jp = derivative(stage_Fp, k) + Fp_linear = len(Fp.arguments()) == 2 + ks_Fp = Fp.arguments()[1] if Fp_linear else self.ks + stage_Fp, *_ = self.get_form_and_bcs(ks_Fp, F=Fp, bcs=()) + stage_Jp = lhs(stage_Fp) if Fp_linear else backend_cls.derivative(stage_Fp, k) appctx_irksome = {"stepper": self} if appctx is None: @@ -143,14 +140,14 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None, appctx = {**appctx, **appctx_irksome} self.appctx = appctx - self.problem = NonlinearVariationalProblem( + self.problem = backend_cls.create_variational_problem( stage_F, k, bcs=bcnew, Jp=stage_Jp, form_compiler_parameters=kwargs.pop("form_compiler_parameters", None), is_linear=kwargs.pop("is_linear", False), restrict=kwargs.pop("restrict", False), + constant_jacobian=kwargs.pop("constant_jacobian", False), ) - self.problem._constant_jacobian = kwargs.pop("constant_jacobian", False) - self.solver = NonlinearVariationalSolver( + self.solver = backend_cls.create_variational_solver( self.problem, appctx=appctx, nullspace=nullspace, transpose_nullspace=transpose_nullspace, @@ -220,4 +217,4 @@ def invalidate_jacobian(self): """ Forces the matrix to be reassembled next time it is required. """ - LinearVariationalSolver.invalidate_jacobian(self.solver) + self._backend.invalidate_jacobian(self.solver) diff --git a/irksome/discontinuous_galerkin_stepper.py b/irksome/discontinuous_galerkin_stepper.py index 20d5588e..ee3e8d77 100644 --- a/irksome/discontinuous_galerkin_stepper.py +++ b/irksome/discontinuous_galerkin_stepper.py @@ -9,7 +9,7 @@ from .ufl.deriv import TimeDerivative, expand_time_derivatives from .ufl.manipulation import split_time_derivative_terms, remove_time_derivatives from .scheme import create_time_quadrature, ufc_line -from .tools import IA, dot, reshape, replace +from .tools import IA, dot, extract_timedep_arguments, reshape, replace from .constant import vecconst from .tableaux.ButcherTableaux import CollocationButcherTableau from .stage_value import getFormStage @@ -41,7 +41,7 @@ def getElement(basis_type, order): def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"): - v, = F.arguments() + v, u = extract_timedep_arguments(F, u0) V = v.function_space() assert V == u0.function_space() @@ -68,7 +68,7 @@ def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"): dtusub = dot(trial_dvals.T, u_np) # preprocess time derivatives - split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u0,)) + split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u,)) F_dtless = remove_time_derivatives(split_form.time) if F_dtless.empty(): Fnew = F_dtless @@ -81,8 +81,8 @@ def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"): u_at_0 = dot(L_at_0.T, u_np) v_at_0 = dot(L_at_0.T, v_np) - repl_tminus = {v: v_at_0} - repl_tplus = {v: v_at_0, u0: u_at_0} + repl_tminus = {v: v_at_0, u: u0} + repl_tplus = {v: v_at_0, u: u_at_0} Fnew = replace(F_dtless, repl_tplus) - replace(F_dtless, repl_tminus) F_remainder = F elif deriv_type == "weak": @@ -93,8 +93,8 @@ def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"): u_at_01 = dot(L_at_01.T, u_np) v_at_01 = dot(L_at_01.T, v_np) - repl_tminus = {v: v_at_01[0]} - repl_tnew = {v: v_at_01[1], u0: u_at_01[1], t: t + dt} + repl_tminus = {v: v_at_01[0], u: u0} + repl_tnew = {v: v_at_01[1], u: u_at_01[1], t: t + dt} Fnew = replace(F_dtless, repl_tnew) - replace(F_dtless, repl_tminus) # Terms with time derivatives: -(g(u), Dt(v)) @@ -103,20 +103,20 @@ def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"): for q in range(len(qpts)): repl = {t: t + qpts[q] * dt, v: dtvsub[q], - u0: usub[q]} + u: usub[q]} Fnew -= replace(F_dtless, repl) F_remainder = split_form.remainder else: raise ValueError(f"Unrecongnized deriv_type {deriv_type}") # Handle the rest of the terms - F_remainder = expand_time_derivatives(F_remainder, t=t, timedep_coeffs=(u0,)) - dtu0 = TimeDerivative(u0) + F_remainder = expand_time_derivatives(F_remainder, t=t, timedep_coeffs=(u,)) + dtu = TimeDerivative(u) for q in range(len(qpts)): repl = {t: t + qpts[q] * dt, v: vsub[q] * dt, - u0: usub[q], - dtu0: dtusub[q] / dt} + u: usub[q], + dtu: dtusub[q] / dt} Fnew += replace(F_remainder, repl) return Fnew diff --git a/irksome/galerkin_stepper.py b/irksome/galerkin_stepper.py index 62fe38f7..99bb230f 100644 --- a/irksome/galerkin_stepper.py +++ b/irksome/galerkin_stepper.py @@ -10,7 +10,7 @@ from .ufl.estimate_degrees import TimeDegreeEstimator, get_degree_mapping from .labeling import split_quadrature, as_form from .scheme import create_time_quadrature, ufc_line -from .tools import AI, IA, dot, fields_to_components, reshape, replace +from .tools import AI, IA, dot, extract_timedep_arguments, fields_to_components, reshape, replace from .constant import vecconst from .discontinuous_galerkin_stepper import getElement as getTestElement from .integrated_lagrange import IntegratedLagrange @@ -66,12 +66,12 @@ def getElements(basis_type, order): def getTermGalerkin(F, L_trial, L_test, Q, t, dt, u0, stages, test, aux_indices): - # preprocess time derivatives - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) - v, = F.arguments() + v, u = extract_timedep_arguments(F, u0) V = v.function_space() assert V == u0.function_space() - i0, = L_trial.entity_dofs()[0][0] + + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u,)) qpts = Q.get_points() qwts = Q.get_weights() @@ -89,14 +89,15 @@ def getTermGalerkin(F, L_trial, L_test, Q, t, dt, u0, stages, test, aux_indices) qpts = vecconst(np.reshape(qpts, (-1,))) # set up the pieces we need to work with to do our substitutions + i0, = L_trial.entity_dofs()[0][0] v_np = reshape(test, (-1, *v.ufl_shape)) - w_np = reshape(stages, (-1, *u0.ufl_shape)) + w_np = reshape(stages, (-1, *u.ufl_shape)) u_np = np.insert(w_np, i0, reshape(u0, (1, *u0.ufl_shape)), axis=0) vsub = dot(test_vals_w.T, v_np) usub = dot(trial_vals.T, u_np) - dtu0sub = dot(trial_dvals.T, u_np) - dtu0 = TimeDerivative(u0) + dtusub = dot(trial_dvals.T, u_np) + dtu = TimeDerivative(u) # discretize the auxiliary fields in the DG test space if aux_indices is not None: @@ -108,8 +109,8 @@ def getTermGalerkin(F, L_trial, L_test, Q, t, dt, u0, stages, test, aux_indices) for q in range(len(qpts)): repl[q] = {t: t + qpts[q] * dt, v: vsub[q] * dt, - u0: usub[q], - dtu0: dtu0sub[q] / dt} + u: usub[q], + dtu: dtusub[q] / dt} Fnew = sum(replace(F, repl[q]) for q in repl) return Fnew diff --git a/irksome/imex.py b/irksome/imex.py index 25513b52..25dcc363 100644 --- a/irksome/imex.py +++ b/irksome/imex.py @@ -1,18 +1,16 @@ import FIAT import numpy as np -from firedrake import (Function, LinearVariationalSolver, - NonlinearVariationalProblem, - NonlinearVariationalSolver, TestFunction, - as_ufl, dx, inner) -from ufl import zero +from firedrake import Function, TestFunction +from ufl import Form, as_ufl, dx, inner +from .backend import get_backend +from .bcs import bc2space +from .constant import MeshConstant, vecconst +from .stage_value import getFormStage +from .tools import (AI, IA, extract_timedep_arguments, reshape, replace, + getNullspace, get_stage_space) from .tableaux.ButcherTableaux import RadauIIA from .ufl.deriv import TimeDerivative, expand_time_derivatives -from .stage_value import getFormStage -from .tools import AI, IA, reshape, replace, getNullspace, get_stage_space -from .bcs import bc2space -from .constant import MeshConstant, ConstantOrZero -from .labeling import as_linear_form def riia_explicit_coeffs(k): @@ -41,7 +39,9 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None): """Processes the explicitly split-off part for a RadauIIA-IMEX method. Returns the forms for both the iterator and propagator, which really just differ by which constants are in them.""" - v = Fexp.arguments()[0] + v, u = extract_timedep_arguments(Fexp, u0) + V = v.function_space() + assert V == u0.function_space() Vbig = UU.function_space() VV = TestFunction(Vbig) @@ -49,8 +49,6 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None): Aexp = riia_explicit_coeffs(num_stages) - vecconst = np.vectorize(ConstantOrZero) - Aprop = vecconst(Aexp) Ait = vecconst(butch.A) C = vecconst(butch.c) @@ -58,22 +56,22 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None): v_np = reshape(VV, (num_stages, *u0.ufl_shape)) u_np = reshape(UU, (num_stages, *u0.ufl_shape)) - Fit = zero() - Fprop = zero() + Fit = Form([]) + Fprop = Form([]) # preprocess time derivatives - Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u0,)) + Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u,)) if splitting == AI: for i in range(num_stages): # replace test function - repl = {v: v_np[i]} + repl = {v: v_np[i], u: u0} Ftmp = replace(Fexp, repl) # replace the solution with stage values for j in range(num_stages): repl = {t: t + C[j] * dt, - u0: u_np[j]} + u: u_np[j]} # and sum the contribution replF = replace(Ftmp, repl) @@ -83,7 +81,7 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None): # diagonal contribution to iterator for i in range(num_stages): repl = {t: t+C[i]*dt, - u0: u_np[i], + u: u_np[i], v: v_np[i]} Fit += dt * replace(Fexp, repl) @@ -93,13 +91,13 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None): for i in range(num_stages): # replace test function - repl = {v: v_np[i]} + repl = {v: v_np[i], u: u0} Ftmp = replace(Fexp, repl) # replace the solution with stage values for j in range(num_stages): repl = {t: t + C[j] * dt, - u0: u_np[j]} + u: u_np[j]} # and sum the contribution Fprop += AinvAexp[i, j] * dt * replace(Ftmp, repl) @@ -168,8 +166,10 @@ def __init__(self, F, Fexp, butcher_tableau, nullspace=None, num_its_initial=0, num_its_per_step=0, + backend="firedrake", **kwargs): assert isinstance(butcher_tableau, RadauIIA) + self._backend = backend_cls = get_backend(backend) self.u0 = u0 self.t = t @@ -193,10 +193,8 @@ def __init__(self, F, Fexp, butcher_tableau, # the update information on the floor. V = u0.function_space() Vbig = get_stage_space(V, self.num_stages) - UU = Function(Vbig) + UU = backend_cls.Function(Vbig) - F = as_linear_form(F, u0) - Fexp = as_linear_form(Fexp, u0) restrict = kwargs.pop("restrict", False) is_linear = kwargs.pop("is_linear", False) constant_jacobian = kwargs.pop("constant_jacobian", False) @@ -217,15 +215,16 @@ def __init__(self, F, Fexp, butcher_tableau, Fit, Fprop = getFormExplicit( Fexp, butcher_tableau, u0, UU_old, t, dt, splitting) - self.itprob = NonlinearVariationalProblem( + self.itprob = backend_cls.create_variational_problem( Fbig + Fit, UU, bcs=bigBCs, - is_linear=is_linear, restrict=restrict) - self.propprob = NonlinearVariationalProblem( + is_linear=is_linear, restrict=restrict, + constant_jacobian=constant_jacobian, + ) + self.propprob = backend_cls.create_variational_problem( Fbig + Fprop, UU, bcs=bigBCs, - is_linear=is_linear, restrict=restrict) - - self.itprob._constant_jacobian = constant_jacobian - self.propprob._constant_jacobian = constant_jacobian + is_linear=is_linear, restrict=restrict, + constant_jacobian=constant_jacobian, + ) self.F = F self.orig_bcs = bcs @@ -238,11 +237,11 @@ def __init__(self, F, Fexp, butcher_tableau, else: appctx = {**appctx, **appctx_irksome} - self.it_solver = NonlinearVariationalSolver( + self.it_solver = backend_cls.create_variational_solver( self.itprob, appctx=appctx, solver_parameters=it_solver_parameters, nullspace=nsp, **kwargs) - self.prop_solver = NonlinearVariationalSolver( + self.prop_solver = backend_cls.create_variational_solver( self.propprob, appctx=appctx, solver_parameters=prop_solver_parameters, nullspace=nsp, **kwargs) @@ -308,29 +307,34 @@ def invalidate_jacobian(self): """ Forces the matrix to be reassembled next time it is required. """ - LinearVariationalSolver.invalidate_jacobian(self.prop_solver) - LinearVariationalSolver.invalidate_jacobian(self.it_solver) + self._backend.invalidate_jacobian(self.prop_solver) + self._backend.invalidate_jacobian(self.it_solver) def getFormsDIRKIMEX(F, Fexp, ks, khats, butch, t, dt, u0, bcs=None): if bcs is None: bcs = [] - - # preprocess time derivatives - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) - Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u0,)) - - v = F.arguments()[0] + v, u = extract_timedep_arguments(F, u0) V = v.function_space() assert V == u0.function_space() + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u,)) + Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u,)) + num_stages = butch.num_stages - k = Function(V) + k0 = Function(V) g = Function(V) - khat = Function(V) + khat0 = Function(V) ghat = Function(V) vhat = TestFunction(V) + if u == u0: + k = k0 + khat = khat0 + else: + k = u + khat = u # Note: the Constant c is used for substitution in both the # implicit variational form and BC's, and we update it for each stage in @@ -342,19 +346,15 @@ def getFormsDIRKIMEX(F, Fexp, ks, khats, butch, t, dt, u0, bcs=None): chat = MC.Constant(1.0) a = MC.Constant(1.0) - # preprocess time derivatives - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) - Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u0,)) - # Implicit replacement, solve at time t + c * dt, for k repl = {t: t + c * dt, - u0: g + dt * a * k, - TimeDerivative(u0): k} + u: g + dt * a * k, + TimeDerivative(u): k} stage_F = replace(F, repl) # Explicit replacement, solve at time t + chat * dt, for khat replhat = {t: t + chat * dt, - u0: ghat} + u: ghat} Fhat = inner(khat, vhat)*dx + replace(Fexp, replhat) @@ -386,7 +386,7 @@ def getFormsDIRKIMEX(F, Fexp, ks, khats, butch, t, dt, u0, bcs=None): gdat /= dt*d_val bcnew.append(bc.reconstruct(g=gdat)) - return stage_F, (k, g, a, c), bcnew, Fhat, (khat, ghat, chat), (a_vals, ahat_vals, d_val) + return stage_F, (k0, g, a, c), bcnew, Fhat, (khat0, ghat, chat), (a_vals, ahat_vals, d_val) class DIRKIMEXMethod: @@ -399,8 +399,10 @@ class DIRKIMEXMethod: """ def __init__(self, F, F_explicit, butcher_tableau, t, dt, u0, bcs=None, - solver_parameters=None, mass_parameters=None, appctx=None, nullspace=None, **kwargs): + solver_parameters=None, mass_parameters=None, appctx=None, nullspace=None, + backend="firedrake", **kwargs): assert butcher_tableau.is_dirk_imex + self._backend = backend_cls = get_backend(backend) self.num_steps = 0 self.num_nonlinear_iterations = 0 @@ -419,7 +421,6 @@ def __init__(self, F, F_explicit, butcher_tableau, t, dt, u0, bcs=None, self.ks = [Function(V) for _ in range(self.num_stages)] self.k_hat_s = [Function(V) for _ in range(self.num_stages)] - F = as_linear_form(F, u0) restrict = kwargs.pop("restrict", False) is_linear = kwargs.pop("is_linear", False) constant_jacobian = kwargs.pop("constant_jacobian", False) @@ -444,17 +445,25 @@ def __init__(self, F, F_explicit, butcher_tableau, t, dt, u0, bcs=None, else: appctx = {**appctx, **appctx_irksome} - self.problem = NonlinearVariationalProblem(stage_F, k, bcnew, - is_linear=is_linear, restrict=restrict) - self.problem._constant_jacobian = constant_jacobian - self.solver = NonlinearVariationalSolver(self.problem, appctx=appctx, - solver_parameters=solver_parameters, - nullspace=nullspace, **kwargs) - - self.mass_problem = NonlinearVariationalProblem(Fhat, khat, is_linear=True) - self.problem._constant_jacobian = True - self.mass_solver = NonlinearVariationalSolver(self.mass_problem, - solver_parameters=mass_parameters) + self.problem = backend_cls.create_variational_problem( + stage_F, k, bcnew, + is_linear=is_linear, + restrict=restrict, + constant_jacobian=constant_jacobian, + ) + self.solver = backend_cls.create_variational_solver( + self.problem, appctx=appctx, + solver_parameters=solver_parameters, + nullspace=nullspace, **kwargs, + ) + self.mass_problem = backend_cls.create_variational_problem( + Fhat, khat, is_linear=True, + constant_jacobian=True, + ) + self.mass_solver = backend_cls.create_variational_solver( + self.mass_problem, + solver_parameters=mass_parameters, + ) self.kgac = k, g, a, c self.kgchat = khat, ghat, chat @@ -634,4 +643,4 @@ def invalidate_jacobian(self): """ Forces the matrix to be reassembled next time it is required. """ - LinearVariationalSolver.invalidate_jacobian(self.solver) + self._backend.invalidate_jacobian(self.solver) diff --git a/irksome/labeling.py b/irksome/labeling.py index 3b0e10e3..12730d0f 100644 --- a/irksome/labeling.py +++ b/irksome/labeling.py @@ -3,7 +3,6 @@ from firedrake.fml import Label, keep, drop, LabelledForm from collections import defaultdict from .scheme import create_time_quadrature -from .tools import replace import numpy as np explicit = Label("explicit") @@ -173,21 +172,3 @@ def as_form(form): if isinstance(form, LabelledForm): form = Form([]) if len(form) == 0 else form.form return form - - -def as_linear_form(F, u0): - """ - If `F` is a bilinear :class:`Form` compute a linear - :class:`Form` by replacing the trial function with `u0`, - otherwise return `F`. - """ - form = as_form(F) - nargs = len(form.arguments()) - if nargs == 2: - if u0 in form.coefficients(): - raise ValueError("The provided bilinear form must not depend on the solution") - test, trial = form.arguments() - F = replace(F, {trial: u0}) - elif nargs != 1: - raise ValueError("Expecting a Form with 1 or 2 arguments.") - return F diff --git a/irksome/nystrom_dirk_stepper.py b/irksome/nystrom_dirk_stepper.py index 370b4867..dc67fb2a 100644 --- a/irksome/nystrom_dirk_stepper.py +++ b/irksome/nystrom_dirk_stepper.py @@ -5,7 +5,7 @@ from ufl.constantvalue import as_ufl from .ufl.deriv import Dt, expand_time_derivatives -from .tools import replace +from .tools import extract_timedep_arguments, replace from .bcs import bc2space from .constant import MeshConstant, vecconst from .nystrom_stepper import butcher_to_nystrom, NystromTableau @@ -17,11 +17,13 @@ def getFormDIRKNystrom(F, ks, tableau, t, dt, u0, ut0, bcs=None, bc_type=None): if bc_type is None: bc_type = "DAE" - v, = F.arguments() + v, u = extract_timedep_arguments(F, u0) V = v.function_space() - msh = V.mesh() assert V == u0.function_space() + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u,)) + num_stages = tableau.num_stages k = Function(V) g1 = Function(V) @@ -31,14 +33,12 @@ def getFormDIRKNystrom(F, ks, tableau, t, dt, u0, ut0, bcs=None, bc_type=None): # variational form and BC's, and we update it for each stage in # the loop over stages in the advance method. The Constants a # and abar are used similarly in the variational form + msh = V.mesh() MC = MeshConstant(msh) c = MC.Constant(1.0) a = MC.Constant(1.0) abar = MC.Constant(1.0) - # preprocess time derivatives - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) - repl = {t: t + c * dt, u0: g1 + k * (abar * dt**2), Dt(u0): g2 + k * (a * dt), diff --git a/irksome/nystrom_stepper.py b/irksome/nystrom_stepper.py index b17ae8a4..b2a8085c 100644 --- a/irksome/nystrom_stepper.py +++ b/irksome/nystrom_stepper.py @@ -1,11 +1,11 @@ from .base_time_stepper import StageCoupledTimeStepper from .bcs import BCStageData, bc2space from .ufl.deriv import Dt, TimeDerivative, expand_time_derivatives -from .tools import dot, reshape, replace +from .tools import dot, extract_timedep_arguments, reshape, replace from .constant import vecconst from firedrake import TestFunction, as_ufl import numpy -from ufl import zero +from ufl import Form class NystromTableau: @@ -76,12 +76,13 @@ def getFormNystrom(F, tableau, t, dt, u0, ut0, stages, if bc_type is None: bc_type = "DAE" - # preprocess time derivatives - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) - v, = F.arguments() + v, u = extract_timedep_arguments(F, u0) V = v.function_space() assert V == u0.function_space() + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u,)) + A = vecconst(tableau.A) Abar = vecconst(tableau.Abar) c = vecconst(tableau.c) @@ -96,15 +97,15 @@ def getFormNystrom(F, tableau, t, dt, u0, ut0, stages, Ak = dot(A, k_np) Abark = dot(Abar, k_np) - dtu = TimeDerivative(u0) + dtu = TimeDerivative(u) dt2u = TimeDerivative(dtu) - Fnew = zero() + Fnew = Form([]) for i in range(num_stages): repl = {t: t + c[i] * dt, v: v_np[i], - u0: u0 + ut0 * (c[i] * dt) + Abark[i] * dt**2, + u: u0 + ut0 * (c[i] * dt) + Abark[i] * dt**2, dtu: ut0 + Ak[i] * dt, dt2u: k_np[i]} Fnew += replace(F, repl) diff --git a/irksome/stage_derivative.py b/irksome/stage_derivative.py index eb16ad1e..c3963406 100644 --- a/irksome/stage_derivative.py +++ b/irksome/stage_derivative.py @@ -1,17 +1,15 @@ import numpy -from firedrake import Function, TestFunction -from firedrake import NonlinearVariationalProblem as NLVP -from firedrake import NonlinearVariationalSolver as NLVS -from firedrake import assemble, dx, inner, norm, as_tensor +from firedrake import TestFunction from firedrake.bcs import EquationBC, EquationBCSplit from FIAT import ufc_simplex from FIAT.barycentric_interpolation import LagrangePolynomialSet -from ufl.constantvalue import as_ufl +from ufl import as_tensor, as_ufl, dx, inner + from .tableaux.ButcherTableaux import CollocationButcherTableau from .constant import vecconst -from .tools import AI, dot, replace, reshape, fields_to_components +from .tools import AI, dot, extract_timedep_arguments, fields_to_components, replace, reshape from .ufl.deriv import Dt, TimeDerivative, expand_time_derivatives from .bcs import EmbeddedBCData, BCStageData, extract_bcs, bc2space, stage2spaces4bc from .ufl.manipulation import split_time_derivative_terms @@ -60,12 +58,11 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI, a """ if bc_type is None: bc_type = "DAE" - - # preprocess time derivatives - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) - v, = F.arguments() + v, u = extract_timedep_arguments(F, u0) V = v.function_space() assert V == u0.function_space() + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u,)) c = vecconst(butch.c) bA1, bA2 = splitting(butch.A) @@ -83,10 +80,10 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI, a # set up the pieces we need to work with to do our substitutions v_np = reshape(test, (num_stages, *v.ufl_shape)) - w_np = reshape(stages, (num_stages, *u0.ufl_shape)) + w_np = reshape(stages, (num_stages, *u.ufl_shape)) A1w = dot(A1, w_np) A2invw = dot(A2inv, w_np) - dtu = TimeDerivative(u0) + dtu = TimeDerivative(u) aux_components = fields_to_components(V, aux_indices or []) @@ -101,7 +98,7 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI, a repl[i] = {t: t + c[i] * dt, v: v_np[i], - u0: usub, + u: usub, dtu: dtusub} Fnew = sum(replace(F, repl[i]) for i in range(num_stages)) @@ -115,7 +112,7 @@ def bc2stagebc(bc, i): if isinstance(bc, EquationBCSplit): raise NotImplementedError("EquationBC not implemented for ODE formulation") gorig = as_ufl(bc._original_arg) - gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=(u0,)) + gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=(u,)) gcur = replace(gfoo, {t: t + c[i] * dt}) return BCStageData(bc, gcur, u0, stages, i) @@ -128,7 +125,7 @@ def bc2stagebc(bc, i): def bc2stagebc(bc, i): if isinstance(bc, EquationBCSplit): - F_bc_orig = expand_time_derivatives(bc.f, t=t, timedep_coeffs=(u0,)) + F_bc_orig = expand_time_derivatives(bc.f, t=t, timedep_coeffs=(u,)) F_bc_new = replace(F_bc_orig, repl[i]) Vbigi = stage2spaces4bc(bc, V, Vbig, i) return EquationBC(F_bc_new == 0, stages, bc.sub_domain, V=Vbigi, @@ -231,8 +228,8 @@ def get_form_and_bcs(self, stages, F=None, bcs=None, tableau=None): bcs = self.orig_bcs return getForm(F or self.F, tableau or self.butcher_tableau, - self.t, self.dt, - self.u0, stages, bcs, self.bc_type, + self.t, self.dt, self.u0, + stages, bcs, self.bc_type, splitting=self.splitting, aux_indices=self.aux_indices) @@ -330,7 +327,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, self.onscale_factor = onscale_factor self.safety_factor = safety_factor - self.error_func = Function(u0.function_space()) + self.error_func = self._backend.Function(u0.function_space()) self.tol = tol self.err_old = 0.0 self.contreject = 0 @@ -352,6 +349,7 @@ def _estimate_error(self): the temporal truncation error by taking the norm of the difference between the new solutions computed by the two methods. Typically will not be called by the end user.""" + backend_cls = self._backend dtc = float(self.dt) delb = self.delb ws = self.stages.subfunctions @@ -360,13 +358,13 @@ def _estimate_error(self): u0 = self.u0 # Initialize e to be gamma*h*f(old value of u) - error_func = Function(u0.function_space()) + error_func = backend_cls.Function(u0.function_space()) # Only do the hard stuff if gamma0 is not zero if self.gamma0 != 0.0: - error_test = TestFunction(u0.function_space()) + error_test = backend_cls.TestFunction(u0.function_space()) f_form = inner(error_func, error_test)*dx-self.gamma0*dtc*self.dtless_form - f_problem = NLVP(f_form, error_func, bcs=self.embbc) - f_solver = NLVS(f_problem, solver_parameters=self.gamma0_params) + f_problem = backend_cls.create_variational_problem(f_form, error_func, bcs=self.embbc) + f_solver = backend_cls.create_variational_solver(f_problem, solver_parameters=self.gamma0_params) f_solver.solve() # Accumulate delta-b terms over stages @@ -374,7 +372,7 @@ def _estimate_error(self): for s in range(ns): for i, e in enumerate(error_func_bits): e += dtc*float(delb[s])*ws[nf*s+i] - return norm(assemble(error_func)) + return backend_cls.norm(error_func) def advance(self): """Attempts to advances the system from time `t` to time `t + diff --git a/irksome/stage_value.py b/irksome/stage_value.py index 3739c3f9..8e77ca90 100644 --- a/irksome/stage_value.py +++ b/irksome/stage_value.py @@ -1,18 +1,17 @@ # formulate RK methods to solve for stage values rather than the stage derivatives. import numpy +from firedrake import TestFunction + from FIAT import Bernstein, ufc_simplex from FIAT.barycentric_interpolation import LagrangePolynomialSet -from firedrake import (Function, NonlinearVariationalProblem, - NonlinearVariationalSolver, TestFunction, dx, - inner) -from ufl import as_tensor, Form -from ufl.constantvalue import as_ufl + +from ufl import Form, as_tensor, as_ufl, dx, inner from .bcs import stage2spaces4bc from .tableaux.ButcherTableaux import CollocationButcherTableau from .ufl.deriv import expand_time_derivatives from .ufl.manipulation import split_time_derivative_terms, remove_time_derivatives -from .tools import AI, dot, reshape, replace +from .tools import AI, extract_timedep_arguments, dot, reshape, replace from .constant import vecconst from .base_time_stepper import StageCoupledTimeStepper @@ -77,7 +76,7 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=AI, vandermond - `bcnew`, a list of :class:`firedrake.DirichletBC` objects to be posed on the stages """ - v, = F.arguments() + v, u = extract_timedep_arguments(F, u0) V = v.function_space() assert V == u0.function_space() @@ -105,7 +104,7 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=AI, vandermond # assuming we have something of the form inner(Dt(g(u0)), v)*dx # For each stage i, this gets replaced with # inner((g(stages[i]) - g(u0))/dt, v)*dx - split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u0,)) + split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u,)) F_dtless = remove_time_derivatives(split_form.time) F_remainder = expand_time_derivatives(split_form.remainder, t=t, timedep_coeffs=()) @@ -117,10 +116,10 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=AI, vandermond for i in range(num_stages): repl_new = {t: t + c[i] * dt, v: A2invTv[i], - u0: w_np[i]} + u: w_np[i]} # Evaluate g at the old solution u0 (not substituted) and # old time t (not substituted). - repl_old = {v: A2invTv[i]} + repl_old = {v: A2invTv[i], u: u0} Fnew += replace(F_dtless, repl_new) - replace(F_dtless, repl_old) # Handle the rest of the terms @@ -128,7 +127,7 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=AI, vandermond # replace the solution with stage values repl = {t: t + c[i] * dt, v: A1Tv[i] * dt, - u0: w_np[i]} + u: w_np[i]} Fnew += replace(F_remainder, repl) if bcs is None: @@ -187,7 +186,8 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, self.set_initial_guess() if use_collocation_update: - # Use the terminal value of the collocation polynomial to update the solution. Note: collocation update is only implemented for constant-in-time boundary conditions. + # Use the terminal value of the collocation polynomial to update the solution. + # Note: collocation update is only implemented for constant-in-time boundary conditions. # TODO: create an assertion to check for constant-in-time boundary conditions. self.collocation_vander = self.tabulate_poly((1.0,)) @@ -222,23 +222,24 @@ def _update_stiff_acc(self): def get_update_solver(self, update_solver_parameters): # only form update stuff if we need it # which means neither stiffly accurate nor Vandermonde - v, = self.F.arguments() - unew = Function(self.u0.function_space()) - Fupdate = inner(unew - self.u0, v) * dx - + backend_cls = self._backend C = vecconst(self.butcher_tableau.c) B = vecconst(self.butcher_tableau.b) F = self.F t = self.t dt = self.dt u0 = self.u0 - split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u0,)) + v, u = extract_timedep_arguments(F, u0) + unew = backend_cls.Function(u.function_space()) + Fupdate = inner(unew - self.u0, v) * dx + + split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u,)) F_remainder = expand_time_derivatives(split_form.remainder, t=t, timedep_coeffs=()) u_np = to_value(self.u0, self.stages, self.vandermonde) for i in range(self.num_stages): repl = {t: t + C[i] * dt, - u0: u_np[i]} + u: u_np[i]} Fupdate += dt * B[i] * replace(F_remainder, repl) # And the BC's for the update -- just the original BC at t+dt @@ -248,12 +249,8 @@ def get_update_solver(self, update_solver_parameters): gcur = replace(bcarg, {t: t + dt}) update_bcs.append(bc.reconstruct(g=gcur)) - update_problem = NonlinearVariationalProblem( - Fupdate, unew, update_bcs) - - update_solver = NonlinearVariationalSolver( - update_problem, - solver_parameters=update_solver_parameters) + update_problem = backend_cls.create_variational_probelm(Fupdate, unew, update_bcs) + update_solver = backend_cls.create_variational_solver(update_problem, solver_parameters=update_solver_parameters) return unew, update_solver diff --git a/irksome/tools.py b/irksome/tools.py index ea9f493d..c7443cbe 100644 --- a/irksome/tools.py +++ b/irksome/tools.py @@ -2,7 +2,6 @@ from functools import reduce import numpy -from firedrake.fml import LabelledForm, Term from firedrake import VectorSpaceBasis, MixedVectorSpaceBasis from ufl.algorithms.analysis import extract_type from ufl import as_tensor @@ -46,6 +45,18 @@ def split_stages(V, stages): return ks +def extract_timedep_arguments(F, u0): + """Return both arguments if ``F`` is a bilinear form, otherwise + return the unique argument and ``u0``. + """ + try: + v, u = F.arguments() + except ValueError: + v, = F.arguments() + u = u0 + return v, u + + def fields_to_components(V, fields): """ Returns the scalar component indices corresponding to the possibly @@ -110,12 +121,7 @@ def getNullspace(V, Vbig, num_stages, nullspace): def replace(e, mapping): """A wrapper for ufl.replace that allows numpy arrays.""" cmapping = {k: as_tensor(v) for k, v in mapping.items()} - if isinstance(e, LabelledForm): - enew = LabelledForm(*(Term(ufl_replace(term.form, cmapping), term.labels) - for term in e.terms)) - return enew - else: - return ufl_replace(e, cmapping) + return ufl_replace(e, cmapping) # Utility functions that help us refactor diff --git a/irksome/ufl/deriv.py b/irksome/ufl/deriv.py index 8eb0a492..2de30b0c 100644 --- a/irksome/ufl/deriv.py +++ b/irksome/ufl/deriv.py @@ -9,7 +9,7 @@ from ufl.algorithms.apply_derivatives import GenericDerivativeRuleset from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering from ufl.form import BaseForm -from ufl.classes import (Coefficient, Conj, Curl, ConstantValue, Derivative, +from ufl.classes import (Argument, Coefficient, Conj, Curl, ConstantValue, Derivative, Div, Expr, Grad, Indexed, ReferenceGrad, ReferenceValue, SpatialCoordinate, Variable) from ufl.corealg.multifunction import MultiFunction @@ -106,6 +106,7 @@ def constant(self, o): else: return self.independent_terminal(o) + @process.register(Argument) @process.register(Coefficient) @process.register(SpatialCoordinate) def terminal(self, o):