Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from .scheme import ContinuousPetrovGalerkinScheme, DiscontinuousGalerkinScheme
from .scheme import GalerkinCollocationScheme

from .form_manipulation import getForm

__all__ = [
"Alexander",
"ARS_DIRK_IMEX",
Expand All @@ -31,6 +33,7 @@
"DiscontinuousGalerkinScheme",
"Dt",
"expand_time_derivatives",
"getForm",
"GalerkinCollocationScheme",
"GaussLegendre",
"LobattoIIIA",
Expand All @@ -53,7 +56,6 @@
from .bcs import BoundsConstrainedDirichletBC
from .dirk_stepper import DIRKTimeStepper
from .imex import RadauIIAIMEXMethod, DIRKIMEXMethod
from .stage_derivative import getForm
from .nystrom_dirk_stepper import DIRKNystromTimeStepper, ExplicitNystromTimeStepper
from .nystrom_stepper import (
StageDerivativeNystromTimeStepper,
Expand All @@ -78,7 +80,6 @@
__all__ += [
"DIRKTimeStepper",
"BoundsConstrainedDirichletBC",
"getForm",
"RadauIIAIMEXMethod",
"DIRKIMEXMethod",
"DIRKNystromTimeStepper",
Expand Down
14 changes: 13 additions & 1 deletion irksome/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol
from typing import Protocol, Any
import ufl
from importlib import import_module

Expand All @@ -7,6 +7,10 @@ class Backend(Protocol):
def get_function_space(self, V: ufl.Coefficient) -> ufl.FunctionSpace:
"""Get a function space from the backend"""

def extract_bcs(bcs: Any)->tuple[Any]:
"""Extract boundary conditions"""


Comment thread
pbrubeck marked this conversation as resolved.
def get_stages(self, V: ufl.FunctionSpace, num_stages: int) -> ufl.Coefficient:
"""
Given a function space for a single time-step, get a duplicate of this space,
Expand Down Expand Up @@ -38,6 +42,14 @@ def ConstantOrZero(
def get_mesh_constant(MC: MeshConstant | None) -> ufl.core.expr.Expr:
"""Get a backend class to construct a mesh constant from"""

def TestFunction(space: ufl.FunctionSpace, part: int|None=None)-> ufl.Argument:
"""Return a test-function that can be used by forms in the backend."""

def create_nonlinearvariational_problem(F: ufl.Form, g: ufl.Coefficient, solver_parameters: dict):
Comment thread
jorgensd marked this conversation as resolved.
Outdated
"""Create a non-linear variational solver that uses PETSc SNES."""

def get_stage_spaces(V: ufl.FunctionSpace, num_stages: int) -> ufl.FunctionSpace:
"""Create a stage space with M number of components."""

def get_backend(backend: str) -> Backend:
"""Get backend class from backend name.
Expand Down
26 changes: 25 additions & 1 deletion irksome/backends/dolfinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,32 @@

try:
import basix.ufl
import dolfinx
import dolfinx.fem.petsc
import ufl
import typing
import numpy as np

TestFunction = ufl.TestFunction

def get_stage_space(V: ufl.FunctionSpace, num_stages:int)->ufl.FunctionSpace:
if num_stages == 1:
me = V.ufl_elemet()
else:
el = V.ufl_element()
if el.num_sub_elements > 0:
me = basix.ufl.mixed_element(np.tile(el.sub_elements, num_stages).tolist())
else:
me = basix.ufl.blocked_element(el, shape=(num_stages, ))
return dolfinx.fem.functionspace(V.mesh, me)

def extract_bcs(bcs: typing.Any)->tuple[typing.Any]:
"""Extract boundary conditions"""
return bcs

def create_nonlinearvariational_problem(F: ufl.Form, g: ufl.Coefficient, solver_parameters: dict):
"""Create a non-linear variational solver that uses PETSc SNES."""
return dolfinx.fem.petsc.NonlinearProblem(F, g, petsc_options_prefix="IrkSomeSolver",
petsc_options=solver_parameters)

def get_function_space(u: ufl.Coefficient) -> ufl.FunctionSpace:
return u.ufl_function_space()
Expand Down
23 changes: 23 additions & 0 deletions irksome/backends/firedrake.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
"""Firedrake backend for Irksome"""


from operator import mul
from functools import reduce

Comment on lines +3 to +5
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from operator import mul
from functools import reduce

import firedrake
import ufl
from ..tools import get_stage_space
import typing

TestFunction = firedrake.TestFunction


def get_stage_space(V: ufl.FunctionSpace, num_stages:int)->ufl.FunctionSpace:
return reduce(mul, (V for _ in range(num_stages)))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return reduce(mul, (V for _ in range(num_stages)))
return firedrake.MixedFunctionSpace(tuple(V) * num_stages)



def extract_bcs(bcs: typing.Any)->tuple[typing.Any]:
"""Return an iterable of boundary conditions on the residual form"""
return tuple(bc.extract_form("F") for bc in firedrake.solving._extract_bcs(bcs))


def create_nonlinearvariational_problem(F: ufl.Form, u: ufl.Coefficient, solver_parameters: dict):
Comment thread
jorgensd marked this conversation as resolved.
Outdated
"""Create a non-linear variational solver that uses PETSc SNES."""
problem = firedrake.NonlinearVariationalProblem(F, u)
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
return firedrake.NonlinearVariationalSolver(
problem, solver_parameters=solver_parameters
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
)

def get_function_space(u: ufl.Coefficient) -> firedrake.FunctionSpace:
return u.function_space()
Expand Down
34 changes: 12 additions & 22 deletions irksome/bcs.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
from firedrake.solving import _extract_bcs
from firedrake import (
DirichletBC,
Function,
TestFunction,
NonlinearVariationalProblem,
NonlinearVariationalSolver,
replace,
inner,
dx,
)

from .backend import get_backend
from ufl import as_ufl


def extract_bcs(bcs):
"""Return an iterable of boundary conditions on the residual form"""
return tuple(bc.extract_form("F") for bc in _extract_bcs(bcs))


def get_sub(u, indices):
for i in indices:
if i is not None:
Expand Down Expand Up @@ -66,25 +57,24 @@ def EmbeddedBCData(bc, butcher_tableau, t, dt, u0, stages):
class BoundsConstrainedDirichletBC(DirichletBC):
"""A DirichletBC with bounds-constrained data."""

def __init__(self, V, g, sub_domain, bounds, solver_parameters=None):
def __init__(self, V, g, sub_domain, bounds, solver_parameters=None, backend:str="firedrake"):

self.g = g
self.solver_parameters = solver_parameters
self.bounds = bounds
self.gnew = Function(V)
backend_cls = get_backend(backend)
F = inner(self.gnew - g, backend_cls.TestFunction(V)) * dx

if solver_parameters is None:
solver_parameters = {
"snes_type": "vinewtonrsls",
"snes_max_it": 300,
"snes_atol": 1.0e-8,
"ksp_type": "preonly",
"mat_type": "aij",
}
self.g = g
self.solver_parameters = solver_parameters
self.bounds = bounds

self.gnew = Function(V)
F = inner(self.gnew - g, TestFunction(V)) * dx
problem = NonlinearVariationalProblem(F, self.gnew)
self.solver = NonlinearVariationalSolver(
problem, solver_parameters=self.solver_parameters
)
}
self.solver = backend_cls.create_nonlinearvariational_problem(F, self.gnew, solver_parameters)
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
super().__init__(V, g, sub_domain)

@property
Expand Down
148 changes: 148 additions & 0 deletions irksome/form_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@

Comment thread
jorgensd marked this conversation as resolved.
Outdated
import numpy
from ufl import as_ufl, as_tensor, Form, Coefficient
from .tableaux import ButcherTableaux
from .constant import vecconst
from .tools import AI, dot, replace, reshape, fields_to_components
from .ufl.deriv import Dt, TimeDerivative, expand_time_derivatives
from .backend import get_backend

__all__ = ["getForm"]

def getForm(F: Form, butch:ButcherTableaux, t: Coefficient, dt:Coefficient, u0:Coefficient, stages, bcs=None, bc_type=None, splitting=AI, aux_indices=None, backend:str="firedrake"):
"""Given a time-dependent variational form and a
:class:`ButcherTableau`, produce UFL for the s-stage RK method.

:arg F: UFL form for the semidiscrete ODE/DAE
:arg butch: the :class:`ButcherTableau` for the RK method being used to
advance in time.
:arg t: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time.
:arg dt: a :class:`Function` on the Real space over the same mesh as
`u0`. This serves as a variable referring to the current time step.
The user may adjust this value between time steps.
:arg u0: a :class:`Function` referring to the state of
the PDE system at time `t`
:arg stages: a :class:`Function` representing the stages to be solved for.
:kwarg bcs: optionally, a :class:`DirichletBC` or :class:`EquationBC`
object (or iterable thereof) containing (possibly time-dependent)
boundary conditions imposed on the system.
:kwarg bc_type: How to manipulate the strongly-enforced boundary
conditions to derive the stage boundary conditions. Should
be a string, either "DAE", which implements BCs as
constraints in the style of a differential-algebraic
equation, or "ODE", which takes the time derivative of the
boundary data and evaluates this for the stage values.
Support for `firedrake.EquationBC` in `bcs` is limited
to DAE style BCs.
:kwarg splitting: a callable that maps the (floating point) Butcher matrix
a to a pair of matrices `A1, A2` such that `butch.A = A1 A2`. This is used
to vary between the classical RK formulation and Butcher's reformulation
that leads to a denser mass matrix with block-diagonal stiffness.
Some choices of function will assume that `butch.A` is invertible.
:kwarg aux_indices: a list of field indices to be discretized as :class:`TimeDerivative`,
analogouos to :class:`ContinouosPetrovGalerkinTimeStepper`.

:returns: a 2-tuple of
- `Fnew`, the :class:`Form`
- `bcnew`, a list of :class:`firedrake.DirichletBC` or :class:`EquationBC`
objects to be posed on the stages
"""
backend_cls = get_backend(backend)
if bc_type is None:
bc_type = "DAE"

# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
v, = F.arguments()
V = backend_cls.get_function_space(v)
assert V == backend_cls.get_function_space(u0)

c = vecconst(butch.c, backend=backend)
bA1, bA2 = splitting(butch.A)
try:
bA2inv = numpy.linalg.inv(bA2)
except numpy.linalg.LinAlgError:
raise NotImplementedError("We require A = A1 A2 with A2 invertible")
A1 = vecconst(bA1, backend=backend)
A2inv = vecconst(bA2inv, backend=backend)

# s-way product space for the stage variables
num_stages = butch.num_stages
Vbig = backend_cls.get_function_space(stages)
test = backend_cls.TestFunction(Vbig)

# set up the pieces we need to work with to do our substitutions
v_np = reshape(test, (num_stages, *v.ufl_shape))
w_np = reshape(stages, (num_stages, *u0.ufl_shape))
A1w = dot(A1, w_np)
A2invw = dot(A2inv, w_np)
dtu = TimeDerivative(u0)

aux_components = fields_to_components(V, aux_indices or [])

repl = {}
for i in range(num_stages):
usub = u0 + as_tensor(A1w[i]) * dt
dtusub = A2invw[i]
if aux_components:
# Apply TimeDerivative substitution to auxiliary fields
usub = reshape(usub, u0.ufl_shape)
usub[aux_components] = dtusub[aux_components] * dt

repl[i] = {t: t + c[i] * dt,
v: v_np[i],
u0: usub,
dtu: dtusub}

Fnew = sum(replace(F, repl[i]) for i in range(num_stages))

if bcs is None:
bcs = []
if bc_type == "ODE":
assert splitting == AI, "ODE-type BC aren't implemented for this splitting strategy"

def bc2stagebc(bc, i):
from irksome.bcs import BCStageData
from firedrake.bcs import EquationBCSplit

if isinstance(bc, EquationBCSplit):
raise NotImplementedError("EquationBC not implemented for ODE formulation")
gorig = as_ufl(bc._original_arg)
gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=(u0,))
gcur = replace(gfoo, {t: t + c[i] * dt})
return BCStageData(bc, gcur, u0, stages, i)

elif bc_type == "DAE":
try:
bA1inv = numpy.linalg.inv(bA1)
A1inv = vecconst(bA1inv, backend=backend)
except numpy.linalg.LinAlgError:
raise NotImplementedError("Cannot have DAE BCs for this Butcher Tableau/splitting")

def bc2stagebc(bc, i):
from irksome.bcs import BCStageData, stage2spaces4bc, bc2space
from firedrake.bcs import EquationBCSplit, EquationBC
if isinstance(bc, EquationBCSplit):
F_bc_orig = expand_time_derivatives(bc.f, t=t, timedep_coeffs=(u0,))
F_bc_new = replace(F_bc_orig, repl[i])
Vbigi = stage2spaces4bc(bc, V, Vbig, i)
return EquationBC(F_bc_new == 0, stages, bc.sub_domain, V=Vbigi,
bcs=[bc2stagebc(innerbc, i) for innerbc in backend_cls.extract_bcs(bc.bcs)])
else:
gcur = bc._original_arg
if gcur != 0:
gorig = as_ufl(gcur)
ucur = bc2space(bc, u0)
gcur = (1/dt) * sum((replace(gorig, {t: t + c[j]*dt}) - ucur) * A1inv[i, j]
for j in range(num_stages))
return BCStageData(bc, gcur, u0, stages, i)
else:
raise ValueError(f"Unrecognised bc_type: {bc_type}")

# This logic uses information set up in the previous section to
# set up the new BCs for either method
bcs = backend_cls.extract_bcs(bcs)
bcnew = [bc2stagebc(bc, i) for i in range(num_stages) for bc in bcs]

return Fnew, bcnew
6 changes: 3 additions & 3 deletions irksome/galerkin_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ufl import as_ufl, as_tensor

from .base_time_stepper import StageCoupledTimeStepper
from .bcs import bc2space, extract_bcs, stage2spaces4bc
from .bcs import bc2space, stage2spaces4bc
from .ufl.deriv import TimeDerivative, expand_time_derivatives
from .ufl.estimate_degrees import TimeDegreeEstimator, get_degree_mapping
from .labeling import split_quadrature, as_form
Expand All @@ -14,9 +14,9 @@
from .constant import vecconst
from .discontinuous_galerkin_stepper import getElement as getTestElement
from .integrated_lagrange import IntegratedLagrange

from .backends.firedrake import extract_bcs
from .tableaux.ButcherTableaux import CollocationButcherTableau
from .stage_derivative import getForm
from .form_manipulation import getForm
from .stage_value import getFormStage

import numpy as np
Expand Down
Loading