Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenworsley committed Mar 30, 2022
1 parent 9b2986d commit c473e0d
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 40 deletions.
16 changes: 8 additions & 8 deletions esmf_regrid/_esmf_sdo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def shape(self):
return self._shape

@property
def _extended_shape(self):
def _refined_shape(self):
"""Return shape passed to ESMF."""
return self._shape

Expand All @@ -52,9 +52,9 @@ def size(self):
return np.prod(self._shape)

@property
def _extended_size(self):
def _refined_size(self):
"""Return the number of cells passed to ESMF."""
return np.prod(self._extended_shape)
return np.prod(self._refined_shape)

@property
def index_offset(self):
Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(
)

def _as_esmf_info(self):
shape = np.array(self._extended_shape)
shape = np.array(self._refined_shape)

londims = len(self.lons.shape)

Expand Down Expand Up @@ -373,7 +373,7 @@ def __init__(
self.lat_expansion = 1

@property
def _extended_shape(self):
def _refined_shape(self):
"""Return shape passed to ESMF."""
return (
self.n_lats_orig * self.lat_expansion,
Expand All @@ -396,7 +396,7 @@ def _collapse_weights(self, is_tgt):
True if the target field is being represented, False otherwise.
"""
# The column indices represent each of the cells in the refined grid.
column_indices = np.arange(self._extended_size)
column_indices = np.arange(self._refined_size)

# The row indices represent the cells of the unrefined grid. These are broadcast
# so that each row index coincides with all column indices of the refined cells
Expand All @@ -418,10 +418,10 @@ def _collapse_weights(self, is_tgt):
[self.n_lons_orig, self.n_lats_orig]
)[:, np.newaxis, :]
row_indices = row_indices.flatten()
matrix_shape = (self.size, self._extended_size)
matrix_shape = (self.size, self._refined_size)
refinement_weights = scipy.sparse.csr_matrix(
(
np.ones(self._extended_size),
np.ones(self._refined_size),
(row_indices, column_indices),
),
shape=matrix_shape,
Expand Down
2 changes: 1 addition & 1 deletion esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
)
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
(self.tgt._extended_size, self.src._extended_size),
(self.tgt._refined_size, self.src._refined_size),
(self.tgt.index_offset, self.src.index_offset),
)
if isinstance(tgt, RefinedGridInfo):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,19 +367,6 @@ def test_resolution():
lat_bounds = (-90, 90)
grid = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True)

h = 2
t = 3
height = DimCoord(np.arange(h), standard_name="height")
time = DimCoord(np.arange(t), standard_name="time")

src_data = np.empty([t, n_lats, n_lons, h])
src_data[:] = np.arange(t * h).reshape([t, h])[:, np.newaxis, np.newaxis, :]
cube = Cube(src_data)
cube.add_dim_coord(grid.coord("latitude"), 1)
cube.add_dim_coord(grid.coord("longitude"), 2)
cube.add_dim_coord(time, 0)
cube.add_dim_coord(height, 3)

resolution = 8

result = GridToMeshESMFRegridder(grid, tgt, resolution=resolution)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,22 +330,7 @@ def test_resolution():
Tests for the resolution keyword.
"""
mesh = _full_mesh()
mesh_length = mesh.connectivity(contains_face=True).shape[0]

h = 2
t = 3
height = DimCoord(np.arange(h), standard_name="height")
time = DimCoord(np.arange(t), standard_name="time")

src_data = np.empty([t, mesh_length, h])
src_data[:] = np.arange(t * h).reshape([t, h])[:, np.newaxis, :]
mesh_cube = Cube(src_data)
mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face")
mesh_cube.add_aux_coord(mesh_coord_x, 1)
mesh_cube.add_aux_coord(mesh_coord_y, 1)
mesh_cube.add_dim_coord(time, 0)
mesh_cube.add_dim_coord(height, 2)
mesh_cube = _flat_mesh_cube()

lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_resolution():
n_lats = 4
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
src = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True)
src = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds)
# Ensure data in the target grid is different to the expected data.
# i.e. target grid data is all zero, expected data is all one
tgt.data[:] = 0
Expand All @@ -261,6 +261,7 @@ def test_resolution():
expected_cube = _add_metadata(tgt)

# Lenient check for data.
# Note that when resolution=None, this would be a fully masked array.
assert np.allclose(expected_data, result.data)

# Check metadata and scalar coords.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_resolution():
n_lats = 5
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True)
tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds)
# Ensure data in the target grid is different to the expected data.
# i.e. target grid data is all zero, expected data is all one
tgt.data[:] = 0
Expand All @@ -204,6 +204,7 @@ def test_resolution():
expected_cube = _add_metadata(tgt)

# Lenient check for data.
# Note that when resolution=None, this would be a fully masked array.
assert np.allclose(expected_data, result.data)

# Check metadata and scalar coords.
Expand Down

0 comments on commit c473e0d

Please sign in to comment.