Source code for cubedsphere.regrid

"""
Module for regridding cubedsphere datasets
Inspired from https://github.com/JiaweiZhuang/cubedsphere/blob/master/example_notebooks/C2L_regrid.ipynb
"""
import numpy as np
import time
import warnings
import xarray as xr
import xesmf as xe

import cubedsphere.const as c
from .grid import init_grid_CS
from .utils import _flatten_ds

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)


[docs]class Regridder: """ Class that wraps the xESMF regridder for cs geometry. Only two methods are possible with the cs geometry: conservative when using concat_mode=False (requires lon_b to have different shape from lon and lat_b from lat) or nearest_s2d when using concat_mode=True. Conservative regridding should be used if possible! Attributes ---------- regridder: list or xESMF regrid object (contains) the initialized xESMF regridder grid: dict output grid Examples -------- >>> import cubedsphere as cs # import cubedsphere >>> outdir = "../run" # specify output directory >>> # open Dataset >>> ds_ascii, grid = cs.open_ascii_dataset(outdir_ascii, iters='all', prefix = ["T","U","V","W"]) >>> # regrid dataset >>> regrid = cs.Regridder(ds_ascii, grid) >>> ds_reg = regrid() Notes ----- You can find more examples in the examples directory """
[docs] def __init__(self, ds, cs_grid, input_type="cs", d_lon=5, d_lat=4, concat_mode=False, filename="weights", method='conservative', **kwargs): """ Build the regridder. This step will create the output grid and weight files which will then be used to regrid the dataset. Parameters ---------- ds: xarray DataSet Dataset to be regridded. Dataset must contain input grid information. cs_grid: xarray DataSet Dataset containing cubedsphere grid information. If input_type=="cs": Input grid. Required! If input_type=="ll": Output cs grid. Required! input_type: string Needs to be "cs" or "ll". Will result in an error if something else is used here. If "cs": input=cs grid -> regrid to longitude lattitude (ll) If "ll": input=ll grid -> regrid to cs grid d_lon: integer Longitude step size, i.e. grid resolution. Only used if input_type=="cs". d_lat: integer Latitude step size, i.e. grid resolution. Only used if input_type=="cs". concat_mode: boolean use one regridding instance instead of one regridder for each face. Only used if input_type=="cs". filename: string filename for weights (weights will be name filename(+ _tile{i}).nc) method: string Regridding method. See xe.Regridder for options. kwargs : Optional parameters that are passed to xe.Regridder (see xe.Regridder for options). """ t = time.time() self._ds = ds self._input_type = input_type if self._input_type not in ["cs", "ll"]: raise NotImplementedError( f"wrong input_type={self._input_type}. You need to either use input_type='cs' or input_type='ll'.") if self._input_type == "cs": self._ds_grid_in = cs_grid self.grid = self._build_output_grid(d_lon, d_lat) self._concat_mode = concat_mode else: self.grid = cs_grid self._ds_grid_in = self._ds self._concat_mode = False self._method = method if self._concat_mode: self._build_regridder_concat(filename, **kwargs) else: self._build_regridder_faces(filename, **kwargs) print(f"time needed to build regridder: {time.time() - t}") print(f"Regridder will use {self._method} method") if self._method not in ["patch", "conservative"]: print("Caution: The regridding method that you chose might not conserve fluxes") if self._method not in ["conservative", "nearest_s2d"]: print( "Caution: The regridding method that you chose might return 0's on borders, double check by plotting the dataset")
def _build_regridder_faces(self, filename, **kwargs): """ Wrapper that builds one regridder for every cs-face. Can be used with input_grid=="cs" and input_grid="concat_mode" Parameters ---------- filename: string filename for weights (weights will be name filename(+ _tile{i}).nc) kwargs: Optional parameters that are passed to xe.Regridder (see xe.Regridder for options). """ if np.all(self._ds_grid_in[c.lat].shape == self._ds_grid_in[c.lat_b].shape): if self._input_type == "cs": self._method = "nearest_s2d" self._concat_mode = True self._build_regridder_concat(filename, **kwargs) print("falling back to concat mode. The ds you provide has no outer coordinates.") return else: raise NotImplementedError( "We need the outer grid corner information to proceed. Pleas provide correct lon_b and lat_b values!") if self._input_type == "cs": self._grid_in = [None] * 6 for i in range(6): self._grid_in[i] = {'lat': self._ds_grid_in[c.lat].isel(**{c.FACEDIM: i}), 'lon': self._ds_grid_in[c.lon].isel(**{c.FACEDIM: i}), 'lat_b': self._ds_grid_in[c.lat_b].isel(**{c.FACEDIM: i}), 'lon_b': self._ds_grid_in[c.lon_b].isel(**{c.FACEDIM: i})} if self._method in ["nearest_s2d", "nearest_d2s"]: self._method = "conservative" print("falling back to conservative. Nearest neighbour methods aint working for `concat_mode=False`") self.regridder = [ xe.Regridder(self._grid_in[i], self.grid, filename=f"{filename}_tile{i + 1}.nc", method=self._method, **kwargs) for i in range(6)] else: # case of input_type=ll try: self._grid_in = {'lat': self._ds_grid_in[c.lat], 'lon': self._ds_grid_in[c.lon], 'lat_b': self._ds_grid_in[c.lat_b], 'lon_b': self._ds_grid_in[c.lon_b]} except KeyError: raise KeyError( f"You need the following horizontal dimensions with matching name in your input dataset: {c.lat}, {c.lon}, {c.lat_b}, {c.lon_b}") self.regridder = [ xe.Regridder(self._grid_in, self.grid.isel(**{c.FACEDIM: i}), filename=f"{filename}_tile{i + 1}.nc", method=self._method, **kwargs) for i in range(6)] def _build_regridder_concat(self, filename, **kwargs): """ Wrapper that builds one regridder for the complete cs grid (ie by appending to one horizontal dimension). Only to be used for input_type=="cs". Parameters ---------- filename: string filename for weights (weights will be name filename(+ _tile{i}).nc) kwargs: Optional parameters that are passed to xe.Regridder (see xe.Regridder for options). """ if self._input_type != "cs": raise NotImplementedError("concat mode does only work with input_type='cs'") if self._method != "nearest_s2d": if np.all(self._ds_grid_in[c.lon].shape != self._ds_grid_in[c.lon_b].shape): self._method = "conservative" self._concat_mode = False print( f"falling back to {self._method} and `concat_mode={self._concat_mode}: The interpolation method you chose doesn't work with your grid geometry") self._build_regridder_faces(filename, **kwargs) return else: self._method = "nearest_s2d" print( f"falling back to {self._method}: The interpolation method you chose doesn't work with your grid geometry") self._grid_in = {'lat': _flatten_ds(self._ds_grid_in[c.lat]), 'lon': _flatten_ds(self._ds_grid_in[c.lon])} self.regridder = xe.Regridder(self._grid_in, self.grid, filename=f"{filename}.nc", method=self._method, periodic=False, **kwargs) def _build_output_grid(self, d_lon, d_lat): """ Function that is used to build an output longitude-lattitude (ll) grid for input_type=="cs". Parameters ---------- d_lon: int longitude step size d_lat: int lattitude step size """ grid = xe.util.grid_global(d_lon, d_lat) grid_LL = {'lat': grid["lat"][:, 0].values, 'lon': grid["lon"][0, :].values, 'lat_b': grid["lat_b"][:, 0].values, 'lon_b': grid["lon_b"][0, :].values} return grid_LL
[docs] def __call__(self, **kwargs): """ Wrapper that carries out the regridding from cubedsphere to latlon. Parameters ---------- Returns ------- ds: xarray DataSet regridded Dataset """ # initialize an empty dataset ds = xr.Dataset() # We do not want to regrid grid values to_not_regrid_scalar = [c.lon_b, c.lon, c.lat_b, c.lat, c.drF, c.drC, c.drS, c.dxG, c.dxC, c.drW, c.dyC, c.dyG, c.dxF, c.dyU, c.dxV, c.dyF, c.HFacC, c.HFacW, c.HFacS, c.rAz, c.rA, c.rAw, c.rAs, c.AngleSN, c.AngleCS, c.FACEDIM] _all_vectors = [] for var in self._ds: if var not in to_not_regrid_scalar: if "mate" in self._ds[var].attrs: _all_vectors.append(var) u_vectors = [] v_vectors = [] for vector in _all_vectors: if c.i_g in self._ds[vector].dims: u_vectors.append(vector) v_vectors.append(self._ds[vector].attrs["mate"]) # We need to rotate scalar values first to_not_regrid_scalar = to_not_regrid_scalar + _all_vectors # init grid to interp edge quantities to center if self._input_type == "cs": grid = init_grid_CS(ds=self._ds) face_vis_data = np.zeros((len(self._ds[c.FACEDIM]), len(self._ds[c.i]), len(self._ds[c.j]))) for i in self._ds[c.FACEDIM]: face_vis_data[i] = i self._ds["face_vis"] = xr.DataArray(face_vis_data, coords=[self._ds[c.FACEDIM], self._ds[c.i], self._ds[c.j]], dims=[c.FACEDIM, c.i, c.j]) # We first need interpolate quantites to the cell center (if nescessary) reg_all = np.all(self._ds[c.i].shape != self._ds[c.i_g].shape) for data in set(self._ds.data_vars) - set(to_not_regrid_scalar): dims = self._ds[data].dims if c.i_g in dims and c.j_g not in dims and reg_all: interp = grid.interp(self._ds[data], to="center", axis=c.i) elif c.j_g in dims and c.i_g not in dims and reg_all: interp = grid.interp(self._ds[data], to="center", axis=c.j) elif c.i_g in dims and c.j_g in dims and reg_all: interp = grid.interp(self._ds[data], to="center", axis=[c.i, c.j]) elif c.i_g not in dims and c.j_g not in dims: interp = self._ds[data] else: interp = None if interp is not None: ds[data] = self._regrid_wrapper(interp, **kwargs) # Regridding for vectors for v in range(len(u_vectors)): try: # interpolate vectors to cell centers: interp_UV = grid.interp_2d_vector( vector={c.i: self._ds[u_vectors[v]], c.j: self._ds[v_vectors[v]]}, to="center") # rotate vectors geographic direction: vector_E, vector_N = self._rotate_vector_to_EN(interp_UV[c.i], interp_UV[c.j], self._ds[c.AngleCS], self._ds[c.AngleSN]) # perform the regridding: ds[u_vectors[v]] = self._regrid_wrapper(vector_E, **kwargs) ds[v_vectors[v]] = self._regrid_wrapper(vector_N, **kwargs) except KeyError: pass ds = ds.reset_coords(c.FACEDIM) else: # case of input_grid = "ll" for data in set(self._ds.data_vars) - set(to_not_regrid_scalar): ds[data] = self._regrid_wrapper(self._ds[data], **kwargs) if len(list(set(_all_vectors).intersection(set(self._ds.data_vars)))) > 0: print("WARNING: You have some vector quantities in your input dataset. We can not regrid those yet.") # xESMF names longitude lon and latitude lat. We want to rename it to whatever we set in const.py to be consistent ds = ds.rename({"lon": c.lon, "lat": c.lat}) # clean up weight files (see xESMF doc). Somehow not working in my xESMF version... # for regridder_i in self.regridder: # regridder_i.clean_weight_file() return ds
def _regrid_wrapper(self, ds_in, **kwargs): """ wrapper to regrid general scalar dataarray. Caution: Horizontal dimensions must be the last two dimensions! Parameters ---------- ds_in: xarray DataSet data to be regridded **kwargs additional parameters to be passed to regridding call Returns ---------- numpy array: regridded data """ if self._input_type == "cs": if len(ds_in.shape) == 5: data_out = np.zeros([ds_in.shape[1], ds_in.shape[2], self.grid['lat'].size, self.grid['lon'].size]) elif len(ds_in.shape) == 4: data_out = np.zeros([ds_in.shape[1], self.grid['lat'].size, self.grid['lon'].size]) elif len(ds_in.shape) == 3: data_out = np.zeros([self.grid['lat'].size, self.grid['lon'].size]) else: if c.FACEDIM in ds_in.dims: assert np.all(ds_in.isel(**{c.FACEDIM: 0}) == ds_in.isel( **{c.FACEDIM: 1})), "you have a really messed up input dataset!" return ds_in[0] else: return ds_in if self._concat_mode: data_out = self.regridder(_flatten_ds(ds_in), **kwargs) else: for i in range(6): # add up the results for 6 tiles data_out += self.regridder[i](ds_in.isel(**{c.FACEDIM: i}), **kwargs) else: data_out_list = [] for i in range(6): # add up the results for 6 tiles data_out_list.append(self.regridder[i](ds_in, **kwargs)) data_out = xr.concat(data_out_list, dim=c.FACEDIM) return data_out def _rotate_vector_to_EN(self, U, V, AngleCS, AngleSN): """ rotate vector to east north direction. Assumes that AngleCS and AngleSN are already of same dimension as V and U (i.e. already interpolated to cell center) Parameters ---------- U: xarray Dataarray zonal vector component V: xarray Dataarray meridional vector component AngleCS: xarray Dataarray Cosine of angle of the grid center relative to the geographic direction AngleSN: xarray Dataarray Sine of angle of the grid center relative to the geographic direction Returns ---------- uE: xarray Dataarray rotated zonal velocity vN: xarray Dataarray rotated meridional velocity """ # rotate the vectors: uE = AngleCS * U - AngleSN * V vN = AngleSN * U + AngleCS * V # reorder coordinates: uE = uE.transpose(..., c.j, c.i) vN = vN.transpose(..., c.j, c.i) return uE, vN