diff --git a/docs/src/userguide/examples.rst b/docs/src/userguide/examples.rst index e66a5a14..a24edd3c 100644 --- a/docs/src/userguide/examples.rst +++ b/docs/src/userguide/examples.rst @@ -54,6 +54,97 @@ certain regridders. We can do this as follows:: # Use loaded regridder. result = loaded_regridder(source_mesh_cube) +Partitioning a Regridder +------------------------ + +If a regridder would be too large to handle in memory, it can be broken down +into smaller regridders which can collectively do the job of the larger regridder. +This is done using a `Partition` object. + +.. note:: Currently, it is only possible to partition regridding when the source is + a large grid and the target is small enough to fit in memory. + +A `Partition` is made by specifying a source, a target, a list of files, and a way +to divide the source grid into blocks whose regridders are saved to those files:: + + from iris.util import make_gridcube + + from esmf_regrid import ESMFAreaWeighted + from esmf_regrid.experimental.partition import Partition + + # Create a large source cube. + source_cube = make_gridcube(nx=800, ny=800) + + # Create a small target cube. + target_cube = make_gridcube(nx=100, ny=100) + + # Set the regridding scheme. + scheme = AreaWeighted() + + # List a collection of file names/paths to save partial regridders to. + files = ["file_1", "file_2", "file_3", "file_4"] + + # Set the size of each block of the partition. For the keyword `src_chunks` + # this follows the dask chunking API. + src_chunks = (400, 400) + + # Initialise the partition. + partition = Partition( + source_cube, + target_cube, + scheme, + files, + src_chunks=src_chunks + ) + +.. note:: there are several different ways of specifying the division of the + source into blocks : + see :class:`~esmf_regrid.experimental.partition.Partition`. + +Initialising the `Partition` will not generate the files automatically unless +the `auto_generate` keyword is set to `True`. In order for this `Partition` to +function, the regridder files must be generated by calling the `generate_files` +method:: + + # Generate partial regridders and save them to the list of files. + partition.generate_files() + + # Once the files have been generated, they can be used for regridding. + result = partition.apply_regridders(source_cube) + +.. note:: Not all files need to be generated at once, if you have a grid which + needs to be split into very many files, it is possible to generate only + a portion of those files within a given session by passing the number + of files to generate as an argument to the regridder. It is then possible + to split the file generation in batches across multiple python sessions. + +Once the files for a regridder have been generated, they can be used to reconstruct +the partition object in a later session. This is done by passing in the list of +files which have already been generated:: + + # Use the same arguments which constructed the original partition. + source_cube = make_gridcube(nx=800, ny=800) + target_cube = make_gridcube(nx=100, ny=100) + scheme = AreaWeighted() + files = ["file_1", "file_2", "file_3", "file_4"] + src_chunks = (400, 400) + + # List the files which have already been generated. + saved_files = ["file_1", "file_2", "file_3", "file_4"] + + # Reconstruct Partition from pre-generated files. + partition = Partition( + source_cube, + target_cube, + scheme, + files, + src_chunks=src_chunks + saved_files=saved_files # Pass in the list of saved files. + ) + + # The new Partition can now be used without the need for generating files. + result = partition.apply_regridders(source_cube) + .. todo: Add more examples. diff --git a/pyproject.toml b/pyproject.toml index e2cd86ef..10ada6d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -298,5 +298,5 @@ convention = "numpy" [tool.ruff.lint.pylint] # TODO: refactor to reduce complexity, if possible max-args = 10 -max-branches = 21 +max-branches = 25 max-statements = 110 diff --git a/src/esmf_regrid/esmf_regridder.py b/src/esmf_regrid/esmf_regridder.py index 800fdbd3..c335320c 100644 --- a/src/esmf_regrid/esmf_regridder.py +++ b/src/esmf_regrid/esmf_regridder.py @@ -175,6 +175,46 @@ def _out_dtype(self, in_dtype): ).dtype return out_dtype + def _gen_weights_and_data(self, src_array): + extra_shape = src_array.shape[: -self.src.dims] + + if self.method == Constants.Method.NEAREST: + weight_matrix = self.weight_matrix.astype(src_array.dtype) + else: + weight_matrix = self.weight_matrix + + flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0)) + flat_tgt = weight_matrix @ flat_src + + src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array)) + weight_sums = weight_matrix @ src_inverted_mask + return weight_sums, flat_tgt, extra_shape + + def _regrid_from_weights_and_data( + self, + tgt_weights, + tgt_data, + extra, + norm_type=Constants.NormType.FRACAREA, + mdtol=1, + ): + # Set the minimum mdtol to be slightly higher than 0 to account for rounding + # errors. + mdtol = max(mdtol, 1e-8) + tgt_mask = tgt_weights > 1 - mdtol + normalisations = np.ones_like(tgt_data) + if self.method != Constants.Method.NEAREST: + masked_weight_sums = tgt_weights * tgt_mask + if norm_type == Constants.NormType.FRACAREA: + normalisations[tgt_mask] /= masked_weight_sums[tgt_mask] + elif norm_type == Constants.NormType.DSTAREA: + pass + normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask)) + + tgt_array = tgt_data * normalisations + tgt_array = self.tgt._matrix_to_array(tgt_array, extra) + return tgt_array + def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): """Perform regridding on an array of data. @@ -212,30 +252,8 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): f"got an array with shape ending in {main_shape}." ) raise ValueError(e_msg) - extra_shape = array_shape[: -self.src.dims] - extra_size = max(1, np.prod(extra_shape)) - src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array)) - weight_matrix = self.weight_matrix - if self.method == Constants.Method.NEAREST: - # force out_dtype := in_dtype - weight_matrix = weight_matrix.astype(src_array.dtype) - weight_sums = weight_matrix @ src_inverted_mask - out_dtype = self._out_dtype(src_array.dtype) - # Set the minimum mdtol to be slightly higher than 0 to account for rounding - # errors. - mdtol = max(mdtol, 1e-8) - tgt_mask = weight_sums > 1 - mdtol - normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype) - if self.method != Constants.Method.NEAREST: - masked_weight_sums = weight_sums * tgt_mask - if norm_type == Constants.NormType.FRACAREA: - normalisations[tgt_mask] /= masked_weight_sums[tgt_mask] - elif norm_type == Constants.NormType.DSTAREA: - pass - normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask)) - - flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0)) - flat_tgt = weight_matrix @ flat_src - flat_tgt = flat_tgt * normalisations - tgt_array = self.tgt._matrix_to_array(flat_tgt, extra_shape) + tgt_weights, tgt_data, extra = self._gen_weights_and_data(src_array) + tgt_array = self._regrid_from_weights_and_data( + tgt_weights, tgt_data, extra, norm_type=norm_type, mdtol=mdtol + ) return tgt_array diff --git a/src/esmf_regrid/experimental/_partial.py b/src/esmf_regrid/experimental/_partial.py new file mode 100644 index 00000000..ffaf8708 --- /dev/null +++ b/src/esmf_regrid/experimental/_partial.py @@ -0,0 +1,124 @@ +"""Provides a regridder class compatible with Partition.""" + +import numpy as np + +from esmf_regrid.schemes import ( + GridRecord, + MeshRecord, + _create_cube, + _ESMFRegridder, +) + + +class PartialRegridder(_ESMFRegridder): + """Regridder class designed for use in :class:`~esmf_regrid.experimental.Partition`.""" + + def __init__(self, src, tgt, src_slice, tgt_slice, weights, scheme, **kwargs): + """Create a regridder instance for a block of :class:`~esmf_regrid.experimental.Partition`. + + Parameters + ---------- + src : :class:`iris.cube.Cube` + The :class:`~iris.cube.Cube` providing the source. + tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` + The :class:`~iris.cube.Cube` or :class:`~iris.mesh.MeshXY` providing the target. + src_slice : tuple + The upper and lower bounds of the block taken from the original source from which the + ``src`` was derived. In the form ((x_low, x_high), ...) where x_low and x_high are the + upper and lower bounds of the slice (in the x dimension) taken from the original source. + There are as many tuples of upper and lower bounds as there are horizontal dimensions in + the source cube (currently this is always 2 as Meshes are not yet supported for sources). + tgt_slice : tuple + The upper and lower bounds of the block taken from the original target from which the + ``tgt`` was derived. In the form ((x_low, x_high), ...) where x_low and x_high are the + upper and lower bounds of the slice (in the x dimension) taken from the original target. + There are as many tuples of upper and lower bounds as there are horizontal dimensions in + the target cube. + weights : :class:`scipy.sparse.spmatrix` + The weights to use for regridding. + scheme : :class:`~esmf_regrid.schemes.ESMFAreaWeighted` or :class:`~esmf_regrid.schemes.ESMFBilinear` + The scheme used to construct the regridder. + kwargs : dict + Additional keyword arguments to pass to the `scheme`s regridder method. + """ + self.src_slice = src_slice # this will be tuple-like + self.tgt_slice = tgt_slice + self.scheme = scheme + + self._regridder = scheme.regridder( + src, + tgt, + precomputed_weights=weights, + **kwargs, + ) + self.__dict__.update(self._regridder.__dict__) + + def __repr__(self): + """Return a representation of the class.""" + result = ( + f"PartialRegridder(" + f"src={self._src}, " + f"tgt_slice={self._tgt}, " + f"src_slice={self.src_slice}, " + f"tgt_slice={self.tgt_slice}, " + f"scheme={self.scheme})" + ) + return result + + def partial_regrid(self, src): + """Perform the first half of regridding, generating weights and data.""" + dims = self._get_cube_dims(src) + num_dims = len(dims) + standard_in_dims = [-1, -2][:num_dims] + data = np.moveaxis(src.data, dims, standard_in_dims) + result = self.regridder._gen_weights_and_data(data) + return result + + def finish_regridding(self, src_cube, weights, data, extra): + """Perform the second half of regridding, combining weights and data. + + This operation is used to process the combined results from all the partial + regridders in a Partition. + Since all the combined data is passed in, this operation can be done using + *any one* of the individual PartialRegridders. + However, the passed "src_cube" must be the "correct" slice of the source + data cube, corresponding to the 'tgt_slice' slice params it was created with. + It is also implicit that the 'extra' arg (additional dimensions) will be the + same for all partial results. + The `src_cube` provides coordinates for the non-horizontal dimensions of the + result cube, matching the dimensions of the `data` array. + For technical convenience, its *horizontal* coordinates need to match those + of the 'src' reference cube provided in regridder creation (`self._src`). + So, it must be the correct "corresponding slice" of the source cube. + """ + src_dims = self._get_cube_dims(src_cube) + + result_data = self.regridder._regrid_from_weights_and_data(weights, data, extra) + + num_out_dims = self.regridder.tgt.dims + num_dims = len(src_dims) + standard_out_dims = [-1, -2][:num_out_dims] + if num_dims == 2 and num_out_dims == 1: + new_dims = [min(src_dims)] + elif num_dims == 1 and num_out_dims == 2: + # Note: this code is currently inaccessible since src_cube can't have a Mesh. + new_dims = [src_dims[0] + 1, src_dims[0]] + else: + new_dims = src_dims + + result_data = np.moveaxis(result_data, standard_out_dims, new_dims) + + if isinstance(self._tgt, GridRecord): + tgt_coords = self._tgt + out_dims = 2 + elif isinstance(self._tgt, MeshRecord): + tgt_coords = self._tgt.mesh.to_MeshCoords(self._tgt.location) + out_dims = 1 + else: + msg = "Unrecognised target information." + raise TypeError(msg) + + result_cube = _create_cube( + result_data, src_cube, src_dims, tgt_coords, out_dims + ) + return result_cube diff --git a/src/esmf_regrid/experimental/io.py b/src/esmf_regrid/experimental/io.py index e71d7365..62646e01 100644 --- a/src/esmf_regrid/experimental/io.py +++ b/src/esmf_regrid/experimental/io.py @@ -10,12 +10,15 @@ import esmf_regrid from esmf_regrid import Constants, _load_context, check_method, esmpy +from esmf_regrid.experimental._partial import PartialRegridder from esmf_regrid.experimental.unstructured_scheme import ( GridToMeshESMFRegridder, MeshToGridESMFRegridder, ) from esmf_regrid.schemes import ( + ESMFAreaWeighted, ESMFAreaWeightedRegridder, + ESMFBilinear, ESMFBilinearRegridder, ESMFNearestRegridder, GridRecord, @@ -28,6 +31,7 @@ ESMFNearestRegridder, GridToMeshESMFRegridder, MeshToGridESMFRegridder, + PartialRegridder, ] _REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS} _SOURCE_NAME = "regridder_source_field" @@ -47,6 +51,8 @@ _SOURCE_RESOLUTION = "src_resolution" _TARGET_RESOLUTION = "tgt_resolution" _ESMF_ARGS = "esmf_args" +_SRC_SLICE_NAME = "src_slice" +_TGT_SLICE_NAME = "tgt_slice" _VALID_ESMF_KWARGS = [ "pole_method", "regrid_pole_npoints", @@ -69,6 +75,7 @@ "extrap_method": _EXTRAP_METHOD_DICT, "unmapped_action": _UNMAPPED_ACTION_DICT, } +_ESMF_BOOL_ARGS = ["ignore_degenerate", "large_file"] def _add_mask_to_cube(mask, cube, name): @@ -118,54 +125,41 @@ def _clean_var_names(cube): con.var_name = None -def save_regridder(rg, filename): - """Save a regridder scheme instance. +def _standard_grid_cube(grid, name): + if grid[0].ndim == 1: + shape = [coord.points.size for coord in grid] + else: + shape = grid[0].shape + data = np.zeros(shape) + cube = Cube(data, var_name=name, long_name=name) + if grid[0].ndim == 1: + cube.add_dim_coord(grid[0], 0) + cube.add_dim_coord(grid[1], 1) + else: + cube.add_aux_coord(grid[0], [0, 1]) + cube.add_aux_coord(grid[1], [0, 1]) + return cube - Saves any of the regridder classes, i.e. - :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, - :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, - :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, - :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or - :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. - . - Parameters - ---------- - rg : :class:`~esmf_regrid.schemes._ESMFRegridder` - The regridder instance to save. - filename : str - The file name to save to. - """ - regridder_type = rg.__class__.__name__ +def _standard_mesh_cube(mesh, location, name): + mesh_coords = mesh.to_MeshCoords(location) + data = np.zeros(mesh_coords[0].points.shape[0]) + cube = Cube(data, var_name=name, long_name=name) + for coord in mesh_coords: + cube.add_aux_coord(coord, 0) + return cube - def _standard_grid_cube(grid, name): - if grid[0].ndim == 1: - shape = [coord.points.size for coord in grid] - else: - shape = grid[0].shape - data = np.zeros(shape) - cube = Cube(data, var_name=name, long_name=name) - if grid[0].ndim == 1: - cube.add_dim_coord(grid[0], 0) - cube.add_dim_coord(grid[1], 1) - else: - cube.add_aux_coord(grid[0], [0, 1]) - cube.add_aux_coord(grid[1], [0, 1]) - return cube - - def _standard_mesh_cube(mesh, location, name): - mesh_coords = mesh.to_MeshCoords(location) - data = np.zeros(mesh_coords[0].points.shape[0]) - cube = Cube(data, var_name=name, long_name=name) - for coord in mesh_coords: - cube.add_aux_coord(coord, 0) - return cube +def _generate_src_tgt(regridder_type, rg, allow_partial): if regridder_type in [ "ESMFAreaWeightedRegridder", "ESMFBilinearRegridder", "ESMFNearestRegridder", + "PartialRegridder", ]: + if regridder_type == "PartialRegridder" and not allow_partial: + e_msg = "To save a PartialRegridder, `allow_partial=True` must be set." + raise ValueError(e_msg) src_grid = rg._src if isinstance(src_grid, GridRecord): src_cube = _standard_grid_cube( @@ -210,12 +204,36 @@ def _standard_mesh_cube(mesh, location, name): tgt_grid = (rg.grid_y, rg.grid_x) tgt_cube = _standard_grid_cube(tgt_grid, _TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, _TARGET_MASK_NAME) + else: - e_msg = ( - f"Expected a regridder of type `GridToMeshESMFRegridder` or " - f"`MeshToGridESMFRegridder`, got type {regridder_type}." - ) + e_msg = f"Unexpected regridder type {regridder_type}." raise TypeError(e_msg) + return src_cube, tgt_cube + + +def save_regridder(rg, filename, allow_partial=False): + """Save a regridder scheme instance. + + Saves any of the regridder classes, i.e. + :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, + :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, + :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, + :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or + :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. + . + + Parameters + ---------- + rg : :class:`~esmf_regrid.schemes._ESMFRegridder` + The regridder instance to save. + filename : str + The file name to save to. + allow_partial : bool, default=False + If True, allow the saving of :class:`~esmf_regrid.experimental._partial.PartialRegridder` instances. + """ + regridder_type = rg.__class__.__name__ + + src_cube, tgt_cube = _generate_src_tgt(regridder_type, rg, allow_partial) method = str(check_method(rg.method).name) @@ -223,7 +241,7 @@ def _standard_mesh_cube(mesh, location, name): resolution = rg.resolution src_resolution = None tgt_resolution = None - elif regridder_type == "ESMFAreaWeightedRegridder": + elif method == "CONSERVATIVE": resolution = None src_resolution = rg.src_resolution tgt_resolution = rg.tgt_resolution @@ -264,6 +282,22 @@ def _standard_mesh_cube(mesh, location, name): if tgt_resolution is not None: attributes[_TARGET_RESOLUTION] = tgt_resolution + extra_cubes = [] + if regridder_type == "PartialRegridder": + src_slice = rg.src_slice # this slice is described by a tuple + if src_slice is None: + src_slice = [] + src_slice_cube = Cube( + src_slice, long_name=_SRC_SLICE_NAME, var_name=_SRC_SLICE_NAME + ) + tgt_slice = rg.tgt_slice # this slice is described by a tuple + if tgt_slice is None: + tgt_slice = [] + tgt_slice_cube = Cube( + tgt_slice, long_name=_TGT_SLICE_NAME, var_name=_TGT_SLICE_NAME + ) + extra_cubes = [src_slice_cube, tgt_slice_cube] + weights_cube = Cube(weight_data, var_name=_WEIGHTS_NAME, long_name=_WEIGHTS_NAME) row_coord = AuxCoord( weight_rows, var_name=_WEIGHTS_ROW_NAME, long_name=_WEIGHTS_ROW_NAME @@ -298,7 +332,9 @@ def _standard_mesh_cube(mesh, location, name): # Save cubes while ensuring var_names do not conflict for the sake of consistency. with _managed_var_name(src_cube, tgt_cube): - cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube]) + cube_list = CubeList( + [src_cube, tgt_cube, weights_cube, weight_shape_cube, *extra_cubes] + ) for cube in cube_list: cube.attributes = attributes @@ -306,7 +342,7 @@ def _standard_mesh_cube(mesh, location, name): iris.fileformats.netcdf.save(cube_list, filename) -def load_regridder(filename): +def load_regridder(filename, allow_partial=False): """Load a regridder scheme instance. Loads any of the regridder classes, i.e. @@ -320,6 +356,8 @@ def load_regridder(filename): ---------- filename : str The file name to load from. + allow_partial : bool, default=False + If True, allow the loading of :class:`~esmf_regrid.experimental._partial.PartialRegridder` instances. Returns ------- @@ -343,6 +381,12 @@ def load_regridder(filename): raise TypeError(e_msg) scheme = _REGRIDDER_NAME_MAP[regridder_type] + if regridder_type == "PartialRegridder" and not allow_partial: + e_msg = ( + "PartialRegridder cannot be loaded without setting `allow_partial=True`." + ) + raise ValueError(e_msg) + # Determine the regridding method, allowing for files created when # conservative regridding was the only method. method_string = weights_cube.attributes.get(_METHOD, "CONSERVATIVE") @@ -372,11 +416,11 @@ def load_regridder(filename): mdtol = weights_cube.attributes[_MDTOL] if src_cube.coords(_SOURCE_MASK_NAME): - use_src_mask = src_cube.coord(_SOURCE_MASK_NAME).points + use_src_mask = src_cube.coord(_SOURCE_MASK_NAME).points.astype(bool) else: use_src_mask = False if tgt_cube.coords(_TARGET_MASK_NAME): - use_tgt_mask = tgt_cube.coord(_TARGET_MASK_NAME).points + use_tgt_mask = tgt_cube.coord(_TARGET_MASK_NAME).points.astype(bool) else: use_tgt_mask = False @@ -389,6 +433,9 @@ def load_regridder(filename): for arg, arg_dict in _ESMF_ENUM_ARGS.items(): if arg in esmf_args: esmf_args[arg] = arg_dict[esmf_args[arg]] + for arg in _ESMF_BOOL_ARGS: + if arg in esmf_args: + esmf_args[arg] = bool(esmf_args[arg]) if scheme is GridToMeshESMFRegridder: resolution_keyword = _SOURCE_RESOLUTION @@ -396,26 +443,56 @@ def load_regridder(filename): elif scheme is MeshToGridESMFRegridder: resolution_keyword = _TARGET_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} - elif scheme is ESMFAreaWeightedRegridder: + elif method is Constants.Method.CONSERVATIVE: kwargs = { _SOURCE_RESOLUTION: src_resolution, _TARGET_RESOLUTION: tgt_resolution, "mdtol": mdtol, } - elif scheme is ESMFBilinearRegridder: + elif method is Constants.Method.BILINEAR: kwargs = {"mdtol": mdtol} else: kwargs = {} - regridder = scheme( - src_cube, - tgt_cube, - precomputed_weights=weight_matrix, - use_src_mask=use_src_mask, - use_tgt_mask=use_tgt_mask, - esmf_args=esmf_args, - **kwargs, - ) + if scheme is PartialRegridder: + src_slice = cubes.extract_cube(_SRC_SLICE_NAME).data.tolist() + if src_slice == []: + src_slice = None + tgt_slice = cubes.extract_cube(_TGT_SLICE_NAME).data.tolist() + if tgt_slice == []: + tgt_slice = None + sub_scheme = { + Constants.Method.CONSERVATIVE: ESMFAreaWeighted, + Constants.Method.BILINEAR: ESMFBilinear, + }[method] + mdtol = kwargs.pop(_MDTOL, None) + sub_kwargs = {} + if mdtol is not None: + sub_kwargs[_MDTOL] = mdtol + regridder = scheme( + src_cube, + tgt_cube, + src_slice, + tgt_slice, + weight_matrix, + sub_scheme( + use_src_mask=use_src_mask, + use_tgt_mask=use_tgt_mask, + esmf_args=esmf_args, + **sub_kwargs, + ), + **kwargs, + ) + else: + regridder = scheme( + src_cube, + tgt_cube, + precomputed_weights=weight_matrix, + use_src_mask=use_src_mask, + use_tgt_mask=use_tgt_mask, + esmf_args=esmf_args, + **kwargs, + ) esmf_version = weights_cube.attributes[_VERSION_ESMF] regridder.regridder.esmf_version = esmf_version diff --git a/src/esmf_regrid/experimental/partition.py b/src/esmf_regrid/experimental/partition.py new file mode 100644 index 00000000..6513fa3a --- /dev/null +++ b/src/esmf_regrid/experimental/partition.py @@ -0,0 +1,289 @@ +"""Provides an interface for splitting up a large regridding task.""" + +import esmpy +import numpy as np + +from esmf_regrid.constants import Constants +from esmf_regrid.experimental._partial import PartialRegridder +from esmf_regrid.experimental.io import load_regridder, save_regridder +from esmf_regrid.schemes import _get_grid_dims + + +def _get_chunk(cube, sl): + if cube.mesh is None: + grid_dims = _get_grid_dims(cube) + else: + grid_dims = (cube.mesh_dim(),) + full_slice = [np.s_[:]] * len(cube.shape) + for s, d in zip(sl, grid_dims, strict=True): + full_slice[d] = np.s_[s[0] : s[1]] + return cube[*full_slice] + + +def _determine_blocks(shape, chunks, num_chunks, explicit_blocks): + which_inputs = ( + chunks is not None, + num_chunks is not None, + explicit_blocks is not None, + ) + if sum(which_inputs) == 0: + msg = "Partition blocks must must be specified by either chunks, num_chunks, or explicit_chunks." + raise ValueError(msg) + if sum(which_inputs) > 1: + msg = "Potentially conflicting partition block definitions." + raise ValueError(msg) + if num_chunks is not None: + chunks = [s // n for s, n in zip(shape, num_chunks, strict=True)] + for chunk in chunks: + if chunk == 0: + msg = "`num_chunks` cannot divide a dimension into more blocks than the size of that dimension." + raise ValueError(msg) + if chunks is not None: + if all(isinstance(x, int) for x in chunks): + proper_chunks = [] + for s, c in zip(shape, chunks, strict=True): + proper_chunk = [c] * (s // c) + if s % c != 0: + proper_chunk += [s % c] + proper_chunks.append(proper_chunk) + chunks = proper_chunks + for s, chunk in zip(shape, chunks, strict=True): + if sum(chunk) != s: + msg = "Chunks must sum to the size of their respective dimension." + raise ValueError(msg) + bounds = [np.cumsum([0, *chunk]) for chunk in chunks] + if len(bounds) == 1: + msg = "Chunks must have exactly two dimensions." + raise ValueError(msg) + # TODO: This is currently blocked by the fact that slicing an Iris cube on its mesh dimension + # does not currently yield another cube with a mesh. When this is fixed, the following + # code can be uncommented and the noqa on the following line can be removed. + # explicit_blocks = [ + # [[int(lower), int(upper)]] + # for lower, upper in zip(bounds[0][:-1], bounds[0][1:], strict=True) + # ] + elif len(bounds) == 2: # noqa: RET506 + explicit_blocks = [ + [[int(ly), int(uy)], [int(lx), int(ux)]] + for ly, uy in zip(bounds[0][:-1], bounds[0][1:], strict=True) + for lx, ux in zip(bounds[1][:-1], bounds[1][1:], strict=True) + ] + else: + msg = "Chunks must not exceed two dimensions." + raise ValueError(msg) + if len(explicit_blocks[0]) != len(shape): + msg = "Dimensionality of blocks does not match the number of dimensions." + raise ValueError(msg) + return explicit_blocks + + +class Partition: + """Class for breaking down regridding into manageable chunks.""" + + def __init__( + self, + src, + tgt, + scheme, + file_names, + use_dask_src_chunks=False, + src_chunks=None, + num_src_chunks=None, + explicit_src_blocks=None, + auto_generate=False, + saved_files=None, + ): + """Class for breaking down regridding into manageable chunks. + + Note + ---- + The source is partitioned into blocks using one of the four mutually exclusive arguments, + `use_dask_src_chunks`, `src_chunks`, `num_src_chunks`, or `explicit_src_blocks`. These + describe a partition into a number of blocks which must equal the number of `file_names`. + + Currently, it is only possible to divide the source grid into chunks. + Meshes are not yet supported as a source. + + Parameters + ---------- + src : cube + Source cube. + tgt : cube + Target cube. + scheme : regridding scheme + Regridding scheme to generate regridders, either ESMFAreaWeighted or ESMFBilinear. + file_names : iterable of str + A list of file names to save/load parts of the regridder to/from. + use_dask_src_chunks : bool, default=False + If true, partition using the same chunks from the source cube. + src_chunks : numpy array, tuple of int or tuple of tuple of int, default=None + Specify the size of blocks to use to divide up the cube. Dimensions are specified + in y,x axis order. If `src_chunks` is a tuple of int, each integer describes + the maximum size of a block in that dimension. If `src_chunks` is a tuple of tuples, + each sub-tuple describes the size of each successive block in that dimension. The sum + of these block sizes in each of the sub-tuples should add up to the total size of that + dimension or else an error is raised. + num_src_chunks : tuple of int + Specify the number of blocks to use to divide up the cube. Dimensions are specified + in y,x axis order. Each integer describes the number of blocks that dimension will + be divided into. + explicit_src_blocks : arraylike NxMx2 + Explicitly specify the bounds of each block in the partition. Describes N blocks + along M dimensions with a pair of upper and lower bounds. The upper and lower bounds + describe a slice of an array, e.g. the bounds (3, 6) describe the indices 3, 4, 5 in + a particular dimension. + auto_generate : bool, default=False + When true, start generating files on initialisation. + saved_files : iterable of str + A list of paths to previously saved files. + """ + if scheme._method == Constants.Method.NEAREST: + msg = "The `Nearest` method is not implemented." + raise NotImplementedError(msg) + if scheme._method == Constants.Method.BILINEAR: + pole_method = scheme.esmf_args.get("pole_method") + if pole_method != esmpy.PoleMethod.NONE: + msg = ( + "Bilinear regridding must have a `pole_method` of `esmpy.PoleMethod.NONE` in " + "the `esmf_args` in order for Partition to work.`" + ) + raise ValueError(msg) + # TODO: Extract a slice of the cube. + self.src = src + if src.mesh is None: + grid_dims = _get_grid_dims(src) + else: + msg = "Partition does not yet support source meshes." + raise NotImplementedError(msg) + # TODO: This is currently blocked by the fact that slicing an Iris cube on its mesh dimension + # does not currently yield another cube with a mesh. When this is fixed, the following + # code can be uncommented. + # grid_dims = (src.mesh_dim(),) + shape = tuple(src.shape[i] for i in grid_dims) + self.tgt = tgt + self.scheme = scheme + # TODO: consider abstracting away the idea of files + self.file_names = file_names + if use_dask_src_chunks: + if src_chunks is not None: + msg = "`src_chunks` and `use_dask_src_chunks` cannot be used at the same time." + raise ValueError(msg) + if not src.has_lazy_data(): + msg = "If `use_dask_src_chunks=True`, the source cube must be lazy." + raise TypeError(msg) + src_chunks = src.slices(grid_dims).next().lazy_data().chunks + self.src_blocks = _determine_blocks( + shape, src_chunks, num_src_chunks, explicit_src_blocks + ) + if len(self.src_blocks) != len(file_names): + msg = "Number of source blocks does not match number of file names." + raise ValueError(msg) + # This will be controllable in future + tgt_blocks = None + self.tgt_blocks = tgt_blocks + if tgt_blocks is not None: + msg = "Target chunking not yet implemented." + raise NotImplementedError(msg) + + # Note: this may need to become more sophisticated when both src and tgt are large + self.file_block_dict = dict(zip(self.file_names, self.src_blocks, strict=True)) + + if saved_files is None: + self.saved_files = [] + else: + self.saved_files = saved_files + if auto_generate: + self.generate_files(self.file_names) + + def __repr__(self): + """Return a representation of the class.""" + result = ( + f"Partition(" + f"src={self.src!r}, " + f"tgt={self.tgt!r}, " + f"scheme={self.scheme}, " + f"num file_names={len(self.file_names)}," + f"num saved_files={len(self.saved_files)})" + ) + return result + + @property + def unsaved_files(self): + """List of files not yet generated.""" + return [file for file in self.file_names if file not in self.saved_files] + + def generate_files(self, files_to_generate=None): + """Generate files with regridding information. + + Parameters + ---------- + files_to_generate : int, default=None + Specify the number of files to generate, default behaviour is to generate all files. + """ + if files_to_generate is None: + files = self.unsaved_files + else: + if not isinstance(files_to_generate, int): + msg = "`files_to_generate` must be an integer." + raise ValueError(msg) + files = self.unsaved_files[:files_to_generate] + + for file in files: + src_block = self.file_block_dict[file] + src = _get_chunk(self.src, src_block) + tgt = self.tgt + regridder = self.scheme.regridder(src, tgt) + weights = regridder.regridder.weight_matrix + regridder = PartialRegridder( + src, tgt, src_block, None, weights, self.scheme + ) + save_regridder(regridder, file, allow_partial=True) + self.saved_files.append(file) + + def apply_regridders(self, cube, allow_incomplete=False): + """Apply the saved regridders to a cube. + + Parameters + ---------- + allow_incomplete : bool, default=False + If False, raise an error if not all files have been generated. If True, perform + regridding using the files which have been generated. + """ + # for each target chunk, iterate through each associated regridder + # for now, assume one target chunk + if len(self.saved_files) == 0: + msg = "No files have been generated." + raise OSError(msg) + if not allow_incomplete and len(self.unsaved_files) != 0: + msg = "Not all files have been generated." + raise OSError(msg) + current_result = None + current_weights = None + files = self.saved_files + + for file, chunk in zip(self.file_names, self.src_blocks, strict=True): + if file in files: + next_regridder = load_regridder(file, allow_partial=True) + cube_chunk = _get_chunk(cube, chunk) + next_weights, next_result, extra = next_regridder.partial_regrid( + cube_chunk + ) + if current_weights is None: + current_weights = next_weights + else: + current_weights += next_weights + if current_result is None: + current_result = next_result + else: + current_result += next_result + + # NOTE: the final "finish_regridding" operation could be performed using any one + # of the partial regridders,but the correct "corresponding" slice of the source + # must be passed. + # See :meth:`~esmf_regrid.experimental._partial.PartialRegridder.finish_regridding`. + return next_regridder.finish_regridding( + cube_chunk, # matches *this* partial regridder + current_weights, + current_result, + extra, # should be *the same* for all the partial results + ) diff --git a/src/esmf_regrid/schemes.py b/src/esmf_regrid/schemes.py index adfe87b9..25f8c9e1 100644 --- a/src/esmf_regrid/schemes.py +++ b/src/esmf_regrid/schemes.py @@ -606,6 +606,18 @@ def _make_meshinfo(cube_or_mesh, method, mask, src_or_tgt, location=None): return _mesh_to_MeshInfo(mesh, location, mask=mask) +def _get_grid_dims(cube): + src_x = _get_coord(cube, "x") + src_y = _get_coord(cube, "y") + + if len(src_x.shape) == 1: + grid_x_dim = cube.coord_dims(src_x)[0] + grid_y_dim = cube.coord_dims(src_y)[0] + else: + grid_y_dim, grid_x_dim = cube.coord_dims(src_x) + return grid_y_dim, grid_x_dim + + def _regrid_rectilinear_to_rectilinear__prepare( src_grid_cube, tgt_grid_cube, @@ -625,14 +637,8 @@ def _regrid_rectilinear_to_rectilinear__prepare( """ tgt_x = _get_coord(tgt_grid_cube, "x") tgt_y = _get_coord(tgt_grid_cube, "y") - src_x = _get_coord(src_grid_cube, "x") - src_y = _get_coord(src_grid_cube, "y") - if len(src_x.shape) == 1: - grid_x_dim = src_grid_cube.coord_dims(src_x)[0] - grid_y_dim = src_grid_cube.coord_dims(src_y)[0] - else: - grid_y_dim, grid_x_dim = src_grid_cube.coord_dims(src_x) + grid_y_dim, grid_x_dim = _get_grid_dims(src_grid_cube) srcinfo = _make_gridinfo(src_grid_cube, method, src_resolution, src_mask) tgtinfo = _make_gridinfo(tgt_grid_cube, method, tgt_resolution, tgt_mask) @@ -805,8 +811,6 @@ def _regrid_rectilinear_to_unstructured__prepare( The 'regrid info' returned can be reused over many 2d slices. """ - grid_x = _get_coord(src_grid_cube, "x") - grid_y = _get_coord(src_grid_cube, "y") if isinstance(tgt_cube_or_mesh, MeshXY): mesh = tgt_cube_or_mesh location = tgt_location @@ -814,11 +818,7 @@ def _regrid_rectilinear_to_unstructured__prepare( mesh = tgt_cube_or_mesh.mesh location = tgt_cube_or_mesh.location - if grid_x.ndim == 1: - (grid_x_dim,) = src_grid_cube.coord_dims(grid_x) - (grid_y_dim,) = src_grid_cube.coord_dims(grid_y) - else: - grid_y_dim, grid_x_dim = src_grid_cube.coord_dims(grid_x) + grid_y_dim, grid_x_dim = _get_grid_dims(src_grid_cube) meshinfo = _make_meshinfo( tgt_cube_or_mesh, method, tgt_mask, "target", location=tgt_location @@ -1087,6 +1087,7 @@ def __init__( the regridder is saved . """ + self._method = Constants.Method.CONSERVATIVE if not (0 <= mdtol <= 1): msg = "Value for mdtol must be in range 0 - 1, got {}." raise ValueError(msg.format(mdtol)) @@ -1123,6 +1124,7 @@ def regridder( use_tgt_mask=None, tgt_location="face", esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1191,6 +1193,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location="face", esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1240,6 +1243,7 @@ def __init__( the regridder is saved . """ + self._method = Constants.Method.BILINEAR if not (0 <= mdtol <= 1): msg = "Value for mdtol must be in range 0 - 1, got {}." raise ValueError(msg.format(mdtol)) @@ -1274,6 +1278,7 @@ def regridder( tgt_location=None, extrapolate_gaps=False, esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1336,6 +1341,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1389,6 +1395,7 @@ def __init__( arguments are recorded as a property of this regridder and are stored when the regridder is saved . """ + self._method = Constants.Method.NEAREST self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = tgt_location @@ -1415,6 +1422,7 @@ def regridder( use_tgt_mask=None, tgt_location=None, esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1468,6 +1476,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1491,9 +1500,9 @@ def __init__( Parameters ---------- src : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the source grid. + The :class:`~iris.cube.Cube` providing the source grid. tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` - The rectilinear :class:`~iris.cube.Cube` providing the target grid. + The :class:`~iris.cube.Cube` providing the target grid. method : :class:`Constants.Method` The method to be used to calculate weights. mdtol : float, default=None @@ -1566,26 +1575,7 @@ def __init__( else: self._src = GridRecord(_get_coord(src, "x"), _get_coord(src, "y")) - def __call__(self, cube): - """Regrid this :class:`~iris.cube.Cube` onto the target grid of this regridder instance. - - The given :class:`~iris.cube.Cube` must be defined with the same grid as the source - :class:`~iris.cube.Cube` used to create this :class:`_ESMFRegridder` instance. - - Parameters - ---------- - cube : :class:`iris.cube.Cube` - A :class:`~iris.cube.Cube` instance to be regridded. - - Returns - ------- - :class:`iris.cube.Cube` - A :class:`~iris.cube.Cube` defined with the horizontal dimensions of the target - and the other dimensions from this :class:`~iris.cube.Cube`. The data values of - this :class:`~iris.cube.Cube` will be converted to values on the new grid using - regridding via :mod:`esmpy` generated weights. - - """ + def _get_cube_dims(self, cube): if cube.mesh is not None: # TODO: replace temporary hack when iris issues are sorted. # Ignore differences in var_name that might be caused by saving. @@ -1629,6 +1619,29 @@ def __call__(self, cube): else: # Due to structural reasons, the order here must be reversed. dims = cube.coord_dims(new_src_x)[::-1] + return dims + + def __call__(self, cube): + """Regrid this :class:`~iris.cube.Cube` onto the target grid of this regridder instance. + + The given :class:`~iris.cube.Cube` must be defined with the same grid as the source + :class:`~iris.cube.Cube` used to create this :class:`_ESMFRegridder` instance. + + Parameters + ---------- + cube : :class:`iris.cube.Cube` + A :class:`~iris.cube.Cube` instance to be regridded. + + Returns + ------- + :class:`iris.cube.Cube` + A :class:`~iris.cube.Cube` defined with the horizontal dimensions of the target + and the other dimensions from this :class:`~iris.cube.Cube`. The data values of + this :class:`~iris.cube.Cube` will be converted to values on the new grid using + regridding via :mod:`esmpy` generated weights. + + """ + dims = self._get_cube_dims(cube) regrid_info = RegridInfo( dims=dims, @@ -1673,11 +1686,11 @@ def __init__( Parameters ---------- src : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the source. + The :class:`~iris.cube.Cube` providing the source. If this cube has a grid defined by latitude/longitude coordinates, those coordinates must have bounds. tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` - The unstructured :class:`~iris.cube.Cube`or + The :class:`~iris.cube.Cube`or :class:`~iris.mesh.MeshXY` defining the target. If this cube has a grid defined by latitude/longitude coordinates, those coordinates must have bounds. @@ -1760,9 +1773,9 @@ def __init__( Parameters ---------- src : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the source. + The :class:`~iris.cube.Cube` providing the source. tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` - The unstructured :class:`~iris.cube.Cube`or + The :class:`~iris.cube.Cube`or :class:`~iris.mesh.MeshXY` defining the target. mdtol : float, default=0 Tolerance of missing data. The value returned in each element of @@ -1834,9 +1847,9 @@ def __init__( Parameters ---------- src : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the source. + The :class:`~iris.cube.Cube` providing the source. tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` - The unstructured :class:`~iris.cube.Cube`or + The :class:`~iris.cube.Cube`or :class:`~iris.mesh.MeshXY` defining the target. precomputed_weights : :class:`scipy.sparse.spmatrix`, optional If ``None``, :mod:`esmpy` will be used to diff --git a/src/esmf_regrid/tests/unit/experimental/partition/__init__.py b/src/esmf_regrid/tests/unit/experimental/partition/__init__.py new file mode 100644 index 00000000..656fc3a9 --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/partition/__init__.py @@ -0,0 +1 @@ +"""Unit tests for :mod:`esmf_regrid.experimental.partition`.""" diff --git a/src/esmf_regrid/tests/unit/experimental/partition/test_PartialRegridder.py b/src/esmf_regrid/tests/unit/experimental/partition/test_PartialRegridder.py new file mode 100644 index 00000000..d418f66f --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/partition/test_PartialRegridder.py @@ -0,0 +1,56 @@ +"""Unit tests for :mod:`esmf_regrid.experimental.partition`.""" + +import numpy as np + +from esmf_regrid import ESMFAreaWeighted +from esmf_regrid.experimental._partial import PartialRegridder +from esmf_regrid.experimental.io import load_regridder, save_regridder +from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import ( + _grid_cube, +) + + +def test_PartialRegridder_repr(): + """Test repr of PartialRegridder instance.""" + src = _grid_cube(10, 15, (-180, 180), (-90, 90), circular=True) + tgt = _grid_cube(5, 10, (-180, 180), (-90, 90), circular=True) + src_slice = ((10, 20), (15, 30)) + tgt_slice = ((0, 5), (0, 10)) + weights = None + scheme = ESMFAreaWeighted(mdtol=0.5) + + pr = PartialRegridder(src, tgt, src_slice, tgt_slice, weights, scheme) + + expected_repr = ( + "PartialRegridder(src=GridRecord(" + "grid_x=, " + "grid_y=), " + "tgt_slice=GridRecord(grid_x=, " + "grid_y=), " + "src_slice=((10, 20), (15, 30)), tgt_slice=((0, 5), (0, 10)), scheme=ESMFAreaWeighted(mdtol=0.5, " + "use_src_mask=False, use_tgt_mask=False, esmf_args={}))" + ) + assert repr(pr) == expected_repr + + +def test_PartialRegridder_roundtrip(tmp_path): + """Test load/save for PartialRegridder instance.""" + src = _grid_cube(10, 15, (-180, 180), (-90, 90), circular=True) + mask = np.zeros_like(src.data) + mask[0, 0] = 1 + src.data = np.ma.array(src.data, mask=mask) + tgt = _grid_cube(5, 10, (-180, 180), (-90, 90), circular=True) + src_slice = [[10, 20], [15, 30]] + tgt_slice = [[0, 5], [0, 10]] + weights = None + scheme = ESMFAreaWeighted( + mdtol=0.5, use_src_mask=src.data.mask, esmf_args={"ignore_degenerate": True} + ) + + pr = PartialRegridder(src, tgt, src_slice, tgt_slice, weights, scheme) + file = tmp_path / "partial.nc" + + save_regridder(pr, file, allow_partial=True) + loaded_pr = load_regridder(file, allow_partial=True) + + assert repr(loaded_pr) == repr(pr) diff --git a/src/esmf_regrid/tests/unit/experimental/partition/test_Partition.py b/src/esmf_regrid/tests/unit/experimental/partition/test_Partition.py new file mode 100644 index 00000000..ed42231e --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/partition/test_Partition.py @@ -0,0 +1,333 @@ +"""Unit tests for :mod:`esmf_regrid.experimental.partition`.""" + +import dask.array as da +import esmpy +import numpy as np +import pytest + +from esmf_regrid import ESMFAreaWeighted, ESMFBilinear, ESMFNearest +from esmf_regrid.experimental.partition import Partition +from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import ( + _curvilinear_cube, + _grid_cube, +) +from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import ( + _gridlike_mesh_cube, +) +from esmf_regrid.tests.unit.schemes.test_regrid_rectilinear_to_rectilinear import ( + _make_full_cubes, +) + + +def test_Partition(tmp_path): + """Test basic implementation of Partition class.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + src.data = np.arange(150 * 500).reshape([500, 150]) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + blocks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + partition = Partition(src, tgt, scheme, files, explicit_src_blocks=blocks) + + partition.generate_files() + + result = partition.apply_regridders(src) + expected = src.regrid(tgt, scheme) + assert np.allclose(result.data, expected.data) + assert result == expected + + +def test_Partition_block_api(tmp_path): + """Test API for controlling block shape for Partition class.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + num_src_chunks = (5, 1) + partition = Partition(src, tgt, scheme, files, num_src_chunks=num_src_chunks) + + expected_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_chunks + + src_chunks = (100, 150) + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + expected_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_chunks + + src_chunks = ((100, 100, 100, 100, 100), (150,)) + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + expected_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_chunks + + src.data = da.from_array(src.data, chunks=src_chunks) + partition = Partition(src, tgt, scheme, files, use_dask_src_chunks=True) + + expected_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_chunks + + +def test_Partition_mesh_src(tmp_path): + """Test Partition class when the source has a mesh.""" + src = _gridlike_mesh_cube(150, 500) + src.data = np.arange(150 * 500) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + src_chunks = (15000,) + with pytest.raises(NotImplementedError): + _ = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + # TODO: when mesh partitioning becomes possible, uncomment. + # expected_src_chunks = [[[0, 15000]], [[15000, 30000]], [[30000, 45000]], [[45000, 60000]], [[60000, 75000]]] + # assert partition.src_blocks == expected_src_chunks + # + # partition.generate_files() + # + # result = partition.apply_regridders(src) + # expected = src.regrid(tgt, scheme) + # assert np.allclose(result.data, expected.data) + # assert result == expected + + +def test_Partition_curv_src(tmp_path): + """Test Partition class when the source has a curvilinear grid.""" + src = _curvilinear_cube(150, 500, (-180, 180), (-90, 90)) + src.data = np.arange(150 * 500).reshape([500, 150]) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + src_chunks = (100, 150) + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + expected_src_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_src_chunks + + partition.generate_files() + + result = partition.apply_regridders(src) + expected = src.regrid(tgt, scheme) + assert np.allclose(result.data, expected.data) + assert result == expected + + +def test_Partition_bilinear(tmp_path): + """Test Partition class for bilinear regridding.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + src.data = np.arange(150 * 500).reshape([500, 150]) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + src_chunks = (100, 150) + + bad_scheme = ESMFBilinear() + with pytest.raises(ValueError): + _ = Partition(src, tgt, bad_scheme, files, src_chunks=src_chunks) + + # The pole_method must be NONE for bilinear regridding partitions to work. + scheme = ESMFBilinear(esmf_args={"pole_method": esmpy.PoleMethod.NONE}) + + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + partition.generate_files() + + result = partition.apply_regridders(src) + expected = src.regrid(tgt, scheme) + assert np.allclose(result.data, expected.data) + assert result == expected + + +def test_Partition_mesh_tgt(tmp_path): + """Test Partition class when the target has a mesh.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + src.data = np.arange(150 * 500).reshape([500, 150]) + tgt = _gridlike_mesh_cube(16, 36) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + src_chunks = (100, 150) + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + partition.generate_files() + + result = partition.apply_regridders(src) + expected = src.regrid(tgt, scheme) + assert np.allclose(result.data, expected.data) + assert result == expected + + +def test_conflicting_chunks(tmp_path): + """Test error handling of Partition class.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + num_src_chunks = (5, 1) + src_chunks = (100, 150) + blocks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + + with pytest.raises(ValueError): + _ = Partition( + src, + tgt, + scheme, + files, + src_chunks=src_chunks, + num_src_chunks=num_src_chunks, + ) + with pytest.raises(ValueError): + _ = Partition( + src, tgt, scheme, files, src_chunks=src_chunks, explicit_src_blocks=blocks + ) + with pytest.raises(ValueError): + _ = Partition(src, tgt, scheme, files) + with pytest.raises(TypeError): + _ = Partition(src, tgt, scheme, files, use_dask_src_chunks=True) + with pytest.raises(ValueError): + _ = Partition(src, tgt, scheme, files[:-1], src_chunks=src_chunks) + + +def test_multidimensional_cube(tmp_path): + """Test Partition class when the source has a multidimensional cube.""" + src_cube, tgt_grid, expected_cube = _make_full_cubes() + files = [tmp_path / f"partial_{x}.nc" for x in range(4)] + scheme = ESMFAreaWeighted(mdtol=1) + chunks = (2, 3) + + partition = Partition(src_cube, tgt_grid, scheme, files, src_chunks=chunks) + + partition.generate_files() + + result = partition.apply_regridders(src_cube) + + # Lenient check for data. + assert np.allclose(result.data, expected_cube.data) + + # Check metadata and coords. + result.data = expected_cube.data + assert result == expected_cube + + +def test_save_incomplete(tmp_path): + """Test Partition class when a limited number of files are saved.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + src_chunks = (100, 150) + scheme = ESMFAreaWeighted(mdtol=1) + num_initial_chunks = 3 + expected_files = files[:num_initial_chunks] + + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + with pytest.raises(OSError): + _ = partition.apply_regridders(src, allow_incomplete=True) + + partition.generate_files(files_to_generate=num_initial_chunks) + assert partition.saved_files == expected_files + + expected_array_partial = np.ma.zeros([36, 16]) + expected_array_partial[22:] = np.ma.masked + + with pytest.raises(OSError): + _ = partition.apply_regridders(src) + partial_result = partition.apply_regridders(src, allow_incomplete=True) + assert np.ma.allclose(partial_result.data, expected_array_partial) + + loaded_partition = Partition( + src, tgt, scheme, files, src_chunks=src_chunks, saved_files=expected_files + ) + + with pytest.raises(OSError): + _ = loaded_partition.apply_regridders(src) + partial_result_2 = partition.apply_regridders(src, allow_incomplete=True) + assert np.ma.allclose(partial_result_2.data, expected_array_partial) + + loaded_partition.generate_files() + + result = loaded_partition.apply_regridders(src) + expected_array = np.ma.zeros([36, 16]) + assert np.ma.allclose(result.data, expected_array) + + +def test_nearest_invalid(tmp_path): + """Test Partition class when initialised with an invalid scheme.""" + src_cube, tgt_grid, _ = _make_full_cubes() + files = [tmp_path / f"partial_{x}.nc" for x in range(4)] + scheme = ESMFNearest() + chunks = (2, 3) + + with pytest.raises(NotImplementedError): + _ = Partition(src_cube, tgt_grid, scheme, files, src_chunks=chunks) + + +def test_Partition_repr(tmp_path): + """Test repr of Partition instance.""" + src_cube, tgt_grid, _ = _make_full_cubes() + files = [tmp_path / f"partial_{x}.nc" for x in range(4)] + scheme = ESMFAreaWeighted() + chunks = (2, 3) + + partition = Partition(src_cube, tgt_grid, scheme, files, src_chunks=chunks) + + expected_repr = ( + "Partition(src=, " + "tgt=, " + "scheme=ESMFAreaWeighted(mdtol=0, use_src_mask=False, use_tgt_mask=False, esmf_args={}), " + "num file_names=4,num saved_files=0)" + ) + assert repr(partition) == expected_repr diff --git a/src/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py b/src/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py index 548cc1f4..1233b696 100644 --- a/src/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py +++ b/src/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py @@ -1,5 +1,7 @@ """Unit tests for :func:`esmf_regrid.schemes.regrid_rectilinear_to_rectilinear`.""" +from functools import partial + import dask.array as da from iris.coord_systems import RotatedGeogCS from iris.coords import AuxCoord, DimCoord @@ -68,12 +70,19 @@ def test_rotated_regridding(): assert np.allclose(expected_data, full_mdtol_result.data) -def test_extra_dims(): - """Test for :func:`esmf_regrid.schemes.regrid_rectilinear_to_rectilinear`. +def _add_metadata(cube): + result = cube.copy() + result.units = "K" + result.attributes = {"a": 1} + result.standard_name = "air_temperature" + scalar_height = AuxCoord([5], units="m", standard_name="height") + scalar_time = DimCoord([10], units="s", standard_name="time") + result.add_aux_coord(scalar_height) + result.add_aux_coord(scalar_time) + return result - Tests the handling of extra dimensions and metadata. Ensures that proper - coordinates, attributes, names and units are copied over. - """ + +def _make_full_cubes(src_rectilinear=True, tgt_rectilinear=True): h = 2 t = 4 e = 6 @@ -86,13 +95,22 @@ def test_extra_dims(): lon_bounds = (-180, 180) lat_bounds = (-90, 90) - src_grid = _grid_cube( + if src_rectilinear: + src_func = partial(_grid_cube, circular=True) + else: + src_func = _curvilinear_cube + if tgt_rectilinear: + tgt_func = partial(_grid_cube, circular=True) + else: + tgt_func = _curvilinear_cube + src_grid = src_func( src_lons, src_lats, lon_bounds, lat_bounds, ) - tgt_grid = _grid_cube( + + tgt_grid = tgt_func( tgt_lons, tgt_lats, lon_bounds, @@ -110,47 +128,56 @@ def test_extra_dims(): ] src_cube = Cube(src_data) + if src_rectilinear: + src_cube.add_dim_coord(src_grid.coord("latitude"), 1) + src_cube.add_dim_coord(src_grid.coord("longitude"), 3) + else: + src_cube.add_aux_coord(src_grid.coord("latitude"), (1, 3)) + src_cube.add_aux_coord(src_grid.coord("longitude"), (1, 3)) src_cube.add_dim_coord(height, 0) - src_cube.add_dim_coord(src_grid.coord("latitude"), 1) src_cube.add_dim_coord(time, 2) - src_cube.add_dim_coord(src_grid.coord("longitude"), 3) src_cube.add_aux_coord(extra, 4) src_cube.add_aux_coord(spanning, [0, 2, 4]) - def _add_metadata(cube): - result = cube.copy() - result.units = "K" - result.attributes = {"a": 1} - result.standard_name = "air_temperature" - scalar_height = AuxCoord([5], units="m", standard_name="height") - scalar_time = DimCoord([10], units="s", standard_name="time") - result.add_aux_coord(scalar_height) - result.add_aux_coord(scalar_time) - return result - src_cube = _add_metadata(src_cube) - result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) - expected_data = np.empty([h, tgt_lats, t, tgt_lons, e]) expected_data[:] = np.arange(t * h * e).reshape([h, t, e])[ :, np.newaxis, :, np.newaxis, : ] expected_cube = Cube(expected_data) + if tgt_rectilinear: + expected_cube.add_dim_coord(tgt_grid.coord("latitude"), 1) + expected_cube.add_dim_coord(tgt_grid.coord("longitude"), 3) + else: + expected_cube.add_aux_coord(tgt_grid.coord("latitude"), (1, 3)) + expected_cube.add_aux_coord(tgt_grid.coord("longitude"), (1, 3)) expected_cube.add_dim_coord(height, 0) - expected_cube.add_dim_coord(tgt_grid.coord("latitude"), 1) expected_cube.add_dim_coord(time, 2) - expected_cube.add_dim_coord(tgt_grid.coord("longitude"), 3) expected_cube.add_aux_coord(extra, 4) expected_cube.add_aux_coord(spanning, [0, 2, 4]) expected_cube = _add_metadata(expected_cube) + return src_cube, tgt_grid, expected_cube + + +def test_extra_dims(): + """Test for :func:`esmf_regrid.schemes.regrid_rectilinear_to_rectilinear`. + + Tests the handling of extra dimensions and metadata. Ensures that proper + coordinates, attributes, names and units are copied over. + """ + src_cube, tgt_grid, expected_cube = _make_full_cubes( + src_rectilinear=True, tgt_rectilinear=True + ) + + result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) # Lenient check for data. - assert np.allclose(expected_data, result.data) + assert np.allclose(expected_cube.data, result.data) # Check metadata and coords. - result.data = expected_data + result.data = expected_cube.data assert expected_cube == result @@ -266,83 +293,17 @@ def test_extra_dims_curvilinear(): Tests the handling of extra dimensions and metadata. Ensures that proper coordinates, attributes, names and units are copied over. """ - h = 2 - t = 4 - e = 6 - src_lats = 3 - src_lons = 5 - - tgt_lats = 5 - tgt_lons = 3 - - lon_bounds = (-180, 180) - lat_bounds = (-90, 90) - - src_grid = _curvilinear_cube( - src_lons, - src_lats, - lon_bounds, - lat_bounds, - ) - tgt_grid = _curvilinear_cube( - tgt_lons, - tgt_lats, - lon_bounds, - lat_bounds, + src_cube, tgt_grid, expected_cube = _make_full_cubes( + src_rectilinear=False, tgt_rectilinear=False ) - height = DimCoord(np.arange(h), standard_name="height") - time = DimCoord(np.arange(t), standard_name="time") - extra = AuxCoord(np.arange(e), long_name="extra dim") - spanning = AuxCoord(np.ones([h, t, e]), long_name="spanning dim") - - src_data = np.empty([h, src_lats, t, src_lons, e]) - src_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - src_cube = Cube(src_data) - src_cube.add_dim_coord(height, 0) - src_cube.add_aux_coord(src_grid.coord("latitude"), (1, 3)) - src_cube.add_dim_coord(time, 2) - src_cube.add_aux_coord(src_grid.coord("longitude"), (1, 3)) - src_cube.add_aux_coord(extra, 4) - src_cube.add_aux_coord(spanning, [0, 2, 4]) - - def _add_metadata(cube): - result = cube.copy() - result.units = "K" - result.attributes = {"a": 1} - result.standard_name = "air_temperature" - scalar_height = AuxCoord([5], units="m", standard_name="height") - scalar_time = DimCoord([10], units="s", standard_name="time") - result.add_aux_coord(scalar_height) - result.add_aux_coord(scalar_time) - return result - - src_cube = _add_metadata(src_cube) - result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) - expected_data = np.empty([h, tgt_lats, t, tgt_lons, e]) - expected_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - expected_cube = Cube(expected_data) - expected_cube.add_dim_coord(height, 0) - expected_cube.add_aux_coord(tgt_grid.coord("latitude"), (1, 3)) - expected_cube.add_dim_coord(time, 2) - expected_cube.add_aux_coord(tgt_grid.coord("longitude"), (1, 3)) - expected_cube.add_aux_coord(extra, 4) - expected_cube.add_aux_coord(spanning, [0, 2, 4]) - expected_cube = _add_metadata(expected_cube) - # Lenient check for data. - assert np.allclose(expected_data, result.data) + assert np.allclose(expected_cube.data, result.data) # Check metadata and coords. - result.data = expected_data + result.data = expected_cube.data assert expected_cube == result @@ -352,83 +313,17 @@ def test_extra_dims_curvilinear_to_rectilinear(): Tests the handling of extra dimensions and metadata. Ensures that proper coordinates, attributes, names and units are copied over. """ - h = 2 - t = 4 - e = 6 - src_lats = 3 - src_lons = 5 - - tgt_lats = 5 - tgt_lons = 3 - - lon_bounds = (-180, 180) - lat_bounds = (-90, 90) - - src_grid = _curvilinear_cube( - src_lons, - src_lats, - lon_bounds, - lat_bounds, - ) - tgt_grid = _grid_cube( - tgt_lons, - tgt_lats, - lon_bounds, - lat_bounds, + src_cube, tgt_grid, expected_cube = _make_full_cubes( + src_rectilinear=False, tgt_rectilinear=True ) - height = DimCoord(np.arange(h), standard_name="height") - time = DimCoord(np.arange(t), standard_name="time") - extra = AuxCoord(np.arange(e), long_name="extra dim") - spanning = AuxCoord(np.ones([h, t, e]), long_name="spanning dim") - - src_data = np.empty([h, src_lats, t, src_lons, e]) - src_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - src_cube = Cube(src_data) - src_cube.add_dim_coord(height, 0) - src_cube.add_aux_coord(src_grid.coord("latitude"), (1, 3)) - src_cube.add_dim_coord(time, 2) - src_cube.add_aux_coord(src_grid.coord("longitude"), (1, 3)) - src_cube.add_aux_coord(extra, 4) - src_cube.add_aux_coord(spanning, [0, 2, 4]) - - def _add_metadata(cube): - result = cube.copy() - result.units = "K" - result.attributes = {"a": 1} - result.standard_name = "air_temperature" - scalar_height = AuxCoord([5], units="m", standard_name="height") - scalar_time = DimCoord([10], units="s", standard_name="time") - result.add_aux_coord(scalar_height) - result.add_aux_coord(scalar_time) - return result - - src_cube = _add_metadata(src_cube) - result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) - expected_data = np.empty([h, tgt_lats, t, tgt_lons, e]) - expected_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - expected_cube = Cube(expected_data) - expected_cube.add_dim_coord(height, 0) - expected_cube.add_dim_coord(tgt_grid.coord("latitude"), 1) - expected_cube.add_dim_coord(time, 2) - expected_cube.add_dim_coord(tgt_grid.coord("longitude"), 3) - expected_cube.add_aux_coord(extra, 4) - expected_cube.add_aux_coord(spanning, [0, 2, 4]) - expected_cube = _add_metadata(expected_cube) - # Lenient check for data. - assert np.allclose(expected_data, result.data) + assert np.allclose(expected_cube.data, result.data) # Check metadata and coords. - result.data = expected_data + result.data = expected_cube.data assert expected_cube == result @@ -438,81 +333,15 @@ def test_extra_dims_rectilinear_to_curvilinear(): Tests the handling of extra dimensions and metadata. Ensures that proper coordinates, attributes, names and units are copied over. """ - h = 2 - t = 4 - e = 6 - src_lats = 3 - src_lons = 5 - - tgt_lats = 5 - tgt_lons = 3 - - lon_bounds = (-180, 180) - lat_bounds = (-90, 90) - - src_grid = _grid_cube( - src_lons, - src_lats, - lon_bounds, - lat_bounds, + src_cube, tgt_grid, expected_cube = _make_full_cubes( + src_rectilinear=True, tgt_rectilinear=False ) - tgt_grid = _curvilinear_cube( - tgt_lons, - tgt_lats, - lon_bounds, - lat_bounds, - ) - - height = DimCoord(np.arange(h), standard_name="height") - time = DimCoord(np.arange(t), standard_name="time") - extra = AuxCoord(np.arange(e), long_name="extra dim") - spanning = AuxCoord(np.ones([h, t, e]), long_name="spanning dim") - - src_data = np.empty([h, src_lats, t, src_lons, e]) - src_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - src_cube = Cube(src_data) - src_cube.add_dim_coord(height, 0) - src_cube.add_dim_coord(src_grid.coord("latitude"), 1) - src_cube.add_dim_coord(time, 2) - src_cube.add_dim_coord(src_grid.coord("longitude"), 3) - src_cube.add_aux_coord(extra, 4) - src_cube.add_aux_coord(spanning, [0, 2, 4]) - - def _add_metadata(cube): - result = cube.copy() - result.units = "K" - result.attributes = {"a": 1} - result.standard_name = "air_temperature" - scalar_height = AuxCoord([5], units="m", standard_name="height") - scalar_time = DimCoord([10], units="s", standard_name="time") - result.add_aux_coord(scalar_height) - result.add_aux_coord(scalar_time) - return result - - src_cube = _add_metadata(src_cube) result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) - expected_data = np.empty([h, tgt_lats, t, tgt_lons, e]) - expected_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - expected_cube = Cube(expected_data) - expected_cube.add_dim_coord(height, 0) - expected_cube.add_aux_coord(tgt_grid.coord("latitude"), (1, 3)) - expected_cube.add_dim_coord(time, 2) - expected_cube.add_aux_coord(tgt_grid.coord("longitude"), (1, 3)) - expected_cube.add_aux_coord(extra, 4) - expected_cube.add_aux_coord(spanning, [0, 2, 4]) - expected_cube = _add_metadata(expected_cube) - # Lenient check for data. - assert np.allclose(expected_data, result.data) + assert np.allclose(expected_cube.data, result.data) # Check metadata and coords. - result.data = expected_data + result.data = expected_cube.data assert expected_cube == result