Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Render plot used everywhere #66

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ wandb/
*slurm*
tests/data/genomes
tests/data/test_pipeline
tests/data/pl_output/

# Sphinx documentation
_build
Expand All @@ -12,4 +13,4 @@ node_modules
.DS_Store
._.DS_Store
docs/tutorials/mouse_biccn.ipynb
docs/tutorials/.ipynb_checkpoints
docs/tutorials/.ipynb_checkpoints
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ test = [
"coverage",
"anndata",
"genomepy",
"modisco-lite",
]

[tool.coverage.run]
Expand Down
8 changes: 7 additions & 1 deletion src/crested/pl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

def render_plot(
fig,
ax=None,
width: int = 8,
height: int = 8,
title: str | None = None,
Expand Down Expand Up @@ -36,6 +37,8 @@ def render_plot(
----------
fig
The figure object to render.
ax
The axis object to which to apply the customizations. If None, all axes in the figure are modified.
width
Width of the plot (inches).
height
Expand Down Expand Up @@ -78,7 +81,9 @@ def render_plot(
fig.supxlabel(supxlabel)
if supylabel:
fig.supylabel(supylabel)
for ax in fig.axes:
axes_to_modify = [ax] if ax else fig.axes

for ax in axes_to_modify:
if xlabel:
ax.set_xlabel(xlabel, fontsize=x_label_fontsize)
if ylabel:
Expand All @@ -89,6 +94,7 @@ def render_plot(
for label in ax.get_yticklabels():
label.set_fontsize(y_tick_fontsize)
label.set_rotation(y_label_rotation)

if tight_rect:
plt.tight_layout(rect=tight_rect)
else:
Expand Down
80 changes: 49 additions & 31 deletions src/crested/pl/hist/_locus_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
import matplotlib.pyplot as plt
import numpy as np

from crested.pl._utils import render_plot


def locus_scoring(
scores: np.ndarray,
range: tuple[int, int],
gene_start: int | None = None,
gene_end: int | None = None,
title: str = "Predictions across Genomic Regions",
bigwig_values: np.ndarray | None = None,
bigwig_midpoints: list[int] | None = None,
filename: str | None = None,
**kwargs,
):
"""
Plot the predictions as a line chart over the entire genomic input and optionally indicate the gene locus.
Expand All @@ -33,14 +34,14 @@ def locus_scoring(
The start position of the gene locus to highlight on the plot.
gene_end
The end position of the gene locus to highlight on the plot.
title
The title of the plot.
bigwig_values
A numpy array of values extracted from a bigWig file for the same coordinates.
bigwig_midpoints
A list of base pair positions corresponding to the bigwig_values.
filename
The filename to save the plot to.
kwargs
Additional arguments passed to :func:`~crested.pl.render_plot` to
control the final plot output. Please see :func:`~crested.pl.render_plot`
for details.

See Also
--------
Expand All @@ -61,12 +62,20 @@ def locus_scoring(

.. image:: ../../../../docs/_static/img/examples/hist_locus_scoring.png
"""
# Plotting predictions
plt.figure(figsize=(30, 10))

# Top plot: Model predictions
plt.subplot(2, 1, 1)
plt.plot(
if bigwig_midpoints is not None and bigwig_values is not None:
nrows = 2
else:
nrows = 1
fig, axes = plt.subplots(
nrows,
1,
sharex=True,
)
if nrows == 1:
axes = [axes]

axes[0].plot(
np.arange(range[0], range[1]),
scores,
marker="o",
Expand All @@ -75,37 +84,46 @@ def locus_scoring(
label="Prediction Score",
)
if gene_start is not None and gene_end is not None:
plt.axvspan(gene_start, gene_end, color="red", alpha=0.3, label="Gene Locus")
plt.title(title)
plt.xlabel("Genomic Position")
plt.ylabel("Prediction Score")
plt.ylim(bottom=0)
plt.xticks(rotation=90)
plt.grid(True)
plt.legend()
axes[0].axvspan(
gene_start, gene_end, color="red", alpha=0.3, label="Gene Locus"
)

axes[0].set_xlabel("Genomic Position")
for label in axes[0].get_xticklabels():
label.set_rotation(90)
axes[0].set_ylabel("Prediction Score")
axes[0].set_ylim(bottom=0)
axes[0].grid(True)
axes[0].legend()

# Bottom plot: bigWig values
if bigwig_values is not None and bigwig_midpoints is not None:
plt.subplot(2, 1, 2)
plt.plot(
axes[1].plot(
bigwig_midpoints,
bigwig_values,
linestyle="-",
color="g",
label="bigWig Values",
)
if gene_start is not None and gene_end is not None:
plt.axvspan(
axes[1].axvspan(
gene_start, gene_end, color="red", alpha=0.3, label="Gene Locus"
)
plt.xlabel("Genomic Position")
plt.ylabel("bigWig Values")
plt.xticks(rotation=90)
plt.ylim(bottom=0)
plt.grid(True)
plt.legend()
axes[1].set_xlabel("Genomic Position")
axes[1].set_ylabel("bigWig Values")
axes[1].grid(True)
axes[1].legend()

plt.tight_layout()
if filename:
plt.savefig(filename)
plt.show()

default_height = 5 * nrows
default_width = 30

if "width" not in kwargs:
kwargs["width"] = default_width
if "height" not in kwargs:
kwargs["height"] = default_height
if "title" not in kwargs:
kwargs["title"] = "Predictions across Genomic Regions"

return render_plot(fig, **kwargs)
2 changes: 1 addition & 1 deletion src/crested/pl/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def _optional_function_warning(*args, **kwargs):
logger.error(
"The requested functionality requires the 'tfmodisco' package, which is not installed. "
"Please install it with `pip install crested[tfmodisco]`.",
"Please install it with `pip install modisco-lite`.",
)


Expand Down
Loading
Loading