diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 737b630519..3579d36807 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -3,7 +3,7 @@ import collections from ufl import as_tensor, as_vector, split -from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm +from ufl.classes import Form, Zero, FixedIndex, ListTensor, ZeroBaseForm from ufl.algorithms.map_integrands import map_integrand_dags from ufl.algorithms import expand_derivatives from ufl.corealg.map_dag import MultiFunction, map_expr_dags @@ -173,27 +173,57 @@ def cofunction(self, o): return Cofunction(W, val=MixedDat(o.dat[i] for i in indices)) def matrix(self, o): + from firedrake.bcs import DirichletBC, EquationBCSplit ises = [] args = [] + argument_indices = [] for a in o.arguments(): V = a.function_space() iset = PETSc.IS() if a.number() in self.blocks: + fields = self.blocks[a.number()] asplit = self._subspace_argument(a) - for f in self.blocks[a.number()]: + for f in fields: fset = V.dof_dset.field_ises[f] iset = iset.expand(fset) else: + fields = tuple(range(len(V))) asplit = a for fset in V.dof_dset.field_ises: iset = iset.expand(fset) ises.append(iset) args.append(asplit) + argument_indices.append(fields) + + if isinstance(o.a, Form): + form = self.split(o.a, argument_indices=argument_indices) + if isinstance(form, ZeroBaseForm): + return form + else: + form = None submat = o.petscmat.createSubMatrix(*ises) - bcs = () - return AssembledMatrix(tuple(args), bcs, submat) + + bcs = [] + spaces = [a.function_space() for a in o.arguments()] + for bc in o.bcs: + W = bc.function_space() + while W.parent is not None: + W = W.parent + + number = spaces.index(W) + field = self.blocks[number] + V = args[number].function_space() + if isinstance(bc, DirichletBC): + bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, use_split=True) + elif isinstance(bc, EquationBCSplit): + row_field, col_field = argument_indices + bc_temp = bc.reconstruct(field=field, V=V, row_field=row_field, col_field=col_field, use_split=True) + if bc_temp is not None: + bcs.append(bc_temp) + + return AssembledMatrix(form or tuple(args), tuple(bcs), submat) def zero_base_form(self, o): return ZeroBaseForm(tuple(map(self, o.arguments()))) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index db18865712..358eea12b9 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -432,8 +432,10 @@ def split(self, fields): Jp = replace(Jp, {problem.u_restrict: u}) else: Jp = None + # A preassembled Jacobian already encodes the boundary conditions + orig_bcs = [] if isinstance(J, MatrixBase) else problem.bcs bcs = [] - for bc in problem.bcs: + for bc in orig_bcs: if isinstance(bc, DirichletBC): bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain) elif isinstance(bc, EquationBC): diff --git a/tests/firedrake/adjoint/test_reduced_functional.py b/tests/firedrake/adjoint/test_reduced_functional.py index ef6f3d1fba..7a3caebf27 100644 --- a/tests/firedrake/adjoint/test_reduced_functional.py +++ b/tests/firedrake/adjoint/test_reduced_functional.py @@ -298,3 +298,61 @@ def test_ad_dot(riesz_representation): h.dat.data[:] = np.random.rand(V.dof_dset.size) dJdh = dJhat._ad_dot(h, options={'riesz_representation': riesz_representation}) assert taylor_test(Jhat, f, h, dJdm=dJdh) > 1.9 + + +@pytest.mark.skipcomplex +def test_fieldsplit(): + mesh = UnitSquareMesh(2, 2) + V = VectorFunctionSpace(mesh, "CG", 2) + Q = FunctionSpace(mesh, "CG", 1) + W = MixedFunctionSpace([V, Q]) + + bcs = [DirichletBC(W.sub(0), Constant((0, 0)), (1, 2, 3)), + DirichletBC(W.sub(0), Constant((1, 0)), 4)] + + sp = { + 'mat_type': 'nest', + 'snes_converged_reason': None, + 'ksp_converged_reason': None, + 'ksp_type': 'fgmres', + 'pc_type': 'fieldsplit', + 'pc_fieldsplit_type': 'schur', + 'pc_fieldsplit_schur_factorization_type': 'full', + 'fieldsplit_0': { + 'ksp_type': 'preonly', + 'pc_type': 'lu', + "pc_factor_mat_solver_type": 'mumps', + }, + 'fieldsplit_1': { + 'ksp_type': 'cg', + 'ksp_rtol': 1e-9, + 'ksp_atol': 1e-9, + 'pc_type': 'python', + 'pc_python_type': 'firedrake.MassInvPC', + 'Mp_pc_type': 'lu', + 'Mp_pc_factor_mat_solver_type': 'mumps', + }, + } + + constant_nsp = VectorSpaceBasis(constant=True, comm=Q.comm) + nsp = MixedVectorSpaceBasis(W, [W.sub(0), constant_nsp]) + + A = FunctionSpace(mesh, "CG", 1) + rho = Function(A).interpolate(Constant(1)) + + w = Function(W) + (u, p) = split(w) + z = TestFunction(W) + (v, q) = split(z) + F = (inner(sym(grad(u)) * rho, sym(grad(v))) * dx + - inner(p, div(v)) * dx + - inner(div(u), q) * dx + ) + solve(F == 0, w, bcs, solver_parameters=sp, nullspace=nsp) + + J = assemble(0.5*inner(sym(grad(u)) * rho, sym(grad(u))) * dx) + Jhat = ReducedFunctional(J, Control(rho)) + + rg = RandomGenerator(PCG64(seed=0)) + h = rg.uniform(A) + assert taylor_test(Jhat, rho, h) > 1.9