Skip to content

Commit

Permalink
Add plots for DE analysis and unify plotting API (#654)
Browse files Browse the repository at this point in the history
* EdgeR example and first paired plot implementation

* Reuse plot parameters

* Added DE plotting functionalities

* Added plot previews

* Unify plotting API

* Updated all plots to new API

* Resolved ToDos

* Coda plots docs with dict

* Final coda docs edits

* Plot stars in legend in black

* Apply suggestions from code review

Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>

* Update pertpy/tools/_differential_gene_expression/_base.py

Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>

* Update pertpy/tools/_differential_gene_expression/_base.py

Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>

* Removed all save parameters

* Removed _utils.py file

* Fixed edgeR test

* Use p-values from results_df in plot_paired

* PR Reviews

* Pull submodule

---------

Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>
  • Loading branch information
Lilly-May and Zethson authored Sep 16, 2024
1 parent 2ad41a7 commit 99dcd18
Show file tree
Hide file tree
Showing 24 changed files with 936 additions and 383 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ jobs:
- os: ubuntu-latest
python: "3.12"
run_mode: "slow"
# - os: ubuntu-latest
# python: "3.12"
# run_mode: "fast"
- os: ubuntu-latest
python: "3.12"
run_mode: slow
pip-flags: "--pre"
# - os: ubuntu-latest
# python: "3.12"
# run_mode: "fast"
# - os: ubuntu-latest
# python: "3.12"
# run_mode: slow
# pip-flags: "--pre"

env:
OS: ${{ matrix.os }}
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/docstring_previews/de_volcano.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
20 changes: 20 additions & 0 deletions pertpy/_doc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from textwrap import dedent


def _doc_params(**kwds): # pragma: no cover
"""\
Docstrings should start with "\" in the first line for proper formatting.
"""

def dec(obj):
obj.__orig_doc__ = obj.__doc__
obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
return obj

return dec


doc_common_plot_args = """\
show: if `True`, shows the plot.
return_fig: if `True`, returns figure of the plot.\
"""
17 changes: 15 additions & 2 deletions pertpy/metadata/_cell_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
if TYPE_CHECKING:
from collections.abc import Iterable

from matplotlib.pyplot import Figure

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scanpy import settings
from scipy import stats

from pertpy._doc import _doc_params, doc_common_plot_args
from pertpy.data._dataloader import _download

from ._look_up import LookUp
Expand Down Expand Up @@ -690,6 +693,7 @@ def correlate(

return corr, pvals, new_corr, new_pvals

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_correlation(
self,
adata: AnnData,
Expand All @@ -700,7 +704,9 @@ def plot_correlation(
metadata_key: str = "bulk_rna_broad",
category: str = "cell line",
subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
) -> None:
show: bool = True,
return_fig: bool = False,
) -> Figure | None:
"""Visualise the correlation of cell lines with annotated metadata.
Args:
Expand All @@ -713,6 +719,8 @@ def plot_correlation(
subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`.
If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number.
If None, all cell lines will be plotted.
{common_plot_args}
Returns:
Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately.
"""
Expand Down Expand Up @@ -790,6 +798,11 @@ def plot_correlation(
"edgecolor": "black",
},
)
plt.show()

if show:
plt.show()
if return_fig:
return plt.gcf()
return None
else:
raise NotImplementedError
25 changes: 19 additions & 6 deletions pertpy/preprocessing/_guide_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import uuid
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy

from pertpy._doc import _doc_params, doc_common_plot_args

if TYPE_CHECKING:
from anndata import AnnData
from matplotlib.axes import Axes
from matplotlib.pyplot import Figure


class GuideAssignment:
Expand Down Expand Up @@ -106,14 +109,18 @@ def assign_to_max_guide(

return None

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_heatmap(
self,
adata: AnnData,
*,
layer: str | None = None,
order_by: np.ndarray | str | None = None,
key_to_save_order: str = None,
show: bool = True,
return_fig: bool = False,
**kwargs,
) -> list[Axes]:
) -> Figure | None:
"""Heatmap plotting of guide RNA expression matrix.
Assuming guides have sparse expression, this function reorders cells
Expand All @@ -131,11 +138,12 @@ def plot_heatmap(
If a string is provided, adata.obs[order_by] will be used as the order.
If a numpy array is provided, the array will be used for ordering.
key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
{common_plot_args}
kwargs: Are passed to sc.pl.heatmap.
Returns:
List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
If `return_fig` is `True`, returns the figure, otherwise `None`.
Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided.
Examples:
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
Expand Down Expand Up @@ -172,17 +180,22 @@ def plot_heatmap(
adata.obs[key_to_save_order] = pd.Categorical(order)

try:
axis_group = sc.pl.heatmap(
fig = sc.pl.heatmap(
adata[order, :],
var_names=adata.var.index.tolist(),
groupby=temp_col_name,
cmap="viridis",
use_raw=False,
dendrogram=False,
layer=layer,
show=False,
**kwargs,
)
finally:
del adata.obs[temp_col_name]

return axis_group
if show:
plt.show()
if return_fig:
return fig
return None
2 changes: 1 addition & 1 deletion pertpy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs):
Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)

DE_EXTRAS = ["formulaic", "pydeseq2"]
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
Expand Down
78 changes: 34 additions & 44 deletions pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from anndata import AnnData
from joblib import Parallel, delayed
from lamin_utils import logger
from rich import print
from rich.progress import track
from scipy import sparse, stats
from sklearn.base import is_classifier, is_regressor
Expand All @@ -37,6 +36,8 @@
from skmisc.loess import loess
from statsmodels.stats.multitest import fdrcorrection

from pertpy._doc import _doc_params, doc_common_plot_args

if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
Expand Down Expand Up @@ -974,24 +975,26 @@ def predict_differential_prioritization(

return delta

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_dp_scatter(
self,
results: pd.DataFrame,
*,
top_n: int = None,
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Axes | Figure | None:
show: bool = True,
return_fig: bool = False,
) -> Figure | None:
"""Plot scatterplot of differential prioritization.
Args:
results: Results after running differential prioritization.
top_n: optionally, the number of top prioritized cell types to label in the plot
ax: optionally, axes used to draw plot
{common_plot_args}
Returns:
Axes of the plot.
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1038,37 +1041,34 @@ def plot_dp_scatter(
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
ax.add_artist(legend1)

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_important_features(
self,
data: dict[str, Any],
*,
key: str = "augurpy_results",
top_n: int = 10,
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Axes | None:
show: bool = True,
return_fig: bool = False,
) -> Figure | None:
"""Plot a lollipop plot of the n features with largest feature importances.
Args:
results: results after running `predict()` as dictionary or the AnnData object.
data: results after running `predict()` as dictionary or the AnnData object.
key: Key in the AnnData object of the results
top_n: n number feature importance values to plot. Default is 10.
ax: optionally, axes used to draw plot
return_figure: if `True` returns figure of the plot, default is `False`
{common_plot_args}
Returns:
Axes of the plot.
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1109,35 +1109,32 @@ def plot_important_features(
plt.ylabel("Gene")
plt.yticks(y_axes_range, n_features["genes"])

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_lollipop(
self,
data: dict[str, Any],
data: dict[str, Any] | AnnData,
*,
key: str = "augurpy_results",
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Axes | Figure | None:
show: bool = True,
return_fig: bool = False,
) -> Figure | None:
"""Plot a lollipop plot of the mean augur values.
Args:
results: results after running `predict()` as dictionary or the AnnData object.
key: Key in the AnnData object of the results
ax: optionally, axes used to draw plot
return_figure: if `True` returns figure of the plot
data: results after running `predict()` as dictionary or the AnnData object.
key: .uns key in the results AnnData object.
ax: optionally, axes used to draw plot.
{common_plot_args}
Returns:
Axes of the plot.
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1175,32 +1172,29 @@ def plot_lollipop(
plt.ylabel("Cell Type")
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_scatterplot(
self,
results1: dict[str, Any],
results2: dict[str, Any],
*,
top_n: int = None,
return_fig: bool | None = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Axes | Figure | None:
show: bool = True,
return_fig: bool = False,
) -> Figure | None:
"""Create scatterplot with two augur results.
Args:
results1: results after running `predict()`
results2: results after running `predict()`
top_n: optionally, the number of top prioritized cell types to label in the plot
return_figure: if `True` returns figure of the plot
{common_plot_args}
Returns:
Axes of the plot.
Expand Down Expand Up @@ -1249,12 +1243,8 @@ def plot_scatterplot(
plt.xlabel("Augur scores 1")
plt.ylabel("Augur scores 2")

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None
Loading

0 comments on commit 99dcd18

Please sign in to comment.