Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plots for DE analysis and unify plotting API #654

Merged
merged 21 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ jobs:
- os: ubuntu-latest
python: "3.12"
run_mode: "slow"
# - os: ubuntu-latest
# python: "3.12"
# run_mode: "fast"
- os: ubuntu-latest
python: "3.12"
run_mode: slow
pip-flags: "--pre"
# - os: ubuntu-latest
# python: "3.12"
# run_mode: "fast"
# - os: ubuntu-latest
# python: "3.12"
# run_mode: slow
# pip-flags: "--pre"

env:
OS: ${{ matrix.os }}
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/docstring_previews/de_volcano.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
54 changes: 54 additions & 0 deletions pertpy/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from pathlib import Path
from textwrap import dedent

import matplotlib.pyplot as plt
from matplotlib.pyplot import Figure


def savefig_or_show(
writekey: str,
show: bool,
save: bool | str,
return_fig: bool = False,
dpi: int = 150,
ext: str = "png",
) -> Figure | None:
if isinstance(save, str):
for try_ext in [".svg", ".pdf", ".png"]:
if save.endswith(try_ext):
ext = try_ext[1:]
save = save.replace(try_ext, "")
break
writekey += f"_{save}"
save = True

if save:
Path.mkdir(Path("figures"), exist_ok=True)
plt.savefig(f"figures/{writekey}.{ext}", dpi=dpi, bbox_inches="tight")
if show:
plt.show()
if save:
plt.close() # clear figure
if return_fig:
return plt.gcf()
return None


def _doc_params(**kwds): # pragma: no cover
"""\
Docstrings should start with "\" in the first line for proper formatting.
"""

def dec(obj):
obj.__orig_doc__ = obj.__doc__
obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
return obj

return dec


doc_common_plot_args = """\
show: if `True`, shows the plot.
save: if `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`.pdf`, `.png`, `.svg`}.
return_fig: if `True`, returns figure of the plot.\
"""
14 changes: 12 additions & 2 deletions pertpy/metadata/_cell_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
if TYPE_CHECKING:
from collections.abc import Iterable

from matplotlib.pyplot import Figure

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scanpy import settings
from scipy import stats

from pertpy._utils import _doc_params, doc_common_plot_args, savefig_or_show
from pertpy.data._dataloader import _download

from ._look_up import LookUp
Expand Down Expand Up @@ -690,6 +693,7 @@ def correlate(

return corr, pvals, new_corr, new_pvals

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_correlation(
self,
adata: AnnData,
Expand All @@ -700,7 +704,10 @@ def plot_correlation(
metadata_key: str = "bulk_rna_broad",
category: str = "cell line",
subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
) -> None:
show: bool = True,
save: str | bool = False,
return_fig: bool = False,
) -> Figure | None:
"""Visualise the correlation of cell lines with annotated metadata.

Args:
Expand All @@ -713,6 +720,8 @@ def plot_correlation(
subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`.
If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number.
If None, all cell lines will be plotted.
{common_plot_args}

Returns:
Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately.
"""
Expand Down Expand Up @@ -790,6 +799,7 @@ def plot_correlation(
"edgecolor": "black",
},
)
plt.show()

return savefig_or_show("cell_line_correlation", show=show, save=save, return_fig=return_fig)
else:
raise NotImplementedError
18 changes: 16 additions & 2 deletions pertpy/preprocessing/_guide_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import uuid
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy

from pertpy._utils import _doc_params, doc_common_plot_args, savefig_or_show

if TYPE_CHECKING:
from anndata import AnnData
from matplotlib.axes import Axes
Expand Down Expand Up @@ -106,12 +109,16 @@ def assign_to_max_guide(

return None

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_heatmap(
self,
adata: AnnData,
layer: str | None = None,
order_by: np.ndarray | str | None = None,
key_to_save_order: str = None,
show: bool = True,
save: str | bool = False,
return_fig: bool = False,
**kwargs,
) -> list[Axes]:
"""Heatmap plotting of guide RNA expression matrix.
Expand All @@ -131,10 +138,11 @@ def plot_heatmap(
If a string is provided, adata.obs[order_by] will be used as the order.
If a numpy array is provided, the array will be used for ordering.
key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
{common_plot_args}
kwargs: Are passed to sc.pl.heatmap.

Returns:
List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
If return_fig is True, returns a list of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.

Examples:
Expand Down Expand Up @@ -180,9 +188,15 @@ def plot_heatmap(
use_raw=False,
dendrogram=False,
layer=layer,
save=save,
show=False,
**kwargs,
)
finally:
del adata.obs[temp_col_name]

return axis_group
if show:
plt.show()
if return_fig:
return axis_group
return None
2 changes: 1 addition & 1 deletion pertpy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs):
Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)

DE_EXTRAS = ["formulaic", "pydeseq2"]
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
Expand Down
100 changes: 37 additions & 63 deletions pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from anndata import AnnData
from joblib import Parallel, delayed
from lamin_utils import logger
from rich import print
from rich.progress import track
from scipy import sparse, stats
from sklearn.base import is_classifier, is_regressor
Expand All @@ -37,6 +36,8 @@
from skmisc.loess import loess
from statsmodels.stats.multitest import fdrcorrection

from pertpy._utils import _doc_params, doc_common_plot_args, savefig_or_show

if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
Expand Down Expand Up @@ -974,24 +975,26 @@ def predict_differential_prioritization(

return delta

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_dp_scatter(
self,
results: pd.DataFrame,
top_n: int = None,
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Axes | Figure | None:
show: bool = True,
save: str | bool = False,
return_fig: bool = False,
) -> Figure | None:
"""Plot scatterplot of differential prioritization.

Args:
results: Results after running differential prioritization.
top_n: optionally, the number of top prioritized cell types to label in the plot
ax: optionally, axes used to draw plot
{common_plot_args}

Returns:
Axes of the plot.
If `return_fig` is `True`, returns the figure, otherwise `None`.

Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1038,37 +1041,30 @@ def plot_dp_scatter(
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
ax.add_artist(legend1)

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None
return savefig_or_show("augur_dp_scatter", show=show, save=save, return_fig=return_fig)

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_important_features(
self,
data: dict[str, Any],
key: str = "augurpy_results",
top_n: int = 10,
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Axes | None:
show: bool = True,
save: str | bool = False,
return_fig: bool = False,
) -> Figure | None:
"""Plot a lollipop plot of the n features with largest feature importances.

Args:
results: results after running `predict()` as dictionary or the AnnData object.
data: results after running `predict()` as dictionary or the AnnData object.
key: Key in the AnnData object of the results
top_n: n number feature importance values to plot. Default is 10.
ax: optionally, axes used to draw plot
return_figure: if `True` returns figure of the plot, default is `False`
{common_plot_args}

Returns:
Axes of the plot.
If `return_fig` is `True`, returns the figure, otherwise `None`.

Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1109,35 +1105,28 @@ def plot_important_features(
plt.ylabel("Gene")
plt.yticks(y_axes_range, n_features["genes"])

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None
return savefig_or_show("augur_important_features", show=show, save=save, return_fig=return_fig)

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_lollipop(
self,
data: dict[str, Any],
data: dict[str, Any] | AnnData,
key: str = "augurpy_results",
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
show: bool = True,
save: str | bool = False,
return_fig: bool = False,
) -> Axes | Figure | None:
"""Plot a lollipop plot of the mean augur values.

Args:
results: results after running `predict()` as dictionary or the AnnData object.
key: Key in the AnnData object of the results
ax: optionally, axes used to draw plot
return_figure: if `True` returns figure of the plot
data: results after running `predict()` as dictionary or the AnnData object.
key: .uns key in the results AnnData object.
ax: optionally, axes used to draw plot.
{common_plot_args}

Returns:
Axes of the plot.
If `return_fig` is `True`, returns the figure, otherwise `None`.

Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1175,32 +1164,25 @@ def plot_lollipop(
plt.ylabel("Cell Type")
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None
return savefig_or_show("augur_lollipop", show=show, save=save, return_fig=return_fig)

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_scatterplot(
self,
results1: dict[str, Any],
results2: dict[str, Any],
top_n: int = None,
return_fig: bool | None = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Axes | Figure | None:
show: bool = True,
save: str | bool = False,
return_fig: bool = False,
) -> Figure | None:
"""Create scatterplot with two augur results.

Args:
results1: results after running `predict()`
results2: results after running `predict()`
top_n: optionally, the number of top prioritized cell types to label in the plot
return_figure: if `True` returns figure of the plot
{common_plot_args}

Returns:
Axes of the plot.
Expand Down Expand Up @@ -1249,12 +1231,4 @@ def plot_scatterplot(
plt.xlabel("Augur scores 1")
plt.ylabel("Augur scores 2")

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None
return savefig_or_show("augur_scatterplot", show=show, save=save, return_fig=return_fig)
Loading
Loading