Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import collections

from ufl import as_tensor, as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.classes import Form, Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags
Expand Down Expand Up @@ -173,27 +173,57 @@ def cofunction(self, o):
return Cofunction(W, val=MixedDat(o.dat[i] for i in indices))

def matrix(self, o):
from firedrake.bcs import DirichletBC, EquationBCSplit
ises = []
args = []
argument_indices = []
for a in o.arguments():
V = a.function_space()
iset = PETSc.IS()
if a.number() in self.blocks:
fields = self.blocks[a.number()]
asplit = self._subspace_argument(a)
for f in self.blocks[a.number()]:
for f in fields:
fset = V.dof_dset.field_ises[f]
iset = iset.expand(fset)
else:
fields = tuple(range(len(V)))
asplit = a
for fset in V.dof_dset.field_ises:
iset = iset.expand(fset)

ises.append(iset)
args.append(asplit)
argument_indices.append(fields)

if isinstance(o.a, Form):
form = self.split(o.a, argument_indices=argument_indices)
if isinstance(form, ZeroBaseForm):
return form
else:
form = None

submat = o.petscmat.createSubMatrix(*ises)
bcs = ()
return AssembledMatrix(tuple(args), bcs, submat)

bcs = []
spaces = [a.function_space() for a in o.arguments()]
for bc in o.bcs:
W = bc.function_space()
while W.parent is not None:
W = W.parent

number = spaces.index(W)
field = self.blocks[number]
V = args[number].function_space()
if isinstance(bc, DirichletBC):
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, use_split=True)
elif isinstance(bc, EquationBCSplit):
row_field, col_field = argument_indices
bc_temp = bc.reconstruct(field=field, V=V, row_field=row_field, col_field=col_field, use_split=True)
if bc_temp is not None:
bcs.append(bc_temp)

return AssembledMatrix(form or tuple(args), tuple(bcs), submat)

def zero_base_form(self, o):
return ZeroBaseForm(tuple(map(self, o.arguments())))
Expand Down
4 changes: 3 additions & 1 deletion firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,10 @@ def split(self, fields):
Jp = replace(Jp, {problem.u_restrict: u})
else:
Jp = None
# A preassembled Jacobian already encodes the boundary conditions
orig_bcs = [] if isinstance(J, MatrixBase) else problem.bcs
bcs = []
for bc in problem.bcs:
for bc in orig_bcs:
if isinstance(bc, DirichletBC):
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain)
elif isinstance(bc, EquationBC):
Expand Down
58 changes: 58 additions & 0 deletions tests/firedrake/adjoint/test_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,61 @@ def test_ad_dot(riesz_representation):
h.dat.data[:] = np.random.rand(V.dof_dset.size)
dJdh = dJhat._ad_dot(h, options={'riesz_representation': riesz_representation})
assert taylor_test(Jhat, f, h, dJdm=dJdh) > 1.9


@pytest.mark.skipcomplex
def test_fieldsplit():
mesh = UnitSquareMesh(2, 2)
V = VectorFunctionSpace(mesh, "CG", 2)
Q = FunctionSpace(mesh, "CG", 1)
W = MixedFunctionSpace([V, Q])

bcs = [DirichletBC(W.sub(0), Constant((0, 0)), (1, 2, 3)),
DirichletBC(W.sub(0), Constant((1, 0)), 4)]

sp = {
'mat_type': 'nest',
'snes_converged_reason': None,
'ksp_converged_reason': None,
'ksp_type': 'fgmres',
'pc_type': 'fieldsplit',
'pc_fieldsplit_type': 'schur',
'pc_fieldsplit_schur_factorization_type': 'full',
'fieldsplit_0': {
'ksp_type': 'preonly',
'pc_type': 'lu',
"pc_factor_mat_solver_type": 'mumps',
},
'fieldsplit_1': {
'ksp_type': 'cg',
'ksp_rtol': 1e-9,
'ksp_atol': 1e-9,
'pc_type': 'python',
'pc_python_type': 'firedrake.MassInvPC',
'Mp_pc_type': 'lu',
'Mp_pc_factor_mat_solver_type': 'mumps',
},
}

constant_nsp = VectorSpaceBasis(constant=True, comm=Q.comm)
nsp = MixedVectorSpaceBasis(W, [W.sub(0), constant_nsp])

A = FunctionSpace(mesh, "CG", 1)
rho = Function(A).interpolate(Constant(1))

w = Function(W)
(u, p) = split(w)
z = TestFunction(W)
(v, q) = split(z)
F = (inner(sym(grad(u)) * rho, sym(grad(v))) * dx
- inner(p, div(v)) * dx
- inner(div(u), q) * dx
)
solve(F == 0, w, bcs, solver_parameters=sp, nullspace=nsp)

J = assemble(0.5*inner(sym(grad(u)) * rho, sym(grad(u))) * dx)
Jhat = ReducedFunctional(J, Control(rho))

rg = RandomGenerator(PCG64(seed=0))
h = rg.uniform(A)
assert taylor_test(Jhat, rho, h) > 1.9
Loading