Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Plotting updates #343

Merged
merged 18 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ Spectral
:toctree: generated/

plot_power_spectra
plot_spectra_3D
plot_scv
plot_scv_rs_lines
plot_scv_rs_matrix
Expand Down Expand Up @@ -465,6 +466,15 @@ Time Frequency

plot_timefrequency

Aperiodic
~~~~~~~~~

.. currentmodule:: neurodsp.plts
.. autosummary::
:toctree: generated/

plot_autocorr

Combined
~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions neurodsp/plts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from .spectral import (plot_power_spectra, plot_spectral_hist,
plot_scv, plot_scv_rs_lines, plot_scv_rs_matrix)
from .timefrequency import plot_timefrequency
from .aperiodic import plot_autocorr
from .combined import plot_timeseries_and_spectra
35 changes: 35 additions & 0 deletions neurodsp/plts/aperiodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Plotting functions for neurodsp.aperiodic."""

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot

####################################################################################################
####################################################################################################

@savefig
@style_plot
def plot_autocorr(timepoints, autocorrs, labels=None, colors=None, ax=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only thought here is whether to include it in aperiodic.py or somewhere else, since acf isn't restricted to aperiodic signals.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh.... agreed. I wasn't sure where to put it, but since we already have neurodsp/aperiodic/autocorr as the home of the function to compute autocorrelation, this seemed like the most consistent spot for the plot function. I don't know what a better name / place for these things is - so unless we want to re-org, move both I think this works best for now?

"""Plot autocorrelation results.

Parameters
----------
timepoints : 1d array
Time points, in samples, at which autocorrelations are computed.
autocorrs : array
Autocorrelation values, across time lags.
labels : str or list of str, optional
Labels for each time series.
colors : str or list of str
Colors to use to plot lines.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**kwargs
Keyword arguments for customizing the plot.
"""

ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 5)))

for time, ac, label, color in zip(*prepare_multi_plot(timepoints, autocorrs, labels, colors)):
ax.plot(time, ac, label=label, color=color)

ax.set(xlabel='Lag (Samples)', ylabel='Autocorrelation')
83 changes: 68 additions & 15 deletions neurodsp/plts/spectral.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Plotting functions for neurodsp.spectral."""

from itertools import repeat, cycle

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig
from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -47,18 +45,8 @@ def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None, **kwarg

ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 6)))

freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs
powers = [powers] if isinstance(powers, np.ndarray) and powers.ndim == 1 else powers

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = repeat(labels)

colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)

for freq, power, color, label in zip(freqs, powers, colors, labels):
ax.loglog(freq, power, color=color, label=label)
for freq, power, label, color in zip(*prepare_multi_plot(freqs, powers, labels, colors)):
ax.loglog(freq, power, label=label, color=color)

ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power ($V^2/Hz$)')
Expand Down Expand Up @@ -235,3 +223,68 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None,
if spectrum is not None:
plt_inds = np.logical_and(spectrum_freqs >= freqs[0], spectrum_freqs <= freqs[-1])
ax.plot(spectrum_freqs[plt_inds], np.log10(spectrum[plt_inds]), color='w', alpha=0.8)


@savefig
@style_plot
def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True,
TomDonoghue marked this conversation as resolved.
Show resolved Hide resolved
colors=None, orientation=(20, -50), **kwargs):
"""Plot a series of power spectra in a 3D plot.

Parameters
----------
freqs : 1d or 2d array or list of 1d array
Frequency vector.
powers : 2d array or list of 1d array
Power values.
log_freqs : bool, optional, default: False
Whether to plot the frequency values in log10 space.
log_powers : bool, optional, default: True
Whether to plot the power values in log10 space.
colors : str or list of str
Colors to use to plot lines.
orientation : tuple of int
Orientation to set the 3D plot.
**kwargs
Keyword arguments for customizing the plot.

Examples
--------
Plot power spectra in 3D:

>>> from neurodsp.sim import sim_combined
>>> from neurodsp.spectral import compute_spectrum
>>> sig1 = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'exponent' : -1},
... 'sim_bursty_oscillation' : {'freq': 10}})
>>> sig2 = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'exponent' : -1.5},
... 'sim_bursty_oscillation' : {'freq': 10}})
>>> freqs1, powers1 = compute_spectrum(sig1, fs=500)
>>> freqs2, powers2 = compute_spectrum(sig2, fs=500)
>>> plot_spectra_3D([freqs1, freqs2], [powers1, powers2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the z-axis label is cutoff when I copy/paste the example into a notebook. I tried plt.tight_layout() but it gave an error about the size of the fig not having large enough margins. Maybe a solution would be to add an ax or fig kwarg

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, this seems to be an issue with how jupyter notebooks do inline plotting:
matplotlib/matplotlib#28117

When saving out, the figure looks fine, so I don't think we are messing up the plot per se. I'm not sure if there's an action item here - I don't think we want to overtune too much to address a quirk in notebooks.

More broadly, I do think it makes sense to add more access to ax / fig - so I've added a check_ax_3D that will allow for managing / specifying 3D axes.

Also, it seems with the current notebook quirk, changing the zoom can fix the render (https://stackoverflow.com/questions/77577613/matplotlib-3d-plot-z-label-cut-off), so, for example the following make the whole plot visible: plt.gca().set_box_aspect(None, zoom=0.80)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an argument to plot_spectra_3d to be able to pass in a zoom argument, so now in the notebook case one can set the scaling directly as wanted when defining the plot

"""

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

n_spectra = len(powers)

for ind, (freq, power, _, color) in \
enumerate(zip(*prepare_multi_plot(freqs, powers, None, colors))):
ax.plot(xs=np.log10(freq) if log_freqs else freq,
ys=[ind] * len(freq),
zs=np.log10(power) if log_powers else power,
color=color,
**kwargs)

ax.set(
xlabel='Frequency (Hz)',
ylabel='Channels',
zlabel='Power',
ylim=[0, n_spectra - 1],
)

yticks = list(range(n_spectra))
ax.set_yticks(yticks, yticks)
ax.view_init(*orientation)
33 changes: 13 additions & 20 deletions neurodsp/plts/time_series.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Plots for time series."""

from itertools import repeat, cycle
from itertools import repeat

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig
from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot
from neurodsp.utils.data import create_samples
from neurodsp.utils.checks import check_param_options

Expand Down Expand Up @@ -49,18 +49,12 @@

ax = check_ax(ax, kwargs.pop('figsize', (15, 3)))

sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs
times, xlabel = _check_times(times, sigs)

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = repeat(labels)
times, sigs, colors, labels = prepare_multi_plot(times, sigs, colors, labels)

# If not provided, default colors for up to two signals to be black & red
if not colors and len(sigs) <= 2:
if isinstance(colors, repeat) and next(colors) is None and len(sigs) <= 2:
colors = ['k', 'r']
colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)

for time, sig, color, label in zip(times, sigs, colors, labels):
ax.plot(time, sig, color=color, label=label)
Expand Down Expand Up @@ -174,32 +168,31 @@
Keyword arguments for customizing the plot.
"""

colors = 'black' if not colors else colors
colors = repeat(colors) if isinstance(colors, str) else iter(colors)

ax = check_ax(ax, figsize=plt_kwargs.pop('figsize', (15, 5)))

sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs
colors = 'black' if not colors else colors

times, xlabel = _check_times(times, sigs)
times, sigs, _, colors = prepare_multi_plot(times, sigs, None, colors)

step = 0.8 * np.ptp(sigs[0])

for ind, (time, sig) in enumerate(zip(times, sigs)):
ax.plot(time, sig+step*ind, color=next(colors), **plt_kwargs)

ax.set(yticks=[])
ax.set_xlabel(xlabel)
ax.set_ylabel('Channels')
ax.set(xlabel=xlabel, ylabel='Channels', yticks=[])


def _check_times(times, sigs):
"""Helper function to check a times definition passed into a time series plot function."""

xlabel = 'Time (s)'
if times is None:
times = create_samples(len(sigs[0]))
if isinstance(sigs, list) or (isinstance(sigs, np.ndarray) and sigs.ndim == 2):
n_samples = len(sigs[0])
else:
n_samples = len(sigs)

Check warning on line 194 in neurodsp/plts/time_series.py

View check run for this annotation

Codecov / codecov/patch

neurodsp/plts/time_series.py#L194

Added line #L194 was not covered by tests
times = create_samples(n_samples)
xlabel = 'Samples'

times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times

return times, xlabel
42 changes: 42 additions & 0 deletions neurodsp/plts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from copy import deepcopy
from functools import wraps
from os.path import join as pjoin
from itertools import repeat, cycle

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.settings import SUPTITLE_FONTSIZE
Expand Down Expand Up @@ -155,3 +157,43 @@ def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6,
**title_kwargs)

return axes


def prepare_multi_plot(xs, ys, labels=None, colors=None):
"""Prepare inputs for plotting one or more elements in a loop.

Parameters
----------
xs, ys : 1d or 2d array
Plot data.
labels : str or list
Label(s) for the plot input(s).
colors : str or iterable
Color(s) to plot input(s).

Returns
-------
xs, ys : iterable
Plot data.
labels : iterable
Label(s) for the plot input(s).
colors : iterable
Color(s) to plot input(s).

Notes
-----
This function takes inputs that can reflect one or more plot elements, and
prepares the inputs to be iterable for plotting in a loop.
"""

xs = repeat(xs) if isinstance(xs, np.ndarray) and xs.ndim == 1 else xs
ys = [ys] if isinstance(ys, np.ndarray) and ys.ndim == 1 else ys

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = repeat(labels)

colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)
TomDonoghue marked this conversation as resolved.
Show resolved Hide resolved

return xs, ys, labels, colors
25 changes: 25 additions & 0 deletions neurodsp/tests/plts/test_aperiodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Tests for neurodsp.plts.aperiodic."""

from neurodsp.aperiodic.autocorr import compute_autocorr

from neurodsp.tests.settings import TEST_PLOTS_PATH, FS
from neurodsp.tests.tutils import plot_test

from neurodsp.plts.aperiodic import *

###################################################################################################
###################################################################################################

def tests_plot_autocorr(tsig, tsig_comb):

times1, acs1 = compute_autocorr(tsig, max_lag=150)
times2, acs2 = compute_autocorr(tsig_comb, max_lag=150)

plot_autocorr(times1, acs1,
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_autocorr-1.png')

plot_autocorr([times1, times2], [acs1, acs2],
labels=['first', 'second'], colors=['k', 'r'],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_autocorr-2.png')
15 changes: 15 additions & 0 deletions neurodsp/tests/plts/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,18 @@ def test_plot_spectral_hist(tsig_comb):
spectrum=spectrum, spectrum_freqs=spectrum_freqs,
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectral_hist.png')

@plot_test
def test_plot_spectra_3D(tsig_comb, tsig_burst):

freqs1, powers1 = compute_spectrum(tsig_comb, FS)
freqs2, powers2 = compute_spectrum(tsig_burst, FS)

plot_spectra_3D([freqs1, freqs2], [powers1, powers2],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectral3D_1.png')

plot_spectra_3D(freqs1, [powers1, powers2, powers1, powers2],
colors=['r', 'y', 'b', 'g'],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectral3D_2.png')
27 changes: 27 additions & 0 deletions neurodsp/tests/plts/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Tests for neurodsp.plts.utils."""

import os
import itertools

import numpy as np
import matplotlib as mpl

from neurodsp.tests.settings import TEST_PLOTS_PATH
Expand Down Expand Up @@ -74,3 +76,28 @@ def test_make_axes():
axes = make_axes(2, 2)
assert axes.shape == (2, 2)
assert isinstance(axes[0, 0], mpl.axes._axes.Axes)

def test_prepare_multi_plot():

xs1 = np.array([1, 2, 3])
ys1 = np.array([1, 2, 3])
labels1 = None
colors1 = None

# 1 input
xs1o, ys1o, labels1o, colors1o = prepare_multi_plot(xs1, ys1, labels1, colors1)
assert isinstance(xs1o, itertools.repeat)
assert isinstance(ys1o, list)
assert isinstance(labels1o, itertools.repeat)
assert isinstance(colors1o, itertools.repeat)

# multiple inputs
xs2 = [np.array([1, 2, 3]), np.array([4, 5, 6])]
ys2 = [np.array([1, 2, 3]), np.array([4, 5, 6])]
labels2 = ['A', 'B']
colors2 = ['blue', 'red']
xs2o, ys2o, labels2o, colors2o = prepare_multi_plot(xs2, ys2, labels2, colors2)
assert isinstance(xs2o, list)
assert isinstance(ys2o, list)
assert isinstance(labels2o, list)
assert isinstance(colors2o, itertools.cycle)
Loading