diff --git a/demo_dolfinx.py b/demo_dolfinx.py new file mode 100644 index 00000000..88369a3b --- /dev/null +++ b/demo_dolfinx.py @@ -0,0 +1,58 @@ +from mpi4py import MPI +import dolfinx +import ufl +from irksome import GaussLegendre, Dt, MeshConstant +from irksome.stage_derivative import getForm +from irksome.tools import get_stage_space +from ufl import pi, atan, div, grad, inner, dx + +butcher_tableau = GaussLegendre(2) +N = 64 + +x0 = 0.0 +x1 = 10.0 +y0 = 0.0 +y1 = 10.0 + +msh = dolfinx.mesh.create_rectangle(MPI.COMM_WORLD, [[x0, y0], [x1, y1]], [N, N]) +V = dolfinx.fem.functionspace(msh, ("Lagrange", 1)) +x, y = ufl.SpatialCoordinate(msh) + +MC = MeshConstant(msh, backend="dolfinx") +dt = MC.Constant(10 / N) +t = MC.Constant(0.0) + +Constant = lambda val: dolfinx.fem.Constant(msh, val) +S = Constant(2.0) +C = Constant(1000.0) + + +B = ( + (x - Constant(x0)) + * (x - Constant(x1)) + * (y - Constant(y0)) + * (y - Constant(y1)) + / C +) +R = (x * x + y * y) ** 0.5 +uexact = B * atan(t) * (pi / 2.0 - atan(S * (R - t))) + +rhs = Dt(uexact) - div(grad(uexact)) + +u = dolfinx.fem.Function(V) +u.interpolate(dolfinx.fem.Expression(uexact, V.element.interpolation_points)) + +v = ufl.TestFunction(V) +F = inner(Dt(u), v) * dx + inner(grad(u), grad(v)) * dx - inner(rhs, v) * dx + +bc = [] +# bc = DirichletBC(V, 0, "on_boundary") + +# Get the function space for the stage-coupled problem and a function to hold the stages we're computing:: + +Vbig = get_stage_space(V, butcher_tableau.num_stages, backend="dolfinx") +k = dolfinx.fem.Function(Vbig) + +# Get the variational form and bcs for the stage-coupled variational problem:: + +Fnew, bcnew = getForm(F, butcher_tableau, t, dt, u, k, bcs=bc, backend="dolfinx") diff --git a/irksome/__init__.py b/irksome/__init__.py index af848619..28f5ba88 100644 --- a/irksome/__init__.py +++ b/irksome/__init__.py @@ -29,6 +29,7 @@ from .scheme import ContinuousPetrovGalerkinScheme, DiscontinuousGalerkinScheme from .scheme import GalerkinCollocationScheme + __all__ = [ "AdamsBashforth", "AdamsMoulton", @@ -65,7 +66,6 @@ from .bcs import BoundsConstrainedDirichletBC from .dirk_stepper import DIRKTimeStepper from .imex import RadauIIAIMEXMethod, DIRKIMEXMethod - from .stage_derivative import getForm from .nystrom_dirk_stepper import DIRKNystromTimeStepper, ExplicitNystromTimeStepper from .nystrom_stepper import ( StageDerivativeNystromTimeStepper, @@ -91,7 +91,6 @@ __all__ += [ "DIRKTimeStepper", "BoundsConstrainedDirichletBC", - "getForm", "RadauIIAIMEXMethod", "DIRKIMEXMethod", "DIRKNystromTimeStepper", diff --git a/irksome/backend.py b/irksome/backend.py index 29a81914..384dee12 100644 --- a/irksome/backend.py +++ b/irksome/backend.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Protocol, Any, Sequence import ufl from importlib import import_module @@ -7,6 +7,13 @@ class Backend(Protocol): def get_function_space(self, V: ufl.Coefficient) -> ufl.FunctionSpace: """Get a function space from the backend""" + def extract_bcs(bcs: Any) -> tuple[Any]: + """Extract boundary conditions""" + + class Function: ... + + class DirichletBC: ... + def get_stages(self, V: ufl.FunctionSpace, num_stages: int) -> ufl.Coefficient: """ Given a function space for a single time-step, get a duplicate of this space, @@ -38,6 +45,71 @@ def ConstantOrZero( def get_mesh_constant(MC: MeshConstant | None) -> ufl.core.expr.Expr: """Get a backend class to construct a mesh constant from""" + def TestFunction(space: ufl.FunctionSpace, part: int | None = None) -> ufl.Argument: + """Return a test-function that can be used by forms in the backend.""" + + def TrialFunction( + space: ufl.FunctionSpace, part: int | None = None + ) -> ufl.Argument: + """Return a trial-function that can be used by forms in the backend.""" + + def create_nonlinearvariational_problem( + F: ufl.Form, + u: ufl.Coefficient, + bcs: DirichletBC | Sequence | None = None, + **kwargs, + ) -> Any: + """Create a non-linear variational problem in the backend language.""" + + def create_nonlinearvariational_solver( + problem: Any, + solver_parameters: dict | None = None, + **kwargs, + ): + """Create a non-linear variational solver in the backend language.""" + + def create_linearvariational_problem( + a: ufl.Form, + L: ufl.Form, + u: ufl.Coefficient | Sequence[ufl.Coefficient], + bcs: DirichletBC | Sequence | None = None, + aP: ufl.Form | None = None, + **kwargs, + ) -> Any: + """Create a linear variational problem in the backend language.""" + + def create_linearvariational_solver( + problem: Any, + solver_parameters: dict | None = None, + **kwargs, + ): + """Create a linear variational solver in the backend language.""" + + def get_stage_spaces(V: ufl.FunctionSpace, num_stages: int) -> ufl.FunctionSpace: + """Create a stage space with M number of components.""" + + def norm( + v: ufl.core.expr.Expr, norm_type: str = "L2", mesh: ufl.Mesh | None = None + ) -> float: + """Compute the norm of a function in the backend language.""" + + def assemble(expr: ufl.core.expr.Expr) -> Any: + """Assemble a UFL expression in the backend language.""" + + def replace(expr: ufl.core.expr.Expr, mapping: dict) -> ufl.core.expr.Expr: + """Replace sub-expressions in a UFL expression with other expressions.""" + + def derivative( + form: ufl.Form, + u: ufl.Coefficient, + du: ufl.Argument | None = None, + coefficient_derivatives: dict | None = None, + ) -> ufl.Form: + """Compute the derivative of a form with respect to a coefficient in the backend language.""" + + def invalidate_jacobian(solver: Any): + """Invalidate the Jacobian matrix in the backend language.""" + def get_backend(backend: str) -> Backend: """Get backend class from backend name. diff --git a/irksome/backends/dolfinx.py b/irksome/backends/dolfinx.py index 8646ab52..343378ee 100644 --- a/irksome/backends/dolfinx.py +++ b/irksome/backends/dolfinx.py @@ -1,9 +1,76 @@ """DOLFINx backend for Irksome""" try: + from mpi4py import MPI import basix.ufl - import dolfinx + import dolfinx.fem.petsc import ufl + import typing + import numpy as np + + def get_stage_space(V: ufl.FunctionSpace, num_stages: int) -> ufl.FunctionSpace: + if num_stages == 1: + me = V.ufl_elemet() + else: + el = V.ufl_element() + if el.num_sub_elements > 0: + me = basix.ufl.mixed_element( + np.tile(el.sub_elements, num_stages).tolist() + ) + else: + me = basix.ufl.blocked_element(el, shape=(num_stages,)) + return dolfinx.fem.functionspace(V.mesh, me) + + def extract_bcs(bcs: typing.Any) -> tuple[typing.Any]: + """Extract boundary conditions""" + return bcs + + def create_linearvariational_problem( + a: ufl.Form, + L: ufl.Form, + u: ufl.Coefficient | typing.Sequence[ufl.Coefficient], + bcs: typing.Sequence | None = None, + aP: ufl.Form | None = None, + **kwargs, + ) -> dolfinx.fem.petsc.LinearProblem: + return dolfinx.fem.petsc.LinearProblem( + a, + L, + u, + bcs=bcs, + petsc_options_prefix="IrkSomeLinearSolver", + P=aP, + **kwargs, + ) + + def create_linearvariational_solver( + problem: dolfinx.fem.petsc.LinearProblem, + solver_parameters: dict | None = None, + **kwargs, + ): + """Create a linear variational solver that uses PETSc KSP.""" + return problem + + def create_nonlinearvariational_problem( + F: ufl.Form, + g: ufl.Coefficient, + bcs: typing.Sequence | None = None, + solver_parameters: dict | None = None, + ) -> dolfinx.fem.petsc.NonlinearProblem: + return dolfinx.fem.petsc.NonlinearProblem( + F, + g, + petsc_options_prefix="IrkSomeNonlinearSolver", + bcs=bcs, + petsc_options=solver_parameters, + ) + + def create_nonlinearvariational_solver( + problem: dolfinx.fem.petsc.NonlinearProblem, + solver_parameters: dict | None = None, + ): + """Create a non-linear variational solver that uses PETSc SNES.""" + return problem def get_function_space(u: ufl.Coefficient) -> ufl.FunctionSpace: return u.ufl_function_space() @@ -31,11 +98,20 @@ class MeshConstant(object): def __init__(self, msh): self.msh = msh try: - import scifem - except ModuleNotFoundError: - raise RuntimeError("Scifem is required to make mesh-constants") + import basix.ufl - self.V = scifem.create_real_functionspace(msh, ()) + r_el = basix.ufl.real_element( + msh.basix_cell(), value_shape=(), dtype=dolfinx.default_scalar_type + ) + self.V = dolfinx.fem.functionspace(msh, r_el) + except TypeError: + try: + import scifem + except ModuleNotFoundError: + raise RuntimeError( + "DOLFINx with real element support or Scifem is required to make mesh-constants" + ) + self.V = scifem.create_real_functionspace(msh, ()) def Constant(self, val=0.0) -> ufl.Coefficient: v = dolfinx.fem.Function(self.V) @@ -45,5 +121,52 @@ def Constant(self, val=0.0) -> ufl.Coefficient: def get_mesh_constant(MC: MeshConstant | None) -> ufl.core.expr.Expr: return MC.Constant if MC is not None else ufl.constantvalue.ComplexValue + class DirichletBC(dolfinx.fem.DirichletBC): + pass + + def norm( + v: ufl.core.Expr, norm_type: str = "L2", mesh: ufl.Mesh | None = None + ) -> float: + """Compute the norm of a function in the backend language.""" + if mesh is not None: + dx = ufl.Mesure("dx", domain=mesh) + else: + dx = ufl.dx + p = 2 + if norm_type.startswith("L"): + p = int(norm_type[1:]) + if p < 1: + raise ValueError(f"Invalid norm type {norm_type}") + expr = ufl.inner(v, v) ** (p / 2) + form = dolfinx.fem.form(expr * dx) + else: + raise NotImplementedError(f"Norm type {norm_type} not implemented") + norm_loc = dolfinx.fem.assemble_scalar(form) + return form.mesh.comm.Allreduce(MPI.IN_PLACE, norm_loc, op=MPI.SUM) ** (1 / p) + + def assemble(expr: ufl.core.Expr | float): + """Assemble a UFL expression in the backend language.""" + if isinstance(expr, float): + return float + else: + form = dolfinx.fem.form(expr) + if form.rank == 0: + return dolfinx.fem.assemble_scalar(form) + elif form.rank == 1: + return dolfinx.fem.assemble_vector(form) + elif form.rank == 2: + return dolfinx.fem.assemble_matrix(form) + else: + raise ValueError(f"Cannot assemble form of rank {form.rank}") + + replace = ufl.replace + derivative = ufl.derivative + TrialFunction = ufl.TrialFunction + Function = dolfinx.fem.Function + TestFunction = ufl.TestFunction + + def invalidate_jacobian(solver: dolfinx.fem.petsc.LinearProblem): + """Invalidate the Jacobian matrix in the backend language.""" + raise RuntimeError("DOLFINx does not support Jacobian invalidation") except ModuleNotFoundError: pass diff --git a/irksome/backends/firedrake.py b/irksome/backends/firedrake.py index 10d2b12d..581eaef7 100644 --- a/irksome/backends/firedrake.py +++ b/irksome/backends/firedrake.py @@ -1,8 +1,63 @@ """Firedrake backend for Irksome""" +from operator import mul +from functools import reduce + import firedrake import ufl from ..tools import get_stage_space +import typing + + +def get_stage_space(V: ufl.FunctionSpace, num_stages: int) -> ufl.FunctionSpace: + return reduce(mul, (V for _ in range(num_stages))) + + +def extract_bcs(bcs: typing.Any) -> tuple[typing.Any]: + """Return an iterable of boundary conditions on the residual form""" + return tuple(bc.extract_form("F") for bc in firedrake.solving._extract_bcs(bcs)) + + +def create_nonlinearvariational_problem( + F: ufl.Form, + u: ufl.Coefficient, + bcs: typing.Sequence | None = None, + **kwargs, +) -> firedrake.NonlinearVariationalProblem: + return firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, **kwargs) + + +def create_nonlinearvariational_solver( + problem: firedrake.NonlinearVariationalProblem, + solver_parameters: dict | None = None, + **kwargs, +): + """Create a non-linear variational solver that uses PETSc SNES.""" + return firedrake.NonlinearVariationalSolver( + problem, solver_parameters=solver_parameters, **kwargs + ) + + +def create_linearvariational_problem( + a: ufl.Form, + L: ufl.Form, + u: ufl.Coefficient | typing.Sequence[ufl.Coefficient], + bcs: typing.Sequence | None = None, + aP: ufl.Form | None = None, + **kwargs, +) -> firedrake.LinearVariationalProblem: + return firedrake.LinearVariationalProblem(a, L, u, bcs=bcs, aP=aP, **kwargs) + + +def create_linearvariational_solver( + problem: firedrake.LinearVariationalProblem, + solver_parameters: dict | None = None, + **kwargs, +): + """Create a linear variational solver that uses PETSc KSP.""" + return firedrake.LinearVariationalSolver( + problem, solver_parameters=solver_parameters, **kwargs + ) def get_function_space(u: ufl.Coefficient) -> firedrake.FunctionSpace: @@ -36,3 +91,23 @@ def Constant(self, val=0.0) -> ufl.Coefficient: def get_mesh_constant(MC: MeshConstant | None): return MC.Constant if MC else firedrake.Constant + + +def invalidate_jacobian(solver: firedrake.LinearVariationalSolver): + """Invalidate the Jacobian matrix in the backend language.""" + firedrake.LinearVariationalSolver.invalidate_jacobian(solver) + + +Function = firedrake.Function + +DirichletBC = firedrake.DirichletBC + +norm = firedrake.norm + +assemble = firedrake.assemble + +replace = firedrake.replace + +derivative = firedrake.derivative +TestFunction = firedrake.TestFunction +TrialFunction = firedrake.TrialFunction diff --git a/irksome/base_time_stepper.py b/irksome/base_time_stepper.py index e926a93f..be2cbe8c 100644 --- a/irksome/base_time_stepper.py +++ b/irksome/base_time_stepper.py @@ -1,10 +1,6 @@ from abc import abstractmethod -from firedrake import ( - derivative, replace, lhs, rhs, Function, TrialFunction, - LinearVariationalProblem, LinearVariationalSolver, - NonlinearVariationalProblem, NonlinearVariationalSolver, -) -from firedrake.petsc import PETSc + +from petsc4py import PETSc from .tools import AI, getNullspace, flatten_dats, split_stages from .labeling import as_form, as_linear_form from .backend import get_backend @@ -16,9 +12,18 @@ class BaseTimeStepper: """Base class for various time steppers. This is mainly to give code reuse stashing objects that are common to all the time steppers. It's a developer-level class. """ - def __init__(self, F, t, dt, u0, - bcs=None, appctx=None, nullspace=None, backend: str = "firedrake"): - self._backend = get_backend(backend) + + def __init__( + self, + F, + t, + dt, + u0, + bcs=None, + appctx=None, + nullspace=None, + backend: str = "firedrake", + ): self.F = F self.t = t self.dt = dt @@ -27,6 +32,7 @@ def __init__(self, F, t, dt, u0, bcs = () self.orig_bcs = bcs self.nullspace = nullspace + self._backend = get_backend(backend) self.V = self._backend.get_function_space(u0) appctx_base = {"stepper": self} @@ -91,21 +97,38 @@ class StageCoupledTimeStepper(BaseTimeStepper): :kwarg sample_points: An optional kwarg used to evaluate collocation methods at additional points in time. """ - def __init__(self, F, t, dt, u0, num_stages, - bcs=None, Fp=None, solver_parameters=None, - appctx=None, nullspace=None, - transpose_nullspace=None, near_nullspace=None, - splitting=None, bc_type=None, - butcher_tableau=None, bounds=None, sample_points=None, - **kwargs): + + def __init__( + self, + F, + t, + dt, + u0, + num_stages, + bcs=None, + Fp=None, + solver_parameters=None, + appctx=None, + nullspace=None, + transpose_nullspace=None, + near_nullspace=None, + splitting=None, + bc_type=None, + butcher_tableau=None, + bounds=None, + sample_points=None, + backend="firedrake", + **kwargs, + ): is_linear = False if len(as_form(F).arguments()) == 2: F = as_linear_form(F, u0) is_linear = True - super().__init__(F, t, dt, u0, - bcs=bcs, appctx=appctx, nullspace=nullspace) + super().__init__( + F, t, dt, u0, bcs=bcs, appctx=appctx, nullspace=nullspace, backend=backend + ) self.num_stages = num_stages if butcher_tableau: @@ -128,7 +151,7 @@ def __init__(self, F, t, dt, u0, num_stages, if Fp is not None: Fp = as_linear_form(Fp, u0) Fpbig, _ = self.get_form_and_bcs(stages, F=Fp, bcs=()) - Jpbig = derivative(Fpbig, stages) + Jpbig = self._backend.derivative(Fpbig, stages) V = u0.function_space() Vbig = stages.function_space() @@ -139,35 +162,49 @@ def __init__(self, F, t, dt, u0, num_stages, self.bigBCs = bigBCs if is_linear: - Fbig = replace(Fbig, {stages: TrialFunction(stages.function_space())}) - abig = lhs(Fbig) - Lbig = rhs(Fbig) - problem = LinearVariationalProblem( - abig, Lbig, stages, bcs=bigBCs, aP=Jpbig, + Fbig = self._backend.replace( + Fbig, {stages: self._backend.TrialFunction(stages.function_space())} + ) + abig = ufl.lhs(Fbig) + Lbig = ufl.rhs(Fbig) + problem = self._backend.create_linearvariational_problem( + abig, + Lbig, + stages, + bcs=bigBCs, + aP=Jpbig, form_compiler_parameters=kwargs.pop("form_compiler_parameters", None), constant_jacobian=kwargs.pop("constant_jacobian", False), restrict=kwargs.pop("restrict", False), ) - solver_constructor = LinearVariationalSolver + self.solver = self._backend.create_linearvariational_solver( + problem, + appctx=self.appctx, + nullspace=nullspace, + transpose_nullspace=transpose_nullspace, + near_nullspace=near_nullspace, + solver_parameters=solver_parameters, + ) else: - problem = NonlinearVariationalProblem( - Fbig, stages, bcs=bigBCs, Jp=Jpbig, + problem = self._backend.create_nonlinearvariational_problem( + Fbig, + stages, + bcs=bigBCs, + Jp=Jpbig, form_compiler_parameters=kwargs.pop("form_compiler_parameters", None), is_linear=kwargs.pop("is_linear", False), restrict=kwargs.pop("restrict", False), ) problem._constant_jacobian = kwargs.pop("constant_jacobian", False) - solver_constructor = NonlinearVariationalSolver - - self.problem = problem - self.solver = solver_constructor( - self.problem, appctx=self.appctx, - nullspace=nullspace, - transpose_nullspace=transpose_nullspace, - near_nullspace=near_nullspace, - solver_parameters=solver_parameters, - **kwargs, - ) + self.solver = self._backend.create_nonlinearvariational_solver( + problem, + appctx=self.appctx, + nullspace=nullspace, + transpose_nullspace=transpose_nullspace, + near_nullspace=near_nullspace, + solver_parameters=solver_parameters, + **kwargs, + ) # stash these for later in case we do bounds constraints self.stage_bounds = self.get_stage_bounds(bounds) @@ -178,7 +215,6 @@ def __init__(self, F, t, dt, u0, num_stages, def advance(self): """Advances the system from time `t` to time `t + dt`. Note: overwrites the value `u0`.""" - self.solver.solve(bounds=self.stage_bounds) self.num_steps += 1 @@ -197,7 +233,11 @@ def get_form_and_bcs(self, stages, tableau=None, F=None): pass def solver_stats(self): - return (self.num_steps, self.num_nonlinear_iterations, self.num_linear_iterations) + return ( + self.num_steps, + self.num_nonlinear_iterations, + self.num_linear_iterations, + ) def get_stages(self) -> ufl.Coefficient: return self._backend.get_stages(self.V, self.num_stages) @@ -209,30 +249,30 @@ def get_stage_bounds(self, bounds=None): Vbig = self.stages.function_space() bounds_type, lower, upper = bounds if lower is None: - slb = Function(Vbig).assign(PETSc.NINFINITY) + slb = self._backend.Function(Vbig).assign(PETSc.NINFINITY) if upper is None: - sub = Function(Vbig).assign(PETSc.INFINITY) + sub = self._backend.Function(Vbig).assign(PETSc.INFINITY) if bounds_type == "stage": if lower is not None: dats = [lower.dat] * (self.num_stages) - slb = Function(Vbig, val=flatten_dats(dats)) + slb = self._backend.Function(Vbig, val=flatten_dats(dats)) if upper is not None: dats = [upper.dat] * (self.num_stages) - sub = Function(Vbig, val=flatten_dats(dats)) + sub = self._backend.Function(Vbig, val=flatten_dats(dats)) elif bounds_type == "last_stage": V = self.u0.function_space() if lower is not None: - ninfty = Function(V).assign(PETSc.NINFINITY) - dats = [ninfty.dat] * (self.num_stages-1) + ninfty = self._backend.Function(V).assign(PETSc.NINFINITY) + dats = [ninfty.dat] * (self.num_stages - 1) dats.append(lower.dat) - slb = Function(Vbig, val=flatten_dats(dats)) + slb = self._backend.Function(Vbig, val=flatten_dats(dats)) if upper is not None: - infty = Function(V).assign(PETSc.INFINITY) - dats = [infty.dat] * (self.num_stages-1) + infty = self._backend.Function(V).assign(PETSc.INFINITY) + dats = [infty.dat] * (self.num_stages - 1) dats.append(upper.dat) - sub = Function(Vbig, val=flatten_dats(dats)) + sub = self._backend.Function(Vbig, val=flatten_dats(dats)) else: raise ValueError("Unknown bounds type") @@ -244,25 +284,27 @@ def tabulate_poly(self, sample_points): pass def build_poly(self): - ''' + """ When provided with a list of `sample_points` (intended to be in the interval [0,1]), this builds a symbolic expression for the values of the RK collocation polynomial at the corresponding points on the interval [t_n, t_{n+1}]. These are stored in the list `self.sample_values` as functions in the same FunctionSpace as `self.u0`. The resulting expressions can then be assigned to a Function on that same FunctionSpace. - ''' + """ pts = numpy.reshape(self.sample_points, (-1, 1)) vander = self.tabulate_poly(pts) - self.u_old = Function(self.u0) + self.u_old = self._backend.Function(self.u0) ks = [self.u_old] ks.extend(split_stages(self.u0.function_space(), self.stages)) num_samples = vander.shape[1] - self.sample_values = [sum(ks[j] * vander[j, i] for j in range(len(ks))) - for i in range(num_samples)] + self.sample_values = [ + sum(ks[j] * vander[j, i] for j in range(len(ks))) + for i in range(num_samples) + ] def invalidate_jacobian(self): """ Forces the matrix to be reassembled next time it is required. """ - LinearVariationalSolver.invalidate_jacobian(self.solver) + self._backend.invalidate_jacobian(self.solver) diff --git a/irksome/bcs.py b/irksome/bcs.py index 97fb980b..68b92d5a 100644 --- a/irksome/bcs.py +++ b/irksome/bcs.py @@ -1,21 +1,8 @@ -from firedrake.solving import _extract_bcs -from firedrake import ( - DirichletBC, - Function, - TestFunction, - NonlinearVariationalProblem, - NonlinearVariationalSolver, - replace, - inner, - dx, -) +from functools import lru_cache -from ufl import as_ufl - -def extract_bcs(bcs): - """Return an iterable of boundary conditions on the residual form""" - return tuple(bc.extract_form("F") for bc in _extract_bcs(bcs)) +from .backend import get_backend +from ufl import as_ufl, inner, dx def get_sub(u, indices): @@ -54,19 +41,56 @@ def EmbeddedBCData(bc, butcher_tableau, t, dt, u0, stages): V = u0.function_space() field = 0 if len(V) == 1 else bc.function_space_index() comp = (bc.function_space().component,) - ws = stages.subfunctions[field::len(V)] + ws = stages.subfunctions[field :: len(V)] btilde = butcher_tableau.btilde num_stages = butcher_tableau.num_stages - - g = replace(as_ufl(gorig), {t: t + dt}) - gorig + g = get_backend(bc._backend).replace(as_ufl(gorig), {t: t + dt}) - gorig g -= sum(get_sub(ws[j], comp) * (btilde[j] * dt) for j in range(num_stages)) return bc.reconstruct(V=Vbc, g=g) -class BoundsConstrainedDirichletBC(DirichletBC): +class _BoundsConstrainedDirichletBCMeta(type): + def __call__(cls, *args, backend: str = "firedrake", **kwargs): + if getattr(cls, "_backend_impl", False): + return super().__call__(*args, backend=backend, **kwargs) + impl_cls = _get_bounds_constrained_dirichlet_bc_class(backend) + return impl_cls(*args, backend=backend, **kwargs) + + +@lru_cache(maxsize=None) +def _get_bounds_constrained_dirichlet_bc_class(backend: str): + backend_cls = get_backend(backend) + return type( + "BoundsConstrainedDirichletBC", + (BoundsConstrainedDirichletBC, backend_cls.DirichletBC), + {"_backend_impl": True, "__module__": __name__}, + ) + + +class BoundsConstrainedDirichletBC(metaclass=_BoundsConstrainedDirichletBCMeta): """A DirichletBC with bounds-constrained data.""" - def __init__(self, V, g, sub_domain, bounds, solver_parameters=None): + _backend_impl = False + + def __init__( + self, + V, + g, + sub_domain, + bounds, + solver_parameters=None, + backend: str = "firedrake", + ): + + self.g = g + self.solver_parameters = solver_parameters + self.bounds = bounds + self._backend_name = backend + backend_cls = get_backend(backend) + self.gnew = backend_cls.Function(V) + + F = inner(self.gnew - g, backend_cls.TestFunction(V)) * dx + if solver_parameters is None: solver_parameters = { "snes_type": "vinewtonrsls", @@ -75,15 +99,9 @@ def __init__(self, V, g, sub_domain, bounds, solver_parameters=None): "ksp_type": "preonly", "mat_type": "aij", } - self.g = g - self.solver_parameters = solver_parameters - self.bounds = bounds - - self.gnew = Function(V) - F = inner(self.gnew - g, TestFunction(V)) * dx - problem = NonlinearVariationalProblem(F, self.gnew) - self.solver = NonlinearVariationalSolver( - problem, solver_parameters=self.solver_parameters + problem = backend_cls.create_linearvariational_problem(F, self.gnew) + self.solver = backend_cls.create_nonlinearvariational_solver( + problem, solver_parameters=solver_parameters ) super().__init__(V, g, sub_domain) @@ -103,4 +121,11 @@ def reconstruct(self, V=None, g=None, sub_domain=None): V = V or self.function_space() g = g or self.g sub_domain = sub_domain or self.sub_domain - return type(self)(V, g, sub_domain, self.bounds, self.solver_parameters) + return type(self)( + V, + g, + sub_domain, + self.bounds, + self.solver_parameters, + backend=self._backend_name, + ) diff --git a/irksome/galerkin_stepper.py b/irksome/galerkin_stepper.py index d2b67f97..3cd55eff 100644 --- a/irksome/galerkin_stepper.py +++ b/irksome/galerkin_stepper.py @@ -5,7 +5,7 @@ from ufl import as_ufl, as_tensor from .base_time_stepper import StageCoupledTimeStepper -from .bcs import bc2space, extract_bcs, stage2spaces4bc +from .bcs import bc2space, stage2spaces4bc from .ufl.deriv import TimeDerivative, expand_time_derivatives from .ufl.estimate_degrees import TimeDegreeEstimator, get_degree_mapping from .labeling import split_quadrature, as_form @@ -14,7 +14,7 @@ from .constant import vecconst from .discontinuous_galerkin_stepper import getElement as getTestElement from .integrated_lagrange import IntegratedLagrange - +from .backends.firedrake import extract_bcs from .tableaux.ButcherTableaux import CollocationButcherTableau from .stage_derivative import getForm from .stage_value import getFormStage diff --git a/irksome/labeling.py b/irksome/labeling.py index 3b0e10e3..fee107f4 100644 --- a/irksome/labeling.py +++ b/irksome/labeling.py @@ -13,6 +13,7 @@ class TimeQuadratureLabel(Label): """If the constructor gets one argument, it's an integer for the order of the quadrature rule. If there are two arguments, assume they are the points and weights.""" + def __init__(self, *args, scheme="default"): if len(args) == 1: Q = create_time_quadrature(args[0], scheme=scheme) @@ -40,7 +41,9 @@ def has_quad_labels(term): return any(isinstance(label, TimeQuadratureRule) for label in term.labels) -def apply_time_quadrature_labels(form, degree_estimator, scheme=None, max_quadrature_degree=None): +def apply_time_quadrature_labels( + form, degree_estimator, scheme=None, max_quadrature_degree=None +): """ Estimates the polynomial degree in time for each integral in the given form and labels each term with a quadrature rule to be used for time integration. @@ -58,19 +61,25 @@ def apply_time_quadrature_labels(form, degree_estimator, scheme=None, max_quadra # Split labelled and unlabelled parts if isinstance(form, LabelledForm): F = form.label_map(has_quad_labels, map_if_true=keep, map_if_false=drop) - form = as_form(form.label_map(has_quad_labels, map_if_true=drop, map_if_false=keep)) + form = as_form( + form.label_map(has_quad_labels, map_if_true=drop, map_if_false=keep) + ) elif isinstance(form, BaseForm): F = LabelledForm() else: - raise ValueError(f"Expecting a BaseForm or a LabelledForm, not {type(form).__name__}") + raise ValueError( + f"Expecting a BaseForm or a LabelledForm, not {type(form).__name__}" + ) if not isinstance(form, Form) and isinstance(form, BaseForm): # Label the non-integral components if isinstance(form, FormSum): ws = form.weights() fs = form.components() - form = sum((w*f for f, w in zip(fs, ws) if isinstance(f, Form)), Form([])) - base_form = FormSum(*((f, w) for f, w in zip(fs, ws) if not isinstance(f, Form))) + form = sum((w * f for f, w in zip(fs, ws) if isinstance(f, Form)), Form([])) + base_form = FormSum( + *((f, w) for f, w in zip(fs, ws) if not isinstance(f, Form)) + ) else: base_form = form form = Form([]) @@ -99,7 +108,9 @@ def apply_time_quadrature_labels(form, degree_estimator, scheme=None, max_quadra return F -def split_quadrature(F, degree_estimator=None, Qdefault=None, max_quadrature_degree=None): +def split_quadrature( + F, degree_estimator=None, Qdefault=None, max_quadrature_degree=None +): """Splits a :class:`LabelledForm` into the terms to be integrated in time by the different :class:`TimeQuadratureRule` objects used as labels. @@ -117,34 +128,46 @@ def split_quadrature(F, degree_estimator=None, Qdefault=None, max_quadrature_deg do_estimate_degrees = Qdefault is None or isinstance(Qdefault, str) if do_estimate_degrees: scheme = Qdefault - F = apply_time_quadrature_labels(F, degree_estimator, scheme=scheme, - max_quadrature_degree=max_quadrature_degree) + F = apply_time_quadrature_labels( + F, + degree_estimator, + scheme=scheme, + max_quadrature_degree=max_quadrature_degree, + ) if not isinstance(F, LabelledForm): return {Qdefault: F} quad_labels = set() for term in F.terms: - cur_labels = [label for label in term.labels if isinstance(label, TimeQuadratureRule)] + cur_labels = [ + label for label in term.labels if isinstance(label, TimeQuadratureRule) + ] if len(cur_labels) == 1: quad_labels.update(cur_labels) elif len(cur_labels) > 1: raise ValueError("Multiple quadrature labels on one term.") splitting = {} - Fdefault = as_form(F.label_map(has_quad_labels, map_if_true=drop, map_if_false=keep)) + Fdefault = as_form( + F.label_map(has_quad_labels, map_if_true=drop, map_if_false=keep) + ) if do_estimate_degrees: # every term must have been labelled at this point assert Fdefault.empty() else: splitting[Qdefault] = Fdefault for Q in quad_labels: - splitting[Q] = F.label_map(lambda t: Q in t.labels, map_if_true=keep, map_if_false=drop) + splitting[Q] = F.label_map( + lambda t: Q in t.labels, map_if_true=keep, map_if_false=drop + ) # collapse TimeQuadratureRules based on numerical equality - rule_equals = lambda Q1, Q2: (type(Q1) == type(Q2) - and np.array_equal(Q1.get_points(), Q2.get_points()) - and np.array_equal(Q1.get_weights(), Q2.get_weights())) + rule_equals = lambda Q1, Q2: ( + type(Q1) == type(Q2) + and np.array_equal(Q1.get_points(), Q2.get_points()) + and np.array_equal(Q1.get_weights(), Q2.get_weights()) + ) forms = defaultdict(lambda: Form([])) for Q in sorted(splitting, key=lambda Q: tuple(Q.get_points()), reverse=True): @@ -158,12 +181,13 @@ def split_quadrature(F, degree_estimator=None, Qdefault=None, max_quadrature_deg def split_explicit(F): if not isinstance(F, LabelledForm): return (F, None) - exp_part = F.label_map(lambda t: t.has_label(explicit), - map_if_true=keep, - map_if_false=drop) + exp_part = F.label_map( + lambda t: t.has_label(explicit), map_if_true=keep, map_if_false=drop + ) - imp_part = F.label_map(lambda t: t.labels == {}, - map_if_true=keep, map_if_false=drop) + imp_part = F.label_map( + lambda t: t.labels == {}, map_if_true=keep, map_if_false=drop + ) return as_form(imp_part), as_form(exp_part) @@ -185,7 +209,9 @@ def as_linear_form(F, u0): nargs = len(form.arguments()) if nargs == 2: if u0 in form.coefficients(): - raise ValueError("The provided bilinear form must not depend on the solution") + raise ValueError( + "The provided bilinear form must not depend on the solution" + ) test, trial = form.arguments() F = replace(F, {trial: u0}) elif nargs != 1: diff --git a/irksome/multistep.py b/irksome/multistep.py index 9074cfda..8a244540 100644 --- a/irksome/multistep.py +++ b/irksome/multistep.py @@ -171,7 +171,6 @@ def get_form_and_bcs(self, F, t, dt, u0, a, b, bcs=None): return Fnew, bcsnew def advance(self): - self.solver.solve(bounds=self.bounds) # update previous steps diff --git a/irksome/stage_derivative.py b/irksome/stage_derivative.py index eb16ad1e..6fc5e38b 100644 --- a/irksome/stage_derivative.py +++ b/irksome/stage_derivative.py @@ -1,85 +1,94 @@ import numpy -from firedrake import Function, TestFunction -from firedrake import NonlinearVariationalProblem as NLVP -from firedrake import NonlinearVariationalSolver as NLVS -from firedrake import assemble, dx, inner, norm, as_tensor -from firedrake.bcs import EquationBC, EquationBCSplit -from FIAT import ufc_simplex -from FIAT.barycentric_interpolation import LagrangePolynomialSet - -from ufl.constantvalue import as_ufl -from .tableaux.ButcherTableaux import CollocationButcherTableau +from ufl import as_ufl, as_tensor, Form, Coefficient, dx, inner +from .tableaux import ButcherTableaux from .constant import vecconst from .tools import AI, dot, replace, reshape, fields_to_components from .ufl.deriv import Dt, TimeDerivative, expand_time_derivatives -from .bcs import EmbeddedBCData, BCStageData, extract_bcs, bc2space, stage2spaces4bc -from .ufl.manipulation import split_time_derivative_terms +from .backend import get_backend + +from .bcs import EmbeddedBCData from .base_time_stepper import StageCoupledTimeStepper +from .tableaux.ButcherTableaux import CollocationButcherTableau +from FIAT import ufc_simplex +from FIAT.barycentric_interpolation import LagrangePolynomialSet +from .ufl.manipulation import split_time_derivative_terms -def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI, aux_indices=None): +def getForm( + F: Form, + butch: ButcherTableaux, + t: Coefficient, + dt: Coefficient, + u0: Coefficient, + stages, + bcs=None, + bc_type=None, + splitting=AI, + aux_indices=None, + backend: str = "firedrake", +): """Given a time-dependent variational form and a :class:`ButcherTableau`, produce UFL for the s-stage RK method. :arg F: a :class:`ufl.Form` instance describing the semi-discrete problem. :arg butch: the :class:`ButcherTableau` for the RK method being used to advance in time. - :arg t: a :class:`firedrake.Constant` or :class:`firedrake.Function` - on the Real space over the same mesh as ``u0``. This serves as - a variable referring to the current time. - :arg dt: a :class:`firedrake.Constant` or :class:`firedrake.Function` - on the Real space over the same mesh as ``u0``. This serves as - a variable referring to the current time step size. + :arg t: a :class:`Function` or :class:`Constant` on the Real space over the same mesh as + `u0`. This serves as a variable referring to the current time. + :arg dt: a :class:`Function` or :class:`Constant` on the Real space over the same mesh as + `u0`. This serves as a variable referring to the current time step. + The user may adjust this value between time steps. :arg u0: a :class:`Function` referring to the state of - the PDE system at time `t` + the PDE system at time `t` :arg stages: a :class:`Function` representing the stages to be solved for. :kwarg bcs: optionally, a :class:`DirichletBC` or :class:`EquationBC` - object (or iterable thereof) containing (possibly time-dependent) - boundary conditions imposed on the system. - :kwarg bc_type: How to manipulate the strongly-enforced boundary - conditions to derive the stage boundary conditions. Should - be a string, either "DAE", which implements BCs as - constraints in the style of a differential-algebraic - equation, or "ODE", which takes the time derivative of the - boundary data and evaluates this for the stage values. - Support for `firedrake.EquationBC` in `bcs` is limited - to DAE style BCs. - :kwarg splitting: a callable that maps the (floating point) Butcher matrix - to a pair of matrices `A1, A2` such that `butch.A = A1 A2`. This is used - to vary between the classical RK formulation and Butcher's reformulation - that leads to a denser mass matrix with block-diagonal stiffness. - Some choices of function will assume that `butch.A` is invertible. - :kwarg aux_indices: a list of field indices to be discretized as :class:`TimeDerivative`, - analogous to :class:`ContinouosPetrovGalerkinTimeStepper`. + object (or iterable thereof) containing (possibly time-dependent) + boundary conditions imposed on the system. + :kwarg bc_type: How to manipulate the strongly-enforced boundary + conditions to derive the stage boundary conditions. Should + be a string, either "DAE", which implements BCs as + constraints in the style of a differential-algebraic + equation, or "ODE", which takes the time derivative of the + boundary data and evaluates this for the stage values. + Support for `firedrake.EquationBC` in `bcs` is limited + to DAE style BCs. + :kwarg splitting: a callable that maps the (floating point) Butcher matrix + to a pair of matrices `A1, A2` such that `butch.A = A1 A2`. This is used + to vary between the classical RK formulation and Butcher's reformulation + that leads to a denser mass matrix with block-diagonal stiffness. + Some choices of function will assume that `butch.A` is invertible. + :kwarg aux_indices: a list of field indices to be discretized as :class:`TimeDerivative`, + analogous to :class:`ContinouosPetrovGalerkinTimeStepper`. :returns: a 2-tuple of - `Fnew`, the :class:`Form` - `bcnew`, a list of :class:`firedrake.DirichletBC` or :class:`EquationBC` objects to be posed on the stages """ + backend_cls = get_backend(backend) if bc_type is None: bc_type = "DAE" # preprocess time derivatives F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) - v, = F.arguments() - V = v.function_space() - assert V == u0.function_space() + (v,) = F.arguments() + V = backend_cls.get_function_space(v) + assert V == backend_cls.get_function_space(u0) - c = vecconst(butch.c) + c = vecconst(butch.c, backend=backend) bA1, bA2 = splitting(butch.A) try: bA2inv = numpy.linalg.inv(bA2) except numpy.linalg.LinAlgError: raise NotImplementedError("We require A = A1 A2 with A2 invertible") - A1 = vecconst(bA1) - A2inv = vecconst(bA2inv) + A1 = vecconst(bA1, backend=backend) + A2inv = vecconst(bA2inv, backend=backend) # s-way product space for the stage variables num_stages = butch.num_stages - Vbig = stages.function_space() - test = TestFunction(Vbig) + Vbig = backend_cls.get_function_space(stages) + test = backend_cls.TestFunction(Vbig) # set up the pieces we need to work with to do our substitutions v_np = reshape(test, (num_stages, *v.ufl_shape)) @@ -99,21 +108,25 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI, a usub = reshape(usub, u0.ufl_shape) usub[aux_components] = dtusub[aux_components] * dt - repl[i] = {t: t + c[i] * dt, - v: v_np[i], - u0: usub, - dtu: dtusub} + repl[i] = {t: t + c[i] * dt, v: v_np[i], u0: usub, dtu: dtusub} Fnew = sum(replace(F, repl[i]) for i in range(num_stages)) if bcs is None: bcs = [] if bc_type == "ODE": - assert splitting == AI, "ODE-type BC aren't implemented for this splitting strategy" + assert splitting == AI, ( + "ODE-type BC aren't implemented for this splitting strategy" + ) def bc2stagebc(bc, i): + from irksome.bcs import BCStageData + from firedrake.bcs import EquationBCSplit + if isinstance(bc, EquationBCSplit): - raise NotImplementedError("EquationBC not implemented for ODE formulation") + raise NotImplementedError( + "EquationBC not implemented for ODE formulation" + ) gorig = as_ufl(bc._original_arg) gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=(u0,)) gcur = replace(gfoo, {t: t + c[i] * dt}) @@ -122,31 +135,46 @@ def bc2stagebc(bc, i): elif bc_type == "DAE": try: bA1inv = numpy.linalg.inv(bA1) - A1inv = vecconst(bA1inv) + A1inv = vecconst(bA1inv, backend=backend) except numpy.linalg.LinAlgError: - raise NotImplementedError("Cannot have DAE BCs for this Butcher Tableau/splitting") + raise NotImplementedError( + "Cannot have DAE BCs for this Butcher Tableau/splitting" + ) def bc2stagebc(bc, i): + from irksome.bcs import BCStageData, stage2spaces4bc, bc2space + from firedrake.bcs import EquationBCSplit, EquationBC + if isinstance(bc, EquationBCSplit): F_bc_orig = expand_time_derivatives(bc.f, t=t, timedep_coeffs=(u0,)) F_bc_new = replace(F_bc_orig, repl[i]) Vbigi = stage2spaces4bc(bc, V, Vbig, i) - return EquationBC(F_bc_new == 0, stages, bc.sub_domain, V=Vbigi, - bcs=[bc2stagebc(innerbc, i) for innerbc in extract_bcs(bc.bcs)]) + return EquationBC( + F_bc_new == 0, + stages, + bc.sub_domain, + V=Vbigi, + bcs=[ + bc2stagebc(innerbc, i) + for innerbc in backend_cls.extract_bcs(bc.bcs) + ], + ) else: gcur = bc._original_arg if gcur != 0: gorig = as_ufl(gcur) ucur = bc2space(bc, u0) - gcur = (1/dt) * sum((replace(gorig, {t: t + c[j]*dt}) - ucur) * A1inv[i, j] - for j in range(num_stages)) + gcur = (1 / dt) * sum( + (replace(gorig, {t: t + c[j] * dt}) - ucur) * A1inv[i, j] + for j in range(num_stages) + ) return BCStageData(bc, gcur, u0, stages, i) else: raise ValueError(f"Unrecognised bc_type: {bc_type}") # This logic uses information set up in the previous section to # set up the new BCs for either method - bcs = extract_bcs(bcs) + bcs = backend_cls.extract_bcs(bcs) bcnew = [bc2stagebc(bc, i) for i in range(num_stages) for bc in bcs] return Fnew, bcnew @@ -191,9 +219,23 @@ class StageDerivativeTimeStepper(StageCoupledTimeStepper): :kwarg sample_points: An optional kwarg used to evaluate collocation methods at additional points in time. """ - def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, - solver_parameters=None, splitting=AI, - appctx=None, bc_type="DAE", aux_indices=None, sample_points=None, **kwargs): + + def __init__( + self, + F, + butcher_tableau, + t, + dt, + u0, + bcs=None, + solver_parameters=None, + splitting=AI, + appctx=None, + bc_type="DAE", + aux_indices=None, + sample_points=None, + **kwargs, + ): self.num_fields = len(u0.function_space()) self.butcher_tableau = butcher_tableau @@ -204,13 +246,21 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, raise NotImplementedError("A=A1 A2 splitting needs A2 invertible") self.aux_indices = aux_indices - super().__init__(F, t, dt, u0, - butcher_tableau.num_stages, bcs=bcs, - solver_parameters=solver_parameters, - appctx=appctx, - splitting=splitting, bc_type=bc_type, - butcher_tableau=butcher_tableau, - sample_points=sample_points, **kwargs) + super().__init__( + F, + t, + dt, + u0, + butcher_tableau.num_stages, + bcs=bcs, + solver_parameters=solver_parameters, + appctx=appctx, + splitting=splitting, + bc_type=bc_type, + butcher_tableau=butcher_tableau, + sample_points=sample_points, + **kwargs, + ) def _update(self): """Assuming the algebraic problem for the RK stages has been @@ -224,24 +274,36 @@ def _update(self): # Note: this now catches the optimized/stiffly accurate case as b[s] == Zero() will get dropped for i, u0bit in enumerate(self.u0.subfunctions): - u0bit += sum(self.stages.subfunctions[nf * s + i] * (b[s] * dt) for s in range(ns)) + u0bit += sum( + self.stages.subfunctions[nf * s + i] * (b[s] * dt) for s in range(ns) + ) def get_form_and_bcs(self, stages, F=None, bcs=None, tableau=None): if bcs is None: bcs = self.orig_bcs - return getForm(F or self.F, - tableau or self.butcher_tableau, - self.t, self.dt, - self.u0, stages, bcs, self.bc_type, - splitting=self.splitting, - aux_indices=self.aux_indices) + return getForm( + F or self.F, + tableau or self.butcher_tableau, + self.t, + self.dt, + self.u0, + stages, + bcs, + self.bc_type, + splitting=self.splitting, + aux_indices=self.aux_indices, + ) def tabulate_poly(self, sample_points): if not isinstance(self.butcher_tableau, CollocationButcherTableau): - raise ValueError("Need a collocation method to evaluate the collocation polynomial") + raise ValueError( + "Need a collocation method to evaluate the collocation polynomial" + ) nodes = numpy.insert(self.butcher_tableau.c, 0, 0.0) if len(set(nodes)) != len(nodes): - raise ValueError("Need non-confluent collocation method for polynomial evaluation") + raise ValueError( + "Need non-confluent collocation method for polynomial evaluation" + ) ref_el = ufc_simplex(1) lag_basis = LagrangePolynomialSet(ref_el, nodes) @@ -303,18 +365,51 @@ class AdaptiveTimeStepper(StageDerivativeTimeStepper): with each time step. :arg nullspace: An optional nullspace object. """ - def __init__(self, F, butcher_tableau, t, dt, u0, - bcs=None, appctx=None, solver_parameters=None, - bc_type="DAE", splitting=AI, nullspace=None, - tol=1.e-3, dtmin=1.e-15, dtmax=1.0, KI=1/15, KP=0.13, - max_reject=10, onscale_factor=1.2, safety_factor=0.9, - gamma0_params=None, **kwargs): - assert butcher_tableau.btilde is not None - super(AdaptiveTimeStepper, self).__init__(F, butcher_tableau, - t, dt, u0, bcs=bcs, appctx=appctx, solver_parameters=solver_parameters, - bc_type=bc_type, splitting=splitting, nullspace=nullspace, **kwargs) + def __init__( + self, + F, + butcher_tableau, + t, + dt, + u0, + bcs=None, + appctx=None, + solver_parameters=None, + bc_type="DAE", + splitting=AI, + nullspace=None, + tol=1.0e-3, + dtmin=1.0e-15, + dtmax=1.0, + KI=1 / 15, + KP=0.13, + max_reject=10, + onscale_factor=1.2, + safety_factor=0.9, + gamma0_params=None, + backend_cls: str = "firedrake", + **kwargs, + ): + assert butcher_tableau.btilde is not None + super(AdaptiveTimeStepper, self).__init__( + F, + butcher_tableau, + t, + dt, + u0, + bcs=bcs, + appctx=appctx, + solver_parameters=solver_parameters, + bc_type=bc_type, + splitting=splitting, + nullspace=nullspace, + **kwargs, + ) + + self._backend_cls = get_backend(backend_cls) from firedrake.petsc import PETSc + self.print = PETSc.Sys.Print self.dt_min = dtmin @@ -330,21 +425,27 @@ def __init__(self, F, butcher_tableau, t, dt, u0, self.onscale_factor = onscale_factor self.safety_factor = safety_factor - self.error_func = Function(u0.function_space()) + self.error_func = self._backend_cls.Function(u0.function_space()) self.tol = tol self.err_old = 0.0 self.contreject = 0 split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u0,)) - F_remainder = expand_time_derivatives(split_form.remainder, t=t, timedep_coeffs=()) + F_remainder = expand_time_derivatives( + split_form.remainder, t=t, timedep_coeffs=() + ) self.dtless_form = -F_remainder # Set up and cache boundary conditions for error estimate embbc = [] if self.gamma0 != 0: # Grab spaces for BCs - embbc = [EmbeddedBCData(bc, butcher_tableau, self.t, self.dt, self.u0, self.stages) - for bc in bcs] + embbc = [ + EmbeddedBCData( + bc, butcher_tableau, self.t, self.dt, self.u0, self.stages + ) + for bc in bcs + ] self.embbc = embbc def _estimate_error(self): @@ -360,21 +461,28 @@ def _estimate_error(self): u0 = self.u0 # Initialize e to be gamma*h*f(old value of u) - error_func = Function(u0.function_space()) + error_func = self._backend_cls.Function(u0.function_space()) # Only do the hard stuff if gamma0 is not zero if self.gamma0 != 0.0: - error_test = TestFunction(u0.function_space()) - f_form = inner(error_func, error_test)*dx-self.gamma0*dtc*self.dtless_form - f_problem = NLVP(f_form, error_func, bcs=self.embbc) - f_solver = NLVS(f_problem, solver_parameters=self.gamma0_params) + error_test = self._backend_cls.TestFunction(u0.function_space()) + f_form = ( + inner(error_func, error_test) * dx + - self.gamma0 * dtc * self.dtless_form + ) + f_problem = self._backend_cls.create_nonlinearvariational_problem( + f_form, error_func, bcs=self.embbc + ) + f_solver = self._backend_cls.create_nonlinearvariational_solver( + f_problem, solver_parameters=self.gamma0_params + ) f_solver.solve() # Accumulate delta-b terms over stages error_func_bits = error_func.subfunctions for s in range(ns): for i, e in enumerate(error_func_bits): - e += dtc*float(delb[s])*ws[nf*s+i] - return norm(assemble(error_func)) + e += dtc * float(delb[s]) * ws[nf * s + i] + return self._backend_cls.norm(self._backend_cls.assemble(error_func)) def advance(self): """Attempts to advances the system from time `t` to time `t + @@ -394,23 +502,33 @@ def advance(self): dt_old = float(self.dt_old) dt_current = float(self.dt) tol = float(self.tol) - dt_pred = dt_current*((dt_current*tol)/err_current)**(1/self.butcher_tableau.embedded_order) + dt_pred = dt_current * ((dt_current * tol) / err_current) ** ( + 1 / self.butcher_tableau.embedded_order + ) self.print("\tTruncation error is %e" % (err_current)) # Rejected step shrinks the time-step - if err_current >= dt_current*tol: - dtnew = dt_current*(self.safety_factor*dt_current*tol/err_current)**(1./self.butcher_tableau.embedded_order) + if err_current >= dt_current * tol: + dtnew = dt_current * ( + self.safety_factor * dt_current * tol / err_current + ) ** (1.0 / self.butcher_tableau.embedded_order) self.print("\tShrinking time-step to %e" % (dtnew)) self.dt.assign(dtnew) self.contreject += 1 if dtnew <= self.dt_min or numpy.isfinite(dtnew) is False: raise RuntimeError("The time-step became an invalid number.") if self.contreject >= self.max_reject: - raise RuntimeError(f"The time-step was rejected {self.max_reject} times in a row. Please increase the tolerance or decrease the starting time-step.") + raise RuntimeError( + f"The time-step was rejected {self.max_reject} times in a row. Please increase the tolerance or decrease the starting time-step." + ) # Initial time-step selector - elif self.num_steps == 0 and dt_current < self.dt_max and dt_pred > self.onscale_factor*dt_current and self.contreject <= self.max_reject: - + elif ( + self.num_steps == 0 + and dt_current < self.dt_max + and dt_pred > self.onscale_factor * dt_current + and self.contreject <= self.max_reject + ): # Increase the initial time-step dtnew = min(dt_pred, self.dt_max) self.print("\tIncreasing time-step to %e" % (dtnew)) @@ -420,11 +538,22 @@ def advance(self): # Accepted step increases the time-step else: if dt_old != 0.0 and err_old != 0.0 and dt_current < self.dt_max: - dtnew = min(dt_current*((dt_current*tol)/err_current)**self.KI*(err_old/err_current)**self.KP*(dt_current/dt_old)**self.KP, self.dt_max) - self.print("\tThe step was accepted and the new time-step is %e" % (dtnew)) + dtnew = min( + dt_current + * ((dt_current * tol) / err_current) ** self.KI + * (err_old / err_current) ** self.KP + * (dt_current / dt_old) ** self.KP, + self.dt_max, + ) + self.print( + "\tThe step was accepted and the new time-step is %e" % (dtnew) + ) else: dtnew = min(dt_current, self.dt_max) - self.print("\tThe step was accepted and the time-step remains at %e " % (dtnew)) + self.print( + "\tThe step was accepted and the time-step remains at %e " + % (dtnew) + ) self._update() self.contreject = 0 self.num_steps += 1 diff --git a/irksome/tools.py b/irksome/tools.py index 73e6f1d0..b30c3018 100644 --- a/irksome/tools.py +++ b/irksome/tools.py @@ -1,13 +1,9 @@ -from operator import mul -from functools import reduce +from .backend import get_backend import numpy - -from firedrake.fml import LabelledForm, Term -from firedrake import VectorSpaceBasis, MixedVectorSpaceBasis from ufl.algorithms.analysis import extract_type from ufl import as_tensor from ufl import replace as ufl_replace -from pyop2.types import MixedDat + import FIAT from .ufl.deriv import TimeDerivative @@ -22,6 +18,7 @@ def reshape(expr, shape): def flatten_dats(dats): + from pyop2.types import MixedDat flat_dat = [] for dat in dats: if isinstance(dat, (tuple, list, MixedDat)): @@ -31,8 +28,9 @@ def flatten_dats(dats): return MixedDat(flat_dat) -def get_stage_space(V, num_stages): - return reduce(mul, (V for _ in range(num_stages))) +def get_stage_space(V, num_stages, backend:str="firedrake"): + backend_cls = get_backend(backend) + return backend_cls.get_stage_space(V, num_stages) def split_stages(V, stages): @@ -59,6 +57,8 @@ def fields_to_components(V, fields): """ cur = 0 components = [] + if len(fields) == 0: + return components for i, Vi in enumerate(V): if i in fields: components.extend(range(cur, cur+Vi.value_size)) @@ -78,6 +78,7 @@ def getNullspace(V, Vbig, num_stages, nullspace): On output, we produce a :class:`MixedVectorSpaceBasis` defining the nullspace for the multistage problem. """ + from firedrake import VectorSpaceBasis, MixedVectorSpaceBasis num_fields = len(V) if nullspace is None: @@ -110,11 +111,16 @@ def getNullspace(V, Vbig, num_stages, nullspace): def replace(e, mapping): """A wrapper for ufl.replace that allows numpy arrays.""" cmapping = {k: as_tensor(v) for k, v in mapping.items()} - if isinstance(e, LabelledForm): - enew = LabelledForm(*(Term(ufl_replace(term.form, cmapping), term.labels) - for term in e.terms)) - return enew - else: + try: + from firedrake.fml import LabelledForm, Term + if isinstance(e, LabelledForm): + enew = LabelledForm(*(Term(ufl_replace(term.form, cmapping), term.labels) + for term in e.terms)) + return enew + else: + return ufl_replace(e, cmapping) + + except ImportError: return ufl_replace(e, cmapping)