Skip to content

Commit

Permalink
update(plotting): check for user set axes limits
Browse files Browse the repository at this point in the history
* added method to check for user supplied axes limits and set with default extents if autoscaling is on
* update docstring in StructuredGrid.neighbors
* add tests for axes limit checks and scaling
  • Loading branch information
jlarsen-usgs committed Feb 16, 2024
1 parent 626563a commit 664ab37
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 21 deletions.
47 changes: 47 additions & 0 deletions autotest/test_plot_cross_section.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from modflow_devtools.markers import requires_pkg

Expand Down Expand Up @@ -181,3 +182,49 @@ def test_cross_section_invalid_line_representations_fail(line):
grid = structured_square_grid(side=10)
with pytest.raises(ValueError):
flopy.plot.PlotCrossSection(modelgrid=grid, line={"line": line})


def test_plot_limits():
xymin, xymax = 0, 1000
cellsize = 50
nrow = (xymax - xymin) // cellsize
ncol = nrow
nlay = 1

delc = np.full((nrow,), cellsize)
delr = np.full((ncol,), cellsize)

top = np.full((nrow, ncol), 100)
botm = np.full((nlay, nrow, ncol), 0)
idomain = np.ones(botm.shape, dtype=int)

grid = flopy.discretization.StructuredGrid(
delc=delc, delr=delr, top=top, botm=botm, idomain=idomain
)

fig, ax = plt.subplots()
user_extent = 0, 500, 0, 25
ax.axis(user_extent)

pxc = flopy.plot.PlotCrossSection(
modelgrid=grid, ax=ax, line={"column": 4}
)
pxc.plot_grid()

lims = ax.axes.viewLim
if (lims.x0, lims.x1, lims.y0, lims.y1) != user_extent:
raise AssertionError("PlotMapView not checking for user scaling")

plt.close(fig)

fig, ax = plt.subplots(figsize=(8, 8))
pxc = flopy.plot.PlotCrossSection(
modelgrid=grid, ax=ax, line={"column": 4}
)
pxc.plot_grid()

lims = ax.axes.viewLim
if (lims.x0, lims.x1, lims.y0, lims.y1) != pxc.extent:
raise AssertionError("PlotMapView auto extent setting not working")

plt.close(fig)
42 changes: 42 additions & 0 deletions autotest/test_plot_map_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,45 @@ def test_map_view_contour_array_structured(function_tmpdir, ndim, rng):
# for ix, lev in enumerate(contours.levels):
# if not np.allclose(lev, levels[ix]):
# raise AssertionError("TriContour NaN catch Failed")


def test_plot_limits():
xymin, xymax = 0, 1000
cellsize = 50
nrow = (xymax - xymin) // cellsize
ncol = nrow
nlay = 1

delc = np.full((nrow,), cellsize)
delr = np.full((ncol,), cellsize)

top = np.full((nrow, ncol), 100)
botm = np.full((nlay, nrow, ncol), 0)
idomain = np.ones(botm.shape, dtype=int)

grid = flopy.discretization.StructuredGrid(
delc=delc, delr=delr, top=top, botm=botm, idomain=idomain
)

fig, ax = plt.subplots()
user_extent = 0, 300, 0, 100
ax.axis(user_extent)

pmv = flopy.plot.PlotMapView(modelgrid=grid, ax=ax)
pmv.plot_grid()

lims = ax.axes.viewLim
if (lims.x0, lims.x1, lims.y0, lims.y1) != user_extent:
raise AssertionError("PlotMapView not checking for user scaling")

plt.close(fig)

fig, ax = plt.subplots(figsize=(8, 8))
pmv = flopy.plot.PlotMapView(modelgrid=grid, ax=ax)
pmv.plot_grid()

lims = ax.axes.viewLim
if (lims.x0, lims.x1, lims.y0, lims.y1) != pmv.extent:
raise AssertionError("PlotMapView auto extent setting not working")

plt.close(fig)
2 changes: 1 addition & 1 deletion flopy/discretization/structuredgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def neighbors(self, *args, **kwargs):
row number
j : int
column number
as_node : bool
as_nodes : bool
flag to return neighbors as node numbers
method : str
"rook" for shared edge neighbors (default) "queen" for shared
Expand Down
34 changes: 24 additions & 10 deletions flopy/plot/crosssection.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def __init__(
self._masked_values = [model.hnoflo, model.hdry]

# Set axis limits
self.ax.set_xlim(self.extent[0], self.extent[1])
self.ax.set_ylim(self.extent[2], self.extent[3])
self._set_axes_limits(self.ax)

@staticmethod
def _is_valid(line):
Expand Down Expand Up @@ -357,6 +356,25 @@ def get_extent(self):

return xmin, xmax, ymin, ymax

def _set_axes_limits(self, ax):
"""
Internal method to set axes limits
Parameters
----------
ax : matplotlib.pyplot axis
The plot axis
Returns
-------
ax : matplotlib.pyplot axis object
"""
if ax.get_autoscale_on():
ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
return ax

def plot_array(self, a, masked_values=None, head=None, **kwargs):
"""
Plot a three-dimensional array as a patch collection.
Expand Down Expand Up @@ -402,8 +420,7 @@ def plot_array(self, a, masked_values=None, head=None, **kwargs):
pc = self.get_grid_patch_collection(a, projpts, **kwargs)
if pc is not None:
ax.add_collection(pc)
ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
ax = self._set_axes_limits(ax)

return pc

Expand Down Expand Up @@ -464,8 +481,7 @@ def plot_surface(self, a, masked_values=None, **kwargs):
)
surface.append(line)

ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
ax = self._set_axes_limits(ax)

return surface

Expand Down Expand Up @@ -523,8 +539,7 @@ def plot_fill_between(
)
if pc is not None:
ax.add_collection(pc)
ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
ax = self._set_axes_limits(ax)

return pc

Expand Down Expand Up @@ -659,8 +674,7 @@ def contour_array(self, a, masked_values=None, head=None, **kwargs):
if plot_triplot:
ax.triplot(triang, color="black", marker="o", lw=0.75)

ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
ax = self._set_axes_limits(ax)

return contour_set

Expand Down
38 changes: 28 additions & 10 deletions flopy/plot/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ def extent(self):
self._extent = self.mg.extent
return self._extent

def _set_axes_limits(self, ax):
"""
Internal method to set axes limits
Parameters
----------
ax : matplotlib.pyplot axis
The plot axis
Returns
-------
ax : matplotlib.pyplot axis object
"""
if ax.get_autoscale_on():
ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
return ax

def plot_array(self, a, masked_values=None, **kwargs):
"""
Plot an array. If the array is three-dimensional, then the method
Expand Down Expand Up @@ -152,8 +171,7 @@ def plot_array(self, a, masked_values=None, **kwargs):
ax.add_collection(collection)

# set limits
ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
ax = self._set_axes_limits(ax)
return collection

def contour_array(self, a, masked_values=None, tri_mask=False, **kwargs):
Expand Down Expand Up @@ -284,7 +302,7 @@ def contour_array(self, a, masked_values=None, tri_mask=False, **kwargs):
for ix, nodes in enumerate(triangles):
neighbors = self.mg.neighbors(nodes[i], as_nodes=True)
isin = np.isin(nodes[i + 1 :], neighbors)
if not np.all(isin):
if not np.alltrue(isin):
mask[ix] = True

if ismasked is not None:
Expand All @@ -305,8 +323,7 @@ def contour_array(self, a, masked_values=None, tri_mask=False, **kwargs):
if plot_triplot:
ax.triplot(triang, color="black", marker="o", lw=0.75)

ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
ax = self._set_axes_limits(ax)

return contour_set

Expand Down Expand Up @@ -423,8 +440,7 @@ def plot_grid(self, **kwargs):
collection = LineCollection(grid_lines, colors=colors, **kwargs)

ax.add_collection(collection)
ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])
ax = self._set_axes_limits(ax)
return collection

def plot_bc(
Expand Down Expand Up @@ -605,6 +621,7 @@ def plot_shapes(self, obj, **kwargs):
"""
ax = kwargs.pop("ax", self.ax)
patch_collection = plotutil.plot_shapefile(obj, ax, **kwargs)
ax = self._set_axes_limits(ax)
return patch_collection

def plot_vector(
Expand Down Expand Up @@ -701,6 +718,7 @@ def plot_vector(
# these are vectors not locations
urot, vrot = geometry.rotate(u, v, 0.0, 0.0, self.mg.angrot_radians)
quiver = ax.quiver(x, y, urot, vrot, pivot=pivot, **kwargs)
ax = self._set_axes_limits(ax)
return quiver

def plot_pathline(self, pl, travel_time=None, **kwargs):
Expand Down Expand Up @@ -845,9 +863,7 @@ def plot_pathline(self, pl, travel_time=None, **kwargs):
)

# set axis limits
ax.set_xlim(self.extent[0], self.extent[1])
ax.set_ylim(self.extent[2], self.extent[3])

ax = self._set_axes_limits(ax)
return lc

def plot_timeseries(self, ts, travel_time=None, **kwargs):
Expand Down Expand Up @@ -989,4 +1005,6 @@ def plot_endpoint(
if createcb:
cb = plt.colorbar(sp, ax=ax, shrink=shrink)
cb.set_label(colorbar_label)

ax = self._set_axes_limits(ax)
return sp

0 comments on commit 664ab37

Please sign in to comment.