Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start of plotting "help" module #258

Merged
merged 23 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions clouddrift/utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
This module contains various utility functions.
philippemiron marked this conversation as resolved.
Show resolved Hide resolved
"""

from clouddrift.ragged import segment, rowsize_to_index
import numpy as np
import pandas as pd
from typing import Optional, Union
import xarray as xr
import pandas as pd
from typing import Optional, Tuple, Union
from clouddrift.ragged import segment, rowsize_to_index, subset

try:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from matplotlib import cm
except ImportError:
raise ImportError("missing optional dependency 'matplotlib'")


def plot_ragged(
ax: plt.Axes,
longitude: Union[list, np.ndarray, pd.Series, xr.DataArray],
latitude: Union[list, np.ndarray, pd.Series, xr.DataArray],
rowsize: Union[list, np.ndarray, pd.Series, xr.DataArray],
colors: Optional[Union[list, np.ndarray, pd.Series, xr.DataArray]] = None,
*args,
tolerance: Optional[Union[float, int]] = 180,
**kwargs,
):
"""Function that wraps matplotlib plot function (plt.plot) and LineCollection
(matplotlib.collections) to efficiently plot trajectories from a ragged array dataset.

Parameters
----------
longitude : array-like
Longitude sequence. Unidimensional array input.
latitude : array-like
Latitude sequence. Unidimensional array input.
rowsize : list
List of integers specifying the number of data points in each row.
colors : array-like
Colors to use for plotting. If colors is the same shape as longitude and latitude,
the trajectories are splitted into segments and each segment is colored according
to the corresponding color value. If colors is the same shape as rowsize, the
trajectories are uniformly colored according to the corresponding color value.
*args : tuple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it *kwargs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**kwargs are below. LGTM.

Additional arguments to pass to `plt.plot`.
tolerance : float
Tolerance gap between data points (in degrees) for segmenting trajectories. For periodic
philippemiron marked this conversation as resolved.
Show resolved Hide resolved
domains, the tolerance parameter should be set to the maximum allowed gap
between data points. Defaults to 180.
**kwargs : dict
Additional keyword arguments to pass to `plt.plot`.

Returns
-------
list of matplotlib.lines.Line2D or matplotlib.collections.LineCollection
The plotted lines or line collection. Can be used to set a colorbar
after plotting or extract information from the lines.

Examples
--------

Plot the first 100 trajectories from the gdp1h dataset, assigning
a different color to each trajectory:

>>> from clouddrift import datasets
>>> ds = datasets.gdp1h()
>>> ds = subset(ds, {"ID": ds.ID[:100].values}).load()
>>> fig = plt.figure()
>>> ax = fig.add_subplot(1, 1, 1)

>>> time = [v.astype(np.int64) / 86400 / 1e9 for v in ds.time.values]
philippemiron marked this conversation as resolved.
Show resolved Hide resolved
>>> plot_ragged(
>>> ax,
>>> ds.lon,
>>> ds.lat,
>>> ds.rowsize,
>>> colors=np.arange(len(ds.rowsize))
>>> )

To plot the same trajectories, but assigning a different color to each
observation and specifying a colormap:

>>> time = [v.astype(np.int64) / 86400 / 1e9 for v in ds.time.values]
>>> fig = plt.figure()
>>> ax = fig.add_subplot(1, 1, 1)
>>> lc = plot_ragged(
>>> ax,
>>> ds.lon,
>>> ds.lat,
>>> ds.rowsize,
>>> colors=np.floor(time),
>>> cmap="inferno"
>>> )
>>> fig.colorbar(lc[0])
>>> ax.set_xlim([-180, 180])
>>> ax.set_ylim([-90, 90])

Finally, to plot the same trajectories, but using a cartopy
projection:

>>> fig = plt.figure()
>>> ax = fig.add_subplot(1, 1, 1, projection=ccrs.Mollweide())
>>> time = [v.astype(np.int64) / 86400 / 1e9 for v in ds.time.values]
philippemiron marked this conversation as resolved.
Show resolved Hide resolved
>>> lc = plot_ragged(
>>> ax,
>>> ds.lon,
>>> ds.lat,
>>> ds.rowsize,
>>> colors=np.arange(len(ds.rowsize)),
>>> transform=ccrs.PlateCarree(),
>>> cmap=cmocean.cm.ice,
>>> )

Raises
------
ValueError
If longitude and latitude arrays do not have the same shape.
If colors do not have the same shape as longitude and latitude arrays or rowsize.
If ax is not a matplotlib Axes or GeoAxes object.
If ax is a GeoAxes object and the transform keyword argument is not provided.

ImportError
If matplotlib is not installed.
If the axis is a GeoAxes object and cartopy is not installed.
"""

if hasattr(ax, "coastlines"): # check if GeoAxes without cartopy
try:
from cartopy.mpl.geoaxes import GeoAxes

if isinstance(ax, GeoAxes) and not kwargs.get("transform"):
raise ValueError(
"For GeoAxes, the transform keyword argument must be provided."
)
except ImportError:
raise ImportError("missing optional dependency 'cartopy'")
elif not isinstance(ax, plt.Axes):
raise ValueError("ax must be either: plt.Axes or GeoAxes.")

if np.sum(rowsize) != len(longitude):
raise ValueError("The sum of rowsize must equal the length of lon and lat.")

if len(longitude) != len(latitude):
raise ValueError("lon and lat must have the same length.")

if colors is None:
colors = np.arange(len(rowsize))
elif colors is not None and (len(colors) not in [len(longitude), len(rowsize)]):
raise ValueError("shape colors must match the shape of lon/lat or rowsize.")

# define a colormap
cmap = kwargs.pop("cmap", cm.viridis)

# define a normalization obtain uniform colors
# for the sequence of lines or LineCollection
norm = kwargs.pop(
"norm", mcolors.Normalize(vmin=np.nanmin(colors), vmax=np.nanmax(colors))
)

mpl_plot = True if colors is None or len(colors) == len(rowsize) else False
traj_idx = rowsize_to_index(rowsize)

lines = []
for i in range(len(rowsize)):
lon_i, lat_i = (
longitude[traj_idx[i] : traj_idx[i + 1]],
latitude[traj_idx[i] : traj_idx[i + 1]],
)

start = 0
for length in segment(lon_i, tolerance, rowsize=segment(lon_i, -tolerance)):
end = start + length

if mpl_plot:
line = ax.plot(
lon_i[start:end],
lat_i[start:end],
c=cmap(norm(colors[i])) if colors is not None else None,
*args,
**kwargs,
)
else:
colors_i = colors[traj_idx[i] : traj_idx[i + 1]]
segments = np.column_stack(
[
lon_i[start : end - 1],
lat_i[start : end - 1],
lon_i[start + 1 : end],
lat_i[start + 1 : end],
]
).reshape(-1, 2, 2)
line = LineCollection(segments, cmap=cmap, norm=norm, *args, **kwargs)
line.set_array(
# color of a segment is the average of its two data points
np.convolve(colors_i[start:end], [0.5, 0.5], mode="valid")
)
ax.add_collection(line)

start = end
lines.append(line)

return lines
149 changes: 149 additions & 0 deletions tests/utility_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import cartopy.crs as ccrs
from clouddrift.utility import plot_ragged
import matplotlib.pyplot as plt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a mechanism to skip these tests if optional dependencies are not installed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I assumed that it was installed in the ci/cd, but I see that this wouldn't work if you don't have those packages locally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try again with the latest modification?

import numpy as np
import sys
import unittest
from unittest.mock import patch

if __name__ == "__main__":
unittest.main()


class utility_tests(unittest.TestCase):
@classmethod
def setUpClass(self):
"""
Create trajectories example
"""
self.lon = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
self.lat = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
self.rowsize = [3, 3, 4]

def test_lonlatrowsize(self):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
lon_test = np.append(self.lon, 3)
rowsize_test = np.append(self.rowsize, 3)
with self.assertRaises(ValueError):
plot_ragged(ax, lon_test, self.lat, self.rowsize)
plot_ragged(ax, self.lon, self.lat, rowsize_test)

def test_axis(self):
ax = 1
with self.assertRaises(ValueError):
plot_ragged(ax, self.lon, self.lat, self.rowsize)

def test_plot(self):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
l = plot_ragged(
ax, self.lon, self.lat, self.rowsize, colors=np.arange(len(self.rowsize))
)
self.assertIsInstance(l, list)

def test_plot_cartopy_transform(self):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
with self.assertRaises(ValueError):
l = plot_ragged(
ax,
self.lon,
self.lat,
self.rowsize,
colors=np.arange(len(self.rowsize)),
)

def test_plot_cartopy(self):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.Mollweide())
l = plot_ragged(
ax,
self.lon,
self.lat,
self.rowsize,
colors=np.arange(len(self.rowsize)),
transform=ccrs.PlateCarree(),
)
self.assertIsInstance(l, list)

def test_plot_segments(self):
self.lon = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
self.lat = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
self.rowsize = [3, 3, 4]

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.Mollweide())
l = plot_ragged(
ax,
self.lon,
self.lat,
self.rowsize,
colors=np.arange(len(self.rowsize)),
transform=ccrs.PlateCarree(),
)
self.assertIsInstance(l, list)
self.assertEqual(len(l), 3)

def test_plot_segments_split(self):
self.lon = [-170, -175, -180, 175, 170]
self.lat = [0, 1, 2, 3, 4]
self.rowsize = [5]

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.Mollweide())
l = plot_ragged(
ax,
self.lon,
self.lat,
self.rowsize,
colors=np.arange(len(self.rowsize)),
transform=ccrs.PlateCarree(),
)
self.assertIsInstance(l, list)
self.assertEqual(len(l), 2)

def test_plot_segments_split_domain(self):
self.lon = [-1, -2, -3, 3, 2, 1]
self.lat = [0, 1, 2, 3, 4, 5]
self.rowsize = [6]

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.Mollweide())
l = plot_ragged(
ax,
self.lon,
self.lat,
self.rowsize,
colors=np.arange(len(self.rowsize)),
transform=ccrs.PlateCarree(),
tolerance=5,
)
self.assertIsInstance(l, list)
self.assertEqual(len(l), 2)

def test_matplotlib_not_installed(self):
del sys.modules["clouddrift.utility"]
with patch.dict(sys.modules, {"matplotlib": None}):
with self.assertRaises(ImportError):
from clouddrift.utility import plot_ragged
# reload for other tests
from clouddrift.utility import plot_ragged

def test_cartopy_not_installed(self):
del sys.modules["clouddrift.utility"]
with patch.dict(sys.modules, {"cartopy": None}):
from clouddrift.utility import plot_ragged

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
plot_ragged(
ax,
self.lon,
self.lat,
self.rowsize,
colors=np.arange(len(self.rowsize)),
)

# reload for other tests
from clouddrift.utility import plot_ragged