Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Plot updates & extensions #319

Merged
merged 23 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
97fd872
add multi_time_series plot
TomDonoghue Aug 29, 2023
b6b7b6e
add & use helper function to checking plot time definition
TomDonoghue Aug 29, 2023
00d119a
add ticks & labels as catchable axis style args
TomDonoghue Aug 29, 2023
974213c
add special case to set minorticks on/off in axis_styler
TomDonoghue Aug 29, 2023
e0677b2
add combined ts & psd plot
TomDonoghue Aug 31, 2023
edf6da2
fix import
TomDonoghue Aug 31, 2023
63cbf33
udpate description of style_plot
TomDonoghue Sep 5, 2023
45f6f10
add collection styler
TomDonoghue Sep 5, 2023
19effe1
allow for turning off style funcs with None
TomDonoghue Sep 5, 2023
7fbdce1
fix collection styler
TomDonoghue Sep 5, 2023
b96650e
fix example
TomDonoghue Sep 5, 2023
235cd7b
fix doctest for checking style args
TomDonoghue Sep 6, 2023
942db13
udpate style management of tight_layout
TomDonoghue Feb 14, 2024
7f24d08
fix plot style doctest
TomDonoghue Feb 14, 2024
0099a6e
round doctests for checking
TomDonoghue Feb 14, 2024
cf503c9
add new plots to init
TomDonoghue Feb 15, 2024
a31759c
add new plts to API list
TomDonoghue Feb 15, 2024
52e9ec1
fix extra **'s in docs
TomDonoghue Feb 15, 2024
c497d1c
drop combined from init due to circular import
TomDonoghue Feb 15, 2024
6ce3df2
drop times from combined plot (& compute internally)
TomDonoghue Feb 15, 2024
6ac7c0f
fix up start_val and ts_range
TomDonoghue Feb 15, 2024
b551ef1
add to init & avoid circular imports
TomDonoghue Feb 20, 2024
0189891
tweak times related inputs to plot_combined
TomDonoghue Feb 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions neurodsp/filt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def compute_transition_band(f_db, db, low=-20, high=-3):
>>> from neurodsp.filt.fir import design_fir_filter
>>> filter_coefs = design_fir_filter(fs=500, pass_type='bandpass', f_range=(1, 25))
>>> f_db, db = compute_frequency_response(filter_coefs, 1, fs=500)
>>> compute_transition_band(f_db, db, low=-20, high=-3)
>>> round(compute_transition_band(f_db, db, low=-20, high=-3), 1)
0.5

Compute the transition band of an IIR filter, using the computed frequency response:
Expand All @@ -171,7 +171,7 @@ def compute_transition_band(f_db, db, low=-20, high=-3):
>>> sos = design_iir_filter(fs=500, pass_type='bandstop',
... f_range=(10, 20), butterworth_order=7)
>>> f_db, db = compute_frequency_response(sos, None, fs=500)
>>> compute_transition_band(f_db, db, low=-20, high=-3)
>>> round(compute_transition_band(f_db, db, low=-20, high=-3), 1)
2.0
"""

Expand Down
63 changes: 63 additions & 0 deletions neurodsp/plts/combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Plotting functions for plots with combined panels."""

import matplotlib.pyplot as plt

from neurodsp.spectral import compute_spectrum
from neurodsp.spectral.utils import trim_spectrum
from neurodsp.plts.spectral import plot_power_spectra
from neurodsp.plts.time_series import plot_time_series
from neurodsp.plts.utils import savefig

###################################################################################################
###################################################################################################

@savefig
def plot_timeseries_and_spectrum(times, sig, fs, f_range=None, spectrum_kwargs=None,
TomDonoghue marked this conversation as resolved.
Show resolved Hide resolved
ts_kwargs=None, psd_kwargs=None, **plt_kwargs):
"""Plot a timeseries together with it's associated power spectrum.

Parameters
----------
times : 1d array, or None
Time definition(s) for the time series to be plotted.
If None, time series will be plotted in terms of samples instead of time.
sigs : 1d array
Time series to plot.
fs : float
Sampling rate, in Hz.
f_range : list of [float, float], optional
The frequency range to restrict the power spectrum to.
**spectrum_kwargs : dict, optional
Keyword arguments for computing the power spectrum.
See `compute_spectrum` for details.
**ts_kwargs : dict, optional
Keyword arguments for customizing the time series plot.
**psd_kwargs : dict, optional
TomDonoghue marked this conversation as resolved.
Show resolved Hide resolved
Keyword arguments for customizing the power spectrum plot.
**plt_kwargs
Keyword arguments for customizing the plots.
These arguments are passed to both plot axes.
"""

# Allow for defining color as 'color' (since one line per plot), rather than 'colors'
if 'color' in plt_kwargs:
plt_kwargs['colors'] = plt_kwargs.pop('color')

# Default to drawing both plots in same color (otherwise ts is black, psd is blue)
if 'colors' not in plt_kwargs:
psd_kwargs = {} if psd_kwargs is None else psd_kwargs
if 'colors' not in psd_kwargs:
psd_kwargs['colors'] = 'black'

fig = plt.figure(figsize=plt_kwargs.pop('figsize', None))
ax1 = fig.add_axes([0.0, 0.6, 1.3, 0.5])
ax2 = fig.add_axes([1.5, 0.6, 0.6, 0.5])

plot_time_series(times, sig, ax=ax1, **plt_kwargs,
**ts_kwargs if ts_kwargs else {})

freqs, psd = compute_spectrum(sig, fs, **spectrum_kwargs if spectrum_kwargs else {})
if f_range:
freqs, psd = trim_spectrum(freqs, psd, f_range)
plot_power_spectra(freqs, psd, ax=ax2, **plt_kwargs,
**psd_kwargs if psd_kwargs else {})
18 changes: 15 additions & 3 deletions neurodsp/plts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,27 @@

## Define collections of style arguments
# Plot style arguments are those that can be defined on an axis object
AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim']
AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim',
'xticks', 'yticks', 'xticklabels', 'yticklabels',
'minorticks']

# Line style arguments are those that can be defined on a line object
LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle',
'marker', 'ms', 'markersize']

# Collection style arguments are those that can be defined on a collections object
COLLECTION_STYLE_ARGS = ['alpha', 'edgecolor']

# Custom style arguments are those that are custom-handled by the plot style function
CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize',
'legend_size', 'legend_loc']
'legend_size', 'legend_loc', 'tight_layout']

# Define list of available style functions - these can also be replaced by arguments
STYLERS = ['axis_styler', 'line_styler', 'custom_styler']
STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS

# Collect the full set of possible style related input keyword arguments
STYLE_ARGS = \
AXIS_STYLE_ARGS + LINE_STYLE_ARGS + COLLECTION_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS

## Define default values for aesthetics
# These are all custom style arguments
Expand Down
83 changes: 66 additions & 17 deletions neurodsp/plts/style.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Functions and utilities to apply aesthetic styling to plots."""

import warnings
from itertools import cycle
from functools import wraps

import matplotlib.pyplot as plt

from neurodsp.plts.settings import AXIS_STYLE_ARGS, LINE_STYLE_ARGS, CUSTOM_STYLE_ARGS, STYLE_ARGS
from neurodsp.plts.settings import (LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC,
TICK_LABELSIZE, TITLE_FONTSIZE)
from neurodsp.plts.settings import (AXIS_STYLE_ARGS, LINE_STYLE_ARGS, COLLECTION_STYLE_ARGS,
CUSTOM_STYLE_ARGS, STYLE_ARGS, TICK_LABELSIZE, TITLE_FONTSIZE,
LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC)

###################################################################################################
###################################################################################################
Expand All @@ -16,9 +17,10 @@ def check_style_options():
"""Check the list of valid style arguments that can be passed into plot functions."""

print('Valid style arguments:')
for label, options in zip(['Axis', 'Line', 'Custom'],
[AXIS_STYLE_ARGS, LINE_STYLE_ARGS, CUSTOM_STYLE_ARGS]):
print(' ', label, '\t', ', '.join(options))
for label, options in zip(['Axis', 'Line', 'Collection', 'Custom'],
[AXIS_STYLE_ARGS, LINE_STYLE_ARGS,
COLLECTION_STYLE_ARGS, CUSTOM_STYLE_ARGS]):
print(' {:10s} {}'.format(label, ', '.join(options)))


def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs):
Expand All @@ -34,9 +36,15 @@ def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs):
Keyword arguments that define plot style to apply.
"""

axis_kwargs = {key : val for key, val in kwargs.items() if key in style_args}

# Special case: catch and apply minorticks being set to True or False
mtick_dict = {True : 'minorticks_on', False : 'minorticks_off'}
if 'minorticks' in axis_kwargs:
getattr(ax, mtick_dict[axis_kwargs.pop('minorticks')])()

# Apply any provided axis style arguments
plot_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
ax.set(**plot_kwargs)
ax.set(**axis_kwargs)


def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs):
Expand Down Expand Up @@ -69,6 +77,27 @@ def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs):
line.set(**{style : next(values)})


def apply_collection_style(ax, style_args=COLLECTION_STYLE_ARGS, **kwargs):
"""Apply collection plot style.

Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
style_args : list of str
A list of arguments to be sub-selected from `kwargs` and applied as collection styling.
**kwargs
Keyword arguments that define collection style to apply.
"""

# Get the collection related styling arguments from the keyword arguments
collection_kwargs = {key : val for key, val in kwargs.items() if key in style_args}

# Apply any provided collection style arguments
for collection in ax.collections:
collection.set(**collection_kwargs)


def apply_custom_style(ax, **kwargs):
"""Apply custom plot style.

Expand Down Expand Up @@ -98,18 +127,22 @@ def apply_custom_style(ax, **kwargs):
ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)},
loc=kwargs.pop('legend_loc', LEGEND_LOC))

plt.tight_layout()
if kwargs.pop('tight_layout', True):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plt.tight_layout()


def plot_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
custom_styler=apply_custom_style, **kwargs):
collection_styler=apply_collection_style, custom_styler=apply_custom_style,
**kwargs):
"""Apply plot style to a figure axis.

Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
axis_styler, line_styler, custom_styler : callable, optional
axis_styler, line_styler, collection_styler, custom_styler : callable, optional
Functions to apply style to aspects of the plot.
**kwargs
Keyword arguments that define style to apply.
Expand All @@ -120,9 +153,10 @@ def plot_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
Each of these sub-functions can be replaced by passing in a replacement callable.
"""

axis_styler(ax, **kwargs)
line_styler(ax, **kwargs)
custom_styler(ax, **kwargs)
axis_styler(ax, **kwargs) if axis_styler is not None else None
line_styler(ax, **kwargs) if line_styler is not None else None
collection_styler(ax, **kwargs) if collection_styler is not None else None
custom_styler(ax, **kwargs) if custom_styler is not None else None


def style_plot(func, *args, **kwargs):
Expand All @@ -147,9 +181,24 @@ def style_plot(func, *args, **kwargs):
By default, this function applies styling with the `plot_style` function. Custom
functions for applying style can be passed in using `plot_style` as a keyword argument.

The `plot_style` function calls sub-functions for applying style different plot elements,
and these sub-functions can be overridden by passing in alternatives for `axis_styler`,
`line_styler`, and `custom_styler`.
The `plot_style` function calls sub-functions for applying different plot elements, including:

- `axis_styler`: apply style options to an axis
- `line_styler`: applies style options to lines objects in a plot
- `collection_styler`: applies style options to collections objects in a plot
- `custom_style`: applies custom style options

Each of these sub-functions can be overridden by passing in alternatives.

To see the full set of style arguments that are supported, run the following code:

>>> from neurodsp.plts.style import check_style_options
>>> check_style_options()
Valid style arguments:
Axis title, xlabel, ylabel, xlim, ylim, xticks, yticks, xticklabels, yticklabels, minorticks
Line alpha, lw, linewidth, ls, linestyle, marker, ms, markersize
Collection alpha, edgecolor
Custom title_fontsize, label_size, tick_labelsize, legend_size, legend_loc, tight_layout
"""

@wraps(func)
Expand Down
61 changes: 53 additions & 8 deletions neurodsp/plts/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs):
labels : list of str, optional
Labels for each time series.
colors : str or list of str
Colors to use to plot lines.
Color(s) to use to plot lines.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**kwargs
Expand All @@ -50,13 +50,7 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs):
ax = check_ax(ax, kwargs.pop('figsize', (15, 3)))

sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs

xlabel = 'Time (s)'
if times is None:
times = create_samples(len(sigs[0]))
xlabel = 'Samples'

times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times
times, xlabel = _check_times(times, sigs)

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
Expand Down Expand Up @@ -158,3 +152,54 @@ def plot_bursts(times, sig, bursting, ax=None, **kwargs):

bursts = np.ma.array(sig, mask=np.invert(bursting))
plot_time_series(times, [sig, bursts], ax=ax, **kwargs)


@savefig
@style_plot
def plot_multi_time_series(times, sigs, colors=None, ax=None, **plt_kwargs):
TomDonoghue marked this conversation as resolved.
Show resolved Hide resolved
"""Plot multiple time series, with a vertical offset.

Parameters
----------
times : 1d or 2d array, or list of 1d array, or None
Time definition(s) for the time series to be plotted.
If None, time series will be plotted in terms of samples instead of time.
sigs : 2d array or list of 1d array
Time series to plot, each list or row of the array representing a different channel.
colors : str or list of str
Color(s) to use to plot lines.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**kwargs
Keyword arguments for customizing the plot.
"""

colors = 'black' if not colors else colors
colors = repeat(colors) if isinstance(colors, str) else iter(colors)

ax = check_ax(ax, figsize=plt_kwargs.pop('figsize', (15, 5)))

sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs
times, xlabel = _check_times(times, sigs)

step = 0.8 * np.ptp(sigs[0])

for ind, (time, sig) in enumerate(zip(times, sigs)):
ax.plot(time, sig+step*ind, color=next(colors), **plt_kwargs)

ax.set(yticks=[])
ax.set_xlabel(xlabel)
ax.set_ylabel('Channels')


def _check_times(times, sigs):
"""Helper function to check a times definition passed into a time series plot function."""

xlabel = 'Time (s)'
if times is None:
times = create_samples(len(sigs[0]))
xlabel = 'Samples'

times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times

return times, xlabel
29 changes: 29 additions & 0 deletions neurodsp/tests/plts/test_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Tests for neurodsp.plts.combined."""

from neurodsp.utils import create_times

from neurodsp.tests.settings import TEST_PLOTS_PATH, N_SECONDS, FS
from neurodsp.tests.tutils import plot_test

from neurodsp.plts.combined import *

###################################################################################################
###################################################################################################

@plot_test
def test_plot_timeseries_and_spectrum(tsig_comb):

times = create_times(N_SECONDS, FS)

plot_timeseries_and_spectrum(times, tsig_comb, FS,
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_combined_ts_psd.png')

# Test customizations
plot_timeseries_and_spectrum(times, tsig_comb, FS,
f_range=(3, 50), color='blue',
spectrum_kwargs={'nperseg' : 500},
ts_kwargs={'xlim' : [0, 5]},
psd_kwargs={'lw' : 2},
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_combined_ts_psd2.png')
18 changes: 18 additions & 0 deletions neurodsp/tests/plts/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,21 @@ def test_plot_bursts(tsig_burst):
plot_bursts(times, tsig_burst, bursts,
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_bursts.png')

@plot_test
def test_plot_multi_time_series(tsig_comb):

times = create_times(N_SECONDS, FS)

plot_multi_time_series(times, [tsig_comb, tsig_comb],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_multi_time_series-1.png')

plot_multi_time_series(times, np.array([tsig_comb, tsig_comb, tsig_comb]),
colors=['red', 'green', 'blue'],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_multi_time_series-2.png')

plot_multi_time_series([times, times[:-500]], [tsig_comb, tsig_comb[:-500]],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_multi_time_series-3.png')
Loading