diff --git a/irksome/__init__.py b/irksome/__init__.py index 7736a7b2..f99c1a6c 100644 --- a/irksome/__init__.py +++ b/irksome/__init__.py @@ -33,4 +33,10 @@ from .scheme import ContinuousPetrovGalerkinScheme, DiscontinuousGalerkinScheme # noqa: F401 from .galerkin_stepper import ContinuousPetrovGalerkinTimeStepper # noqa: F401 from .discontinuous_galerkin_stepper import DiscontinuousGalerkinTimeStepper # noqa: F401 -from .labeling import TimeQuadratureLabel # noqa: F401 +from .labeling import ( + TimeQuadratureLabel, MeasureOverride, # noqa: F401 + dx_override, ds_override, dS_override, dr_override, dP_override, # noqa: F401 + dc_override, dC_override, dI_override, dO_override, # noqa: F401 + ds_b_override, ds_t_override, ds_v_override, # noqa: F401 + dS_h_override, dS_v_override, # noqa: F401 +) diff --git a/irksome/labeling.py b/irksome/labeling.py index 30a72e72..bf5c0cd7 100644 --- a/irksome/labeling.py +++ b/irksome/labeling.py @@ -1,5 +1,7 @@ from firedrake.fml import Label, keep, drop, LabelledForm from .scheme import create_time_quadrature +from ufl.form import Form +from ufl.measure import Measure import numpy as np explicit = Label("explicit") @@ -33,27 +35,81 @@ def get_weights(self): def split_quadrature(F, Qdefault=None): - if not isinstance(F, LabelledForm): + """Split a form into subforms grouped by time quadrature rule. + + Supports two mechanisms: + 1) firedrake.fml labels using TimeQuadratureLabel/TimeQuadratureRule + 2) UFL integral metadata containing Irksome keys + ("quadrature_degree_time" and optionally "quadrature_scheme_time"). + + If neither labelling nor metadata overrides are present, returns + a single entry mapping Qdefault -> F. + """ + # Case 1: LabelledForm path (existing behaviour) + if isinstance(F, LabelledForm): + quad_labels = set() + for term in F.terms: + cur_labels = [label for label in term.labels if isinstance(label, TimeQuadratureRule)] + if len(cur_labels) == 1: + quad_labels.update(cur_labels) + elif len(cur_labels) > 1: + raise ValueError("Multiple quadrature labels on one term.") + + splitting = {Q: F.label_map(lambda t: Q in t.labels, map_if_true=keep, map_if_false=drop) + for Q in quad_labels} + splitting[Qdefault] = F.label_map(lambda t: len(quad_labels.intersection(t.labels)) > 0, + map_if_true=drop, map_if_false=keep) + for Q in list(splitting): + try: + splitting[Q] = splitting[Q].form + except TypeError: + splitting.pop(Q) + return splitting + + # Case 2: Plain UFL form with per-integral metadata overrides + # See if I can recover integral; it not, return default + try: + integrals = F.integrals() + except Exception: + return {Qdefault: F} + + # Scan for Irksome metadata; if none present, return default + IRK_DEG = "quadrature_degree_override" + IRK_SCH = "quadrature_scheme_override" + has_override = any( + (IRK_DEG in (I.metadata() or {}) or IRK_SCH in (I.metadata() or {})) + for I in integrals + ) + if not has_override: return {Qdefault: F} - quad_labels = set() - for term in F.terms: - cur_labels = [label for label in term.labels if isinstance(label, TimeQuadratureRule)] - if len(cur_labels) == 1: - quad_labels.update(cur_labels) - elif len(cur_labels) > 1: - raise ValueError("Multiple quadrature labels on one term.") - - splitting = {Q: F.label_map(lambda t: Q in t.labels, map_if_true=keep, map_if_false=drop) - for Q in quad_labels} - splitting[Qdefault] = F.label_map(lambda t: len(quad_labels.intersection(t.labels)) > 0, - map_if_true=drop, map_if_false=keep) - for Q in list(splitting): - try: - splitting[Q] = splitting[Q].form - except TypeError: - splitting.pop(Q) - return splitting + # Since we got here, build groups keyed by (degree, scheme) tuples + groups = {} + default_ints = [] + # For each integral... + for I in integrals: + # ...get the metadata... + md = I.metadata() or {} + deg = md.get(IRK_DEG, None) + sch = md.get(IRK_SCH, None) + if deg is None: + # ...if no quadrature override is specified, add to default... + default_ints.append(I) + else: + # ...and otherwise, record in groups + sch = sch if sch is not None else "default" + key = (int(deg), str(sch)) + groups.setdefault(key, []).append(I) + + # Now, assemble into a dictionary as required using create_time_quadrature + result = {} + for (deg, sch), ints in groups.items(): + Q = create_time_quadrature(deg, scheme=sch) + result[Q] = Form(ints) + if default_ints: + result[Qdefault] = Form(default_ints) + + return result def split_explicit(F): @@ -67,3 +123,95 @@ def split_explicit(F): map_if_true=keep, map_if_false=drop) return imp_part.form, exp_part.form + + +class MeasureOverride(Measure): + """Thin wrappers around UFL Measures that allow users to tag + individual integrals with Irksome-specific overrides for + time quadrature used by Galerkin-in-time discretisations. + + Usage example: + F = inner(Dt(u), v) * dx_override(time_degree_override=5) + inner(u, v) * dx + + Here, only the first term will be integrated in time with a rule of + degree 5; the other terms will use the scheme defaults. + """ + def __call__( + self, + subdomain_id=None, + metadata=None, + domain=None, + subdomain_data=None, + degree=None, + scheme=None, + *, + time_degree_override=None, + time_scheme_override=None, + ): + """Reconfigure measure with (optional) time quadrature overrides. + + The optional keyword-only arguments time_degree_override and time_scheme_override + are stored in metadata keys understood by Irksome's Galerkin-in-time + machinery in split_quadrature(). + """ + # Inject time overrides into metadata + if time_degree_override is None and time_scheme_override is not None: + raise ValueError( + "Time quadrature override requires specification of time_degree_override." + ) + if time_degree_override is not None or time_scheme_override is not None: + metadata = {} if metadata is None else metadata.copy() + if time_degree_override is not None: + metadata["quadrature_degree_override"] = time_degree_override + if time_scheme_override is not None: + metadata["quadrature_scheme_override"] = time_scheme_override + + # Inject spatial (degree, scheme) into metadata if provided, mirroring + # ufl.measure.Measure.__call__ semantics. + if (degree, scheme) != (None, None): + metadata = {} if metadata is None else metadata.copy() + if degree is not None: + metadata["quadrature_degree"] = degree + if scheme is not None: + metadata["quadrature_rule"] = scheme + + # Support dx(domain) style: if first positional looks like a domain, treat accordingly + if subdomain_id is not None and hasattr(subdomain_id, "ufl_domain"): + if domain is not None: + raise ValueError( + "Ambiguous: setting domain both as keyword argument and first argument." + ) + subdomain_id, domain = "everywhere", subdomain_id + + # Without args, return everywhere + if all(x is None for x in (subdomain_id, metadata, domain, subdomain_data, degree, scheme)) and ( + time_degree_override is None and time_scheme_override is None + ): + return self.reconstruct(subdomain_id="everywhere") + + # Construct new Measure + return Measure( + self.integral_type(), + domain=domain or self.ufl_domain(), + subdomain_id=subdomain_id if subdomain_id is not None else self.subdomain_id(), + metadata=metadata if metadata is not None else self.metadata(), + subdomain_data=subdomain_data if subdomain_data is not None else self.subdomain_data(), + ) + + +# Convenience instances mirroring Firedrake/UFL defaults +dx_override = MeasureOverride("cell") +ds_override = MeasureOverride("exterior_facet") +dS_override = MeasureOverride("interior_facet") +dr_override = MeasureOverride("ridge") +dP_override = MeasureOverride("vertex") +dc_override = MeasureOverride("custom") +dC_override = MeasureOverride("cutcell") +dI_override = MeasureOverride("interface") +dO_override = MeasureOverride("overlap") +ds_b_override = MeasureOverride("exterior_facet_bottom") +ds_t_override = MeasureOverride("exterior_facet_top") +ds_v_override = MeasureOverride("exterior_facet_vert") +dS_h_override = MeasureOverride("interior_facet_horiz") +dS_v_override = MeasureOverride("interior_facet_vert") + diff --git a/tests/test_measureoverride.py b/tests/test_measureoverride.py new file mode 100644 index 00000000..a720ed67 --- /dev/null +++ b/tests/test_measureoverride.py @@ -0,0 +1,56 @@ +import pytest + +from firedrake import * +from irksome import Dt, TimeStepper, ContinuousPetrovGalerkinScheme, dx_override + +@pytest.mark.parametrize("order", [1, 2, 3]) +@pytest.mark.parametrize("scheme", ["gauss", "cpg"]) +def test_nls(order, scheme): + # Domain and space + mesh = PeriodicUnitIntervalMesh(10) + x, = SpatialCoordinate(mesh) + V = FunctionSpace(mesh, "CG", 1) + Z = V * V + + # State and test functions + psi = Function(Z) + a, b = split(psi) + c, d = TestFunctions(Z) + + # Initial condition: cosine + psi.project(as_vector([cos(x), 0])) + + # Time parameters + t = Constant(0.0) + dt = Constant(0.1) + + # Residual + dx_highorder = dx if scheme == "gauss" else dx_override(time_degree_override=4*order-1) + amp_sq = a**2 + b**2 + F = ( + inner(Dt(b), c) * dx + + 0.5 * inner(grad(a), grad(c)) * dx + - inner(amp_sq * a, c) * dx_highorder + - inner(Dt(a), d) * dx + + 0.5 * inner(grad(b), grad(d)) * dx + - inner(amp_sq * b, d) * dx_highorder + ) + + # Energy + E = 0.5 * (inner(grad(a), grad(a)) + inner(grad(b), grad(b)) - amp_sq**2) * dx + + # Time stepper with cPG(k); default time quadrature is 2k-1 + scheme_ = ContinuousPetrovGalerkinScheme(order=order, quadrature_degree=2*order-1) + stepper = TimeStepper(F, scheme_, t, dt, psi) + + # Record initial energy + E0 = float(assemble(E)) + + # Advance once + stepper.advance() + + # Final energy and drift + E1 = float(assemble(E)) + drift = abs(E1 - E0) + if scheme == "gauss": assert drift > 1e-10 + else: assert drift < 1e-10