Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions irksome/backends/firedrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,35 @@ def Constant(self, val=0.0) -> ufl.Coefficient:

def get_mesh_constant(MC: MeshConstant | None):
return MC.Constant if MC else firedrake.Constant


def create_variational_problem(F, u, bcs=None, J=None, Jp=None, **kwargs):
if len(F.arguments()) == 2:
a = ufl.lhs(F)
L = ufl.rhs(F)
kwargs.pop("is_linear", None)
problem = firedrake.LinearVariationalProblem(a, L, u, bcs=bcs, aP=Jp, **kwargs)
else:
constant_jacobian = kwargs.pop("constant_jacobian", False)
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, **kwargs)
if constant_jacobian:
problem._constant_jacobian = constant_jacobian
return problem


def create_variational_solver(problem, **kwargs):
if isinstance(problem, firedrake.LinearVariationalProblem):
return firedrake.LinearVariationalSolver(problem, **kwargs)
else:
return firedrake.NonlinearVariationalSolver(problem, **kwargs)


def invalidate_jacobian(solver):
return firedrake.LinearVariationalSolver.invalidate_jacobian(solver)


derivative = firedrake.derivative
norm = firedrake.norm
Function = firedrake.Function
TestFunction = firedrake.TestFunction
TrialFunction = firedrake.TrialFunction
82 changes: 31 additions & 51 deletions irksome/base_time_stepper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from abc import abstractmethod
from firedrake import (
derivative, replace, lhs, rhs, Function, TrialFunction,
LinearVariationalProblem, LinearVariationalSolver,
NonlinearVariationalProblem, NonlinearVariationalSolver,
)
from firedrake.petsc import PETSc
from .tools import AI, getNullspace, flatten_dats, split_stages
from .labeling import as_form, as_linear_form
from .labeling import as_form
from .backend import get_backend
import ufl
import numpy
Expand Down Expand Up @@ -99,11 +94,6 @@ def __init__(self, F, t, dt, u0, num_stages,
butcher_tableau=None, bounds=None, sample_points=None,
**kwargs):

is_linear = False
if len(as_form(F).arguments()) == 2:
F = as_linear_form(F, u0)
is_linear = True

super().__init__(F, t, dt, u0,
bcs=bcs, appctx=appctx, nullspace=nullspace)

Expand All @@ -123,44 +113,34 @@ def __init__(self, F, t, dt, u0, num_stages,
stages = self.get_stages()
self.stages = stages

Fbig, bigBCs = self.get_form_and_bcs(stages)
V = u0.function_space()
Vbig = stages.function_space()

F_linear = len(as_form(F).arguments()) == 2
stages_F = self._backend.TrialFunction(Vbig) if F_linear else stages
Fbig, bigBCs = self.get_form_and_bcs(stages_F)

Jpbig = None
if Fp is not None:
Fp = as_linear_form(Fp, u0)
Fpbig, _ = self.get_form_and_bcs(stages, F=Fp, bcs=())
Jpbig = derivative(Fpbig, stages)
Fp_linear = len(as_form(Fp).arguments()) == 2
stages_Fp = self._backend.TrialFunction(Vbig) if Fp_linear else stages
Fpbig, _ = self.get_form_and_bcs(stages_Fp, F=Fp, bcs=())
Jpbig = ufl.lhs(Fpbig) if Fp_linear else self._backend.derivative(Fpbig, stages_Fp)

V = u0.function_space()
Vbig = stages.function_space()
nullspace = getNullspace(V, Vbig, num_stages, nullspace)
transpose_nullspace = getNullspace(V, Vbig, num_stages, transpose_nullspace)
near_nullspace = getNullspace(V, Vbig, num_stages, near_nullspace)

self.bigBCs = bigBCs

if is_linear:
Fbig = replace(Fbig, {stages: TrialFunction(stages.function_space())})
abig = lhs(Fbig)
Lbig = rhs(Fbig)
problem = LinearVariationalProblem(
abig, Lbig, stages, bcs=bigBCs, aP=Jpbig,
form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
constant_jacobian=kwargs.pop("constant_jacobian", False),
restrict=kwargs.pop("restrict", False),
)
solver_constructor = LinearVariationalSolver
else:
problem = NonlinearVariationalProblem(
Fbig, stages, bcs=bigBCs, Jp=Jpbig,
form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
is_linear=kwargs.pop("is_linear", False),
restrict=kwargs.pop("restrict", False),
)
problem._constant_jacobian = kwargs.pop("constant_jacobian", False)
solver_constructor = NonlinearVariationalSolver

self.problem = problem
self.solver = solver_constructor(
self.problem = self._backend.create_variational_problem(
Fbig, stages, bcs=bigBCs, Jp=Jpbig,
form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
is_linear=kwargs.pop("is_linear", False),
restrict=kwargs.pop("restrict", False),
constant_jacobian=kwargs.pop("constant_jacobian", False),
)
self.solver = self._backend.create_variational_solver(
self.problem, appctx=self.appctx,
nullspace=nullspace,
transpose_nullspace=transpose_nullspace,
Expand Down Expand Up @@ -193,7 +173,7 @@ def advance(self):
# allow butcher tableau as input for preconditioners to create
# an alternate operator
@abstractmethod
def get_form_and_bcs(self, stages, tableau=None, F=None):
def get_form_and_bcs(self, stages, F=None, bcs=None, tableau=None):
pass

def solver_stats(self):
Expand All @@ -209,30 +189,30 @@ def get_stage_bounds(self, bounds=None):
Vbig = self.stages.function_space()
bounds_type, lower, upper = bounds
if lower is None:
slb = Function(Vbig).assign(PETSc.NINFINITY)
slb = self._backend.Function(Vbig).assign(PETSc.NINFINITY)
if upper is None:
sub = Function(Vbig).assign(PETSc.INFINITY)
sub = self._backend.Function(Vbig).assign(PETSc.INFINITY)

if bounds_type == "stage":
if lower is not None:
dats = [lower.dat] * (self.num_stages)
slb = Function(Vbig, val=flatten_dats(dats))
slb = self._backend.Function(Vbig, val=flatten_dats(dats))
if upper is not None:
dats = [upper.dat] * (self.num_stages)
sub = Function(Vbig, val=flatten_dats(dats))
sub = self._backend.Function(Vbig, val=flatten_dats(dats))

elif bounds_type == "last_stage":
V = self.u0.function_space()
if lower is not None:
ninfty = Function(V).assign(PETSc.NINFINITY)
ninfty = self._backend.Function(V).assign(PETSc.NINFINITY)
dats = [ninfty.dat] * (self.num_stages-1)
dats.append(lower.dat)
slb = Function(Vbig, val=flatten_dats(dats))
slb = self._backend.Function(Vbig, val=flatten_dats(dats))
if upper is not None:
infty = Function(V).assign(PETSc.INFINITY)
infty = self._backend.Function(V).assign(PETSc.INFINITY)
dats = [infty.dat] * (self.num_stages-1)
dats.append(upper.dat)
sub = Function(Vbig, val=flatten_dats(dats))
sub = self._backend.Function(Vbig, val=flatten_dats(dats))

else:
raise ValueError("Unknown bounds type")
Expand All @@ -254,7 +234,7 @@ def build_poly(self):
pts = numpy.reshape(self.sample_points, (-1, 1))
vander = self.tabulate_poly(pts)

self.u_old = Function(self.u0)
self.u_old = self._backend.Function(self.u0)
ks = [self.u_old]
ks.extend(split_stages(self.u0.function_space(), self.stages))
num_samples = vander.shape[1]
Expand All @@ -265,4 +245,4 @@ def invalidate_jacobian(self):
"""
Forces the matrix to be reassembled next time it is required.
"""
LinearVariationalSolver.invalidate_jacobian(self.solver)
self._backend.invalidate_jacobian(self.solver)
55 changes: 26 additions & 29 deletions irksome/dirk_stepper.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,43 @@
import numpy
from firedrake import (derivative, Function,
LinearVariationalSolver,
NonlinearVariationalProblem,
NonlinearVariationalSolver)
from ufl.constantvalue import as_ufl
from ufl import as_ufl, lhs

from .ufl.deriv import TimeDerivative, expand_time_derivatives
from .constant import vecconst
from .tools import replace
from .constant import MeshConstant
from .backend import get_backend
from .bcs import bc2space
from .labeling import as_linear_form
from .ufl.deriv import TimeDerivative, expand_time_derivatives
from .tools import extract_timedep_arguments, replace


def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None, kgac=None):
def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None, kgac=None, backend="firedrake"):
backend_cls = get_backend(backend)
if bcs is None:
bcs = []

v, = F.arguments()
v, u = extract_timedep_arguments(F, u0)
V = v.function_space()
assert V == u0.function_space()

# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u,))

num_stages = butch.num_stages

# Note: the Constant c is used for substitution in both the
# variational form and BC's, and we update it for each stage in
# the loop over stages in the advance method. The Constant a is
# used similarly in the variational form
MC = MeshConstant(V.mesh())
MC = backend_cls.MeshConstant(V.mesh())
if kgac is None:
k = Function(V)
g = Function(V)
k = backend_cls.Function(V)
g = backend_cls.Function(V)
a = MC.Constant(1.0)
c = MC.Constant(1.0)
else:
k, g, a, c = kgac

# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))

repl = {t: t + c * dt,
u0: g + k * (a * dt),
TimeDerivative(u0): k}
u: g + k * (a * dt),
TimeDerivative(u): k}
stage_F = replace(F, repl)

bcnew = []
Expand Down Expand Up @@ -76,7 +72,9 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None,
solver_parameters=None,
appctx=None, nullspace=None,
transpose_nullspace=None, near_nullspace=None,
backend="firedrake",
**kwargs):
self._backend = backend_cls = get_backend(backend)
assert butcher_tableau.is_diagonally_implicit

self.num_steps = 0
Expand Down Expand Up @@ -114,15 +112,13 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None,
self.dt = dt
self.orig_bcs = bcs
self.num_fields = len(u0.function_space())
self.ks = [Function(V) for _ in range(num_stages)]
self.ks = [backend_cls.Function(V) for _ in range(num_stages)]

# "k" is a generic function for which we will solve the
# NVLP for the next stage value
# "ks" is a list of functions for the stage values
# that we update as we go. We need to remember the
# stage values we've computed earlier in the time step...
F = as_linear_form(F, u0)

stage_F, kgac, bcnew, (a_vals, d_val) = getFormDIRK(
F, self.ks, butcher_tableau, t, dt, u0, bcs=bcs)
k, g, a, c = kgac
Expand All @@ -132,9 +128,10 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None,

stage_Jp = None
if Fp is not None:
Fp = as_linear_form(Fp, u0)
stage_Fp, *_ = self.get_form_and_bcs(self.ks, F=Fp, bcs=())
stage_Jp = derivative(stage_Fp, k)
Fp_linear = len(Fp.arguments()) == 2
ks_Fp = Fp.arguments()[1] if Fp_linear else self.ks
stage_Fp, *_ = self.get_form_and_bcs(ks_Fp, F=Fp, bcs=())
stage_Jp = lhs(stage_Fp) if Fp_linear else backend_cls.derivative(stage_Fp, k)

appctx_irksome = {"stepper": self}
if appctx is None:
Expand All @@ -143,14 +140,14 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, Fp=None,
appctx = {**appctx, **appctx_irksome}
self.appctx = appctx

self.problem = NonlinearVariationalProblem(
self.problem = backend_cls.create_variational_problem(
stage_F, k, bcs=bcnew, Jp=stage_Jp,
form_compiler_parameters=kwargs.pop("form_compiler_parameters", None),
is_linear=kwargs.pop("is_linear", False),
restrict=kwargs.pop("restrict", False),
constant_jacobian=kwargs.pop("constant_jacobian", False),
)
self.problem._constant_jacobian = kwargs.pop("constant_jacobian", False)
self.solver = NonlinearVariationalSolver(
self.solver = backend_cls.create_variational_solver(
self.problem, appctx=appctx,
nullspace=nullspace,
transpose_nullspace=transpose_nullspace,
Expand Down Expand Up @@ -220,4 +217,4 @@ def invalidate_jacobian(self):
"""
Forces the matrix to be reassembled next time it is required.
"""
LinearVariationalSolver.invalidate_jacobian(self.solver)
self._backend.invalidate_jacobian(self.solver)
24 changes: 12 additions & 12 deletions irksome/discontinuous_galerkin_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .ufl.deriv import TimeDerivative, expand_time_derivatives
from .ufl.manipulation import split_time_derivative_terms, remove_time_derivatives
from .scheme import create_time_quadrature, ufc_line
from .tools import IA, dot, reshape, replace
from .tools import IA, dot, extract_timedep_arguments, reshape, replace
from .constant import vecconst
from .tableaux.ButcherTableaux import CollocationButcherTableau
from .stage_value import getFormStage
Expand Down Expand Up @@ -41,7 +41,7 @@ def getElement(basis_type, order):


def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"):
v, = F.arguments()
v, u = extract_timedep_arguments(F, u0)
V = v.function_space()
assert V == u0.function_space()

Expand All @@ -68,7 +68,7 @@ def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"):
dtusub = dot(trial_dvals.T, u_np)

# preprocess time derivatives
split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u0,))
split_form = split_time_derivative_terms(F, t=t, timedep_coeffs=(u,))
F_dtless = remove_time_derivatives(split_form.time)
if F_dtless.empty():
Fnew = F_dtless
Expand All @@ -81,8 +81,8 @@ def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"):
u_at_0 = dot(L_at_0.T, u_np)
v_at_0 = dot(L_at_0.T, v_np)

repl_tminus = {v: v_at_0}
repl_tplus = {v: v_at_0, u0: u_at_0}
repl_tminus = {v: v_at_0, u: u0}
repl_tplus = {v: v_at_0, u: u_at_0}
Fnew = replace(F_dtless, repl_tplus) - replace(F_dtless, repl_tminus)
F_remainder = F
elif deriv_type == "weak":
Expand All @@ -93,8 +93,8 @@ def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"):
u_at_01 = dot(L_at_01.T, u_np)
v_at_01 = dot(L_at_01.T, v_np)

repl_tminus = {v: v_at_01[0]}
repl_tnew = {v: v_at_01[1], u0: u_at_01[1], t: t + dt}
repl_tminus = {v: v_at_01[0], u: u0}
repl_tnew = {v: v_at_01[1], u: u_at_01[1], t: t + dt}
Fnew = replace(F_dtless, repl_tnew) - replace(F_dtless, repl_tminus)

# Terms with time derivatives: -(g(u), Dt(v))
Expand All @@ -103,20 +103,20 @@ def getTermDiscGalerkin(F, L, Q, t, dt, u0, stages, test, deriv_type="strong"):
for q in range(len(qpts)):
repl = {t: t + qpts[q] * dt,
v: dtvsub[q],
u0: usub[q]}
u: usub[q]}
Fnew -= replace(F_dtless, repl)
F_remainder = split_form.remainder
else:
raise ValueError(f"Unrecongnized deriv_type {deriv_type}")

# Handle the rest of the terms
F_remainder = expand_time_derivatives(F_remainder, t=t, timedep_coeffs=(u0,))
dtu0 = TimeDerivative(u0)
F_remainder = expand_time_derivatives(F_remainder, t=t, timedep_coeffs=(u,))
dtu = TimeDerivative(u)
for q in range(len(qpts)):
repl = {t: t + qpts[q] * dt,
v: vsub[q] * dt,
u0: usub[q],
dtu0: dtusub[q] / dt}
u: usub[q],
dtu: dtusub[q] / dt}
Fnew += replace(F_remainder, repl)
return Fnew

Expand Down
Loading
Loading