diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 41b68fd2..e03399e5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,13 +27,13 @@ jobs: - os: ubuntu-latest python: "3.12" run_mode: "slow" -# - os: ubuntu-latest -# python: "3.12" -# run_mode: "fast" - - os: ubuntu-latest - python: "3.12" - run_mode: slow - pip-flags: "--pre" + # - os: ubuntu-latest + # python: "3.12" + # run_mode: "fast" + # - os: ubuntu-latest + # python: "3.12" + # run_mode: slow + # pip-flags: "--pre" env: OS: ${{ matrix.os }} diff --git a/docs/_static/docstring_previews/de_fold_change.png b/docs/_static/docstring_previews/de_fold_change.png new file mode 100644 index 00000000..4c7b88ed Binary files /dev/null and b/docs/_static/docstring_previews/de_fold_change.png differ diff --git a/docs/_static/docstring_previews/de_multicomparison_fc.png b/docs/_static/docstring_previews/de_multicomparison_fc.png new file mode 100644 index 00000000..191ce034 Binary files /dev/null and b/docs/_static/docstring_previews/de_multicomparison_fc.png differ diff --git a/docs/_static/docstring_previews/de_paired_expression.png b/docs/_static/docstring_previews/de_paired_expression.png new file mode 100644 index 00000000..0f7298df Binary files /dev/null and b/docs/_static/docstring_previews/de_paired_expression.png differ diff --git a/docs/_static/docstring_previews/de_volcano.png b/docs/_static/docstring_previews/de_volcano.png new file mode 100644 index 00000000..78324540 Binary files /dev/null and b/docs/_static/docstring_previews/de_volcano.png differ diff --git a/docs/_static/docstring_previews/pseudobulk_samples.png b/docs/_static/docstring_previews/pseudobulk_samples.png new file mode 100644 index 00000000..d26f6840 Binary files /dev/null and b/docs/_static/docstring_previews/pseudobulk_samples.png differ diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 7b014d1f..55d33929 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 7b014d1f11135821667188a4d26bf2156fb18b9a +Subproject commit 55d33929ce3ab5d4d38d5ae8b7684c822f720124 diff --git a/pertpy/_doc.py b/pertpy/_doc.py new file mode 100644 index 00000000..14c7fb31 --- /dev/null +++ b/pertpy/_doc.py @@ -0,0 +1,20 @@ +from textwrap import dedent + + +def _doc_params(**kwds): # pragma: no cover + """\ + Docstrings should start with "\" in the first line for proper formatting. + """ + + def dec(obj): + obj.__orig_doc__ = obj.__doc__ + obj.__doc__ = dedent(obj.__doc__.format_map(kwds)) + return obj + + return dec + + +doc_common_plot_args = """\ +show: if `True`, shows the plot. + return_fig: if `True`, returns figure of the plot.\ +""" diff --git a/pertpy/metadata/_cell_line.py b/pertpy/metadata/_cell_line.py index b0e79864..ac89ae59 100644 --- a/pertpy/metadata/_cell_line.py +++ b/pertpy/metadata/_cell_line.py @@ -8,12 +8,15 @@ if TYPE_CHECKING: from collections.abc import Iterable + from matplotlib.pyplot import Figure + import matplotlib.pyplot as plt import numpy as np import pandas as pd from scanpy import settings from scipy import stats +from pertpy._doc import _doc_params, doc_common_plot_args from pertpy.data._dataloader import _download from ._look_up import LookUp @@ -690,6 +693,7 @@ def correlate( return corr, pvals, new_corr, new_pvals + @_doc_params(common_plot_args=doc_common_plot_args) def plot_correlation( self, adata: AnnData, @@ -700,7 +704,9 @@ def plot_correlation( metadata_key: str = "bulk_rna_broad", category: str = "cell line", subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None, - ) -> None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Visualise the correlation of cell lines with annotated metadata. Args: @@ -713,6 +719,8 @@ def plot_correlation( subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`. If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number. If None, all cell lines will be plotted. + {common_plot_args} + Returns: Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately. """ @@ -790,6 +798,11 @@ def plot_correlation( "edgecolor": "black", }, ) - plt.show() + + if show: + plt.show() + if return_fig: + return plt.gcf() + return None else: raise NotImplementedError diff --git a/pertpy/preprocessing/_guide_rna.py b/pertpy/preprocessing/_guide_rna.py index 9ebded21..50fad2bf 100644 --- a/pertpy/preprocessing/_guide_rna.py +++ b/pertpy/preprocessing/_guide_rna.py @@ -3,14 +3,17 @@ import uuid from typing import TYPE_CHECKING +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc import scipy +from pertpy._doc import _doc_params, doc_common_plot_args + if TYPE_CHECKING: from anndata import AnnData - from matplotlib.axes import Axes + from matplotlib.pyplot import Figure class GuideAssignment: @@ -106,14 +109,18 @@ def assign_to_max_guide( return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_heatmap( self, adata: AnnData, + *, layer: str | None = None, order_by: np.ndarray | str | None = None, key_to_save_order: str = None, + show: bool = True, + return_fig: bool = False, **kwargs, - ) -> list[Axes]: + ) -> Figure | None: """Heatmap plotting of guide RNA expression matrix. Assuming guides have sparse expression, this function reorders cells @@ -131,11 +138,12 @@ def plot_heatmap( If a string is provided, adata.obs[order_by] will be used as the order. If a numpy array is provided, the array will be used for ordering. key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None. + {common_plot_args} kwargs: Are passed to sc.pl.heatmap. Returns: - List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap. - Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided. + If `return_fig` is `True`, returns the figure, otherwise `None`. + Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided. Examples: Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then @@ -172,7 +180,7 @@ def plot_heatmap( adata.obs[key_to_save_order] = pd.Categorical(order) try: - axis_group = sc.pl.heatmap( + fig = sc.pl.heatmap( adata[order, :], var_names=adata.var.index.tolist(), groupby=temp_col_name, @@ -180,9 +188,14 @@ def plot_heatmap( use_raw=False, dendrogram=False, layer=layer, + show=False, **kwargs, ) finally: del adata.obs[temp_col_name] - return axis_group + if show: + plt.show() + if return_fig: + return fig + return None diff --git a/pertpy/tools/__init__.py b/pertpy/tools/__init__.py index d0f26e77..10565cee 100644 --- a/pertpy/tools/__init__.py +++ b/pertpy/tools/__init__.py @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS) DE_EXTRAS = ["formulaic", "pydeseq2"] -EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2 +EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2 PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS) Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"]) TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS) diff --git a/pertpy/tools/_augur.py b/pertpy/tools/_augur.py index 81d317bc..bd226f6f 100644 --- a/pertpy/tools/_augur.py +++ b/pertpy/tools/_augur.py @@ -15,7 +15,6 @@ from anndata import AnnData from joblib import Parallel, delayed from lamin_utils import logger -from rich import print from rich.progress import track from scipy import sparse, stats from sklearn.base import is_classifier, is_regressor @@ -37,6 +36,8 @@ from skmisc.loess import loess from statsmodels.stats.multitest import fdrcorrection +from pertpy._doc import _doc_params, doc_common_plot_args + if TYPE_CHECKING: from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -974,24 +975,26 @@ def predict_differential_prioritization( return delta + @_doc_params(common_plot_args=doc_common_plot_args) def plot_dp_scatter( self, results: pd.DataFrame, + *, top_n: int = None, - return_fig: bool | None = None, ax: Axes = None, - show: bool | None = None, - save: str | bool | None = None, - ) -> Axes | Figure | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plot scatterplot of differential prioritization. Args: results: Results after running differential prioritization. top_n: optionally, the number of top prioritized cell types to label in the plot ax: optionally, axes used to draw plot + {common_plot_args} Returns: - Axes of the plot. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1038,37 +1041,34 @@ def plot_dp_scatter( legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5)) ax.add_artist(legend1) - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return ax return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_important_features( self, data: dict[str, Any], + *, key: str = "augurpy_results", top_n: int = 10, - return_fig: bool | None = None, ax: Axes = None, - show: bool | None = None, - save: str | bool | None = None, - ) -> Axes | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plot a lollipop plot of the n features with largest feature importances. Args: - results: results after running `predict()` as dictionary or the AnnData object. + data: results after running `predict()` as dictionary or the AnnData object. key: Key in the AnnData object of the results top_n: n number feature importance values to plot. Default is 10. ax: optionally, axes used to draw plot - return_figure: if `True` returns figure of the plot, default is `False` + {common_plot_args} Returns: - Axes of the plot. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1109,35 +1109,32 @@ def plot_important_features( plt.ylabel("Gene") plt.yticks(y_axes_range, n_features["genes"]) - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return ax return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_lollipop( self, - data: dict[str, Any], + data: dict[str, Any] | AnnData, + *, key: str = "augurpy_results", - return_fig: bool | None = None, ax: Axes = None, - show: bool | None = None, - save: str | bool | None = None, - ) -> Axes | Figure | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plot a lollipop plot of the mean augur values. Args: - results: results after running `predict()` as dictionary or the AnnData object. - key: Key in the AnnData object of the results - ax: optionally, axes used to draw plot - return_figure: if `True` returns figure of the plot + data: results after running `predict()` as dictionary or the AnnData object. + key: .uns key in the results AnnData object. + ax: optionally, axes used to draw plot. + {common_plot_args} Returns: - Axes of the plot. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1175,32 +1172,29 @@ def plot_lollipop( plt.ylabel("Cell Type") plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns) - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return ax return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_scatterplot( self, results1: dict[str, Any], results2: dict[str, Any], + *, top_n: int = None, - return_fig: bool | None = None, - show: bool | None = None, - save: str | bool | None = None, - ) -> Axes | Figure | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Create scatterplot with two augur results. Args: results1: results after running `predict()` results2: results after running `predict()` top_n: optionally, the number of top prioritized cell types to label in the plot - return_figure: if `True` returns figure of the plot + {common_plot_args} Returns: Axes of the plot. @@ -1249,12 +1243,8 @@ def plot_scatterplot( plt.xlabel("Augur scores 1") plt.ylabel("Augur scores 2") - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return ax return None diff --git a/pertpy/tools/_cinemaot.py b/pertpy/tools/_cinemaot.py index a4a77273..f79ea951 100644 --- a/pertpy/tools/_cinemaot.py +++ b/pertpy/tools/_cinemaot.py @@ -18,9 +18,12 @@ from sklearn.linear_model import LinearRegression from sklearn.neighbors import NearestNeighbors +from pertpy._doc import _doc_params, doc_common_plot_args + if TYPE_CHECKING: from anndata import AnnData from matplotlib.axes import Axes + from matplotlib.pyplot import Figure from statsmodels.tools.typing import ArrayLike @@ -639,6 +642,7 @@ def attribution_scatter( s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6) return c_effect, s_effect + @_doc_params(common_plot_args=doc_common_plot_args) def plot_vis_matching( self, adata: AnnData, @@ -647,16 +651,17 @@ def plot_vis_matching( control: str, de_label: str, source_label: str, + *, matching_rep: str = "ot", resolution: float = 0.5, normalize: str = "col", title: str = "CINEMA-OT matching matrix", min_val: float = 0.01, - show: bool = True, - save: str | None = None, ax: Axes | None = None, + show: bool = True, + return_fig: bool = False, **kwargs, - ) -> None: + ) -> Figure | None: """Visualize the CINEMA-OT matching matrix. Args: @@ -670,11 +675,12 @@ def plot_vis_matching( normalize: normalize the coarse-grained matching matrix by row / column. title: the title for the figure. min_val: The min value to truncate the matching matrix. - show: Show the plot, do not return axis. - save: If `True` or a `str`, save the figure. A string is appended to the default filename. - Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. + {common_plot_args} **kwargs: Other parameters to input for seaborn.heatmap. + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + Examples: >>> import pertpy as pt >>> adata = pt.dt.cinemaot_example() @@ -710,12 +716,12 @@ def plot_vis_matching( g = sns.heatmap(df, annot=True, ax=ax, **kwargs) plt.title(title) - _utils.savefig_or_show("matching_heatmap", show=show, save=save) - if not show: - if ax is not None: - return ax - else: - return g + + if show: + plt.show() + if return_fig: + return g + return None class Xi: diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index c2cc4cba..65110586 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -26,6 +26,8 @@ from rich.table import Table from scipy.cluster import hierarchy as sp_hierarchy +from pertpy._doc import _doc_params, doc_common_plot_args + if TYPE_CHECKING: from collections.abc import Sequence @@ -1023,7 +1025,7 @@ def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda"): >>> key_added="lineage", add_level_name=True >>> ) >>> mdata = tasccoda.prepare( - >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} + >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi" : 0} >>> ) >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) >>> node_effects = tasccoda.get_node_df(mdata) @@ -1185,22 +1187,21 @@ def _stackbar( # pragma: no cover return ax + @_doc_params(common_plot_args=doc_common_plot_args) def plot_stacked_barplot( # pragma: no cover self, data: AnnData | MuData, feature_name: str, + *, modality_key: str = "coda", palette: ListedColormap | None = cm.tab20, show_legend: bool | None = True, level_order: list[str] = None, figsize: tuple[float, float] | None = None, dpi: int | None = 100, - return_fig: bool | None = None, - ax: plt.Axes | None = None, - show: bool | None = None, - save: str | bool | None = None, - **kwargs, - ) -> plt.Axes | plt.Figure | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples"). Args: @@ -1212,9 +1213,10 @@ def plot_stacked_barplot( # pragma: no cover palette: The matplotlib color map for the barplot. show_legend: If True, adds a legend. level_order: Custom ordering of bars on the x-axis. + {common_plot_args} Returns: - A :class:`~matplotlib.axes.Axes` object + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1239,7 +1241,7 @@ def plot_stacked_barplot( # pragma: no cover if level_order: assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels" data = data[level_order] - ax = self._stackbar( + self._stackbar( data.X, type_names=data.var.index, title="samples", @@ -1265,7 +1267,7 @@ def plot_stacked_barplot( # pragma: no cover l_indices = np.where(data.obs[feature_name] == levels[level]) feature_totals[level] = np.sum(data.X[l_indices], axis=0) - ax = self._stackbar( + self._stackbar( feature_totals, type_names=ct_names, title=feature_name, @@ -1276,19 +1278,17 @@ def plot_stacked_barplot( # pragma: no cover show_legend=show_legend, ) - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return ax return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_effects_barplot( # pragma: no cover self, data: AnnData | MuData, + *, modality_key: str = "coda", covariates: str | list | None = None, parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change", @@ -1300,11 +1300,9 @@ def plot_effects_barplot( # pragma: no cover args_barplot: dict | None = None, figsize: tuple[float, float] | None = None, dpi: int | None = 100, - return_fig: bool | None = None, - ax: plt.Axes | None = None, - show: bool | None = None, - save: str | bool | None = None, - ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Barplot visualization for effects. The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types. @@ -1323,10 +1321,10 @@ def plot_effects_barplot( # pragma: no cover palette: The seaborn color map for the barplot. level_order: Custom ordering of bars on the x-axis. args_barplot: Arguments passed to sns.barplot. + {common_plot_args} Returns: - Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`) - or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1437,16 +1435,6 @@ def plot_effects_barplot( # pragma: no cover if ax.get_xticklabels()[0]._text == "zero": ax.set_xticks([]) - if save: - plt.savefig(save, bbox_inches="tight") - if show: - plt.show() - if return_fig: - return plt.gcf() - if not (show or save): - return g - return None - # If not plot as facets, call barplot to plot cell types on the x-axis. else: _, ax = plt.subplots(figsize=figsize, dpi=dpi) @@ -1478,20 +1466,18 @@ def plot_effects_barplot( # pragma: no cover cell_types = pd.unique(plot_df["Cell Type"]) ax.set_xticklabels(cell_types, rotation=90) - if save: - plt.savefig(save, bbox_inches="tight") - if show: - plt.show() - if return_fig: - return plt.gcf() - if not (show or save): - return ax - return None + if show: + plt.show() + if return_fig: + return plt.gcf() + return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_boxplots( # pragma: no cover self, data: AnnData | MuData, feature_name: str, + *, modality_key: str = "coda", y_scale: Literal["relative", "log", "log10", "count"] = "relative", plot_facets: bool = False, @@ -1504,11 +1490,9 @@ def plot_boxplots( # pragma: no cover level_order: list[str] = None, figsize: tuple[float, float] | None = None, dpi: int | None = 100, - return_fig: bool | None = None, - ax: plt.Axes | None = None, - show: bool | None = None, - save: str | bool | None = None, - ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots @@ -1530,10 +1514,10 @@ def plot_boxplots( # pragma: no cover palette: The seaborn color map for the barplot. show_legend: If True, adds a legend. level_order: Custom ordering of bars on the x-axis. + {common_plot_args} Returns: - Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`) - or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1651,16 +1635,6 @@ def plot_boxplots( # pragma: no cover **args_swarmplot, ).set_titles("{col_name}") - if save: - plt.savefig(save, bbox_inches="tight") - if show: - plt.show() - if return_fig: - return plt.gcf() - if not (show or save): - return g - return None - # If not plot as facets, call boxplot to plot cell types on the x-axis. else: if level_order: @@ -1724,19 +1698,17 @@ def plot_boxplots( # pragma: no cover title=feature_name, ) - if save: - plt.savefig(save, bbox_inches="tight") - if show: - plt.show() - if return_fig: - return plt.gcf() - if not (show or save): - return ax - return None + if show: + plt.show() + if return_fig: + return plt.gcf() + return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_rel_abundance_dispersion_plot( # pragma: no cover self, data: AnnData | MuData, + *, modality_key: str = "coda", abundant_threshold: float | None = 0.9, default_color: str | None = "Grey", @@ -1744,11 +1716,10 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover label_cell_types: bool = True, figsize: tuple[float, float] | None = None, dpi: int | None = 100, - return_fig: bool | None = None, ax: plt.Axes | None = None, - show: bool | None = None, - save: str | bool | None = None, - ) -> plt.Axes | plt.Figure | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type. If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color. @@ -1763,9 +1734,10 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover figsize: Figure size. dpi: Dpi setting. ax: A matplotlib axes object. Only works if plotting a single component. + {common_plot_args} Returns: - A :class:`~matplotlib.axes.Axes` object + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1849,19 +1821,17 @@ def label_point(x, y, val, ax): ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant") - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return ax return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_draw_tree( # pragma: no cover self, data: AnnData | MuData, + *, modality_key: str = "coda", tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors tight_text: bool | None = False, @@ -1869,8 +1839,9 @@ def plot_draw_tree( # pragma: no cover units: Literal["px", "mm", "in"] | None = "px", figsize: tuple[float, float] | None = (None, None), dpi: int | None = 100, - show: bool | None = True, - save: str | bool | None = None, + save: str | bool = False, + show: bool = True, + return_fig: bool = False, ) -> Tree | None: """Plot a tree using input ete3 tree object. @@ -1881,12 +1852,11 @@ def plot_draw_tree( # pragma: no cover tight_text: When False, boundaries of the text are approximated according to general font metrics, producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces. show_scale: Include the scale legend in the tree image or not. - show: If True, plot the tree inline. If false, return tree and tree_style objects. - file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG. - Output image can be saved whether show is True or not. units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. figsize: Figure size. dpi: Dots per inches. + save: Save the tree plot to a file. You can specify the file name here. + {common_plot_args} Returns: Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`) @@ -1901,7 +1871,7 @@ def plot_draw_tree( # pragma: no cover >>> key_added="lineage", add_level_name=True >>> ) >>> mdata = tasccoda.prepare( - >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} + >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args=dict(phi=0) >>> ) >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) >>> tasccoda.plot_draw_tree(mdata, tree="lineage") @@ -1936,13 +1906,16 @@ def my_layout(node): tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore if show: return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore - else: + if return_fig: return tree, tree_style + return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_draw_effects( # pragma: no cover self, data: AnnData | MuData, covariate: str, + *, modality_key: str = "coda", tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors show_legend: bool | None = None, @@ -1952,8 +1925,9 @@ def plot_draw_effects( # pragma: no cover units: Literal["px", "mm", "in"] | None = "px", figsize: tuple[float, float] | None = (None, None), dpi: int | None = 100, - show: bool | None = True, - save: str | None = None, + save: str | bool = False, + show: bool = True, + return_fig: bool = False, ) -> Tree | None: """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects. @@ -1968,15 +1942,15 @@ def plot_draw_effects( # pragma: no cover tight_text: When False, boundaries of the text are approximated according to general font metrics, producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces. show_scale: Include the scale legend in the tree image or not. - show: If True, plot the tree inline. If false, return tree and tree_style objects. - file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not. units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. figsize: Figure size. dpi: Dots per inches. + save: Save the tree plot to a file. You can specify the file name here. + {common_plot_args} Returns: - Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) - or plot the tree inline (`show = False`) + Returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`return_fig = False`) + or plot the tree inline (`show = True`) Examples: >>> import pertpy as pt @@ -1988,7 +1962,7 @@ def plot_draw_effects( # pragma: no cover >>> key_added="lineage", add_level_name=True >>> ) >>> mdata = tasccoda.prepare( - >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} + >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args=dict(phi=0) >>> ) >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) >>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage") @@ -2117,52 +2091,55 @@ def my_layout(node): plt.xlim(-leaf_eff_max, leaf_eff_max) plt.subplots_adjust(wspace=0) - if save is not None: + if save: plt.savefig(save) - if save is not None and not show_leaf_effects: + if save and not show_leaf_effects: tree2.render(save, tree_style=tree_style, units=units) if show: if not show_leaf_effects: return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) - else: + if return_fig: if not show_leaf_effects: return tree2, tree_style return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_effects_umap( # pragma: no cover self, mdata: MuData, effect_name: str | list | None, cluster_key: str, + *, modality_key_1: str = "rna", modality_key_2: str = "coda", color_map: Colormap | str | None = None, palette: str | Sequence[str] | None = None, - return_fig: bool | None = None, ax: Axes = None, - show: bool = None, - save: str | bool | None = None, + show: bool = True, + return_fig: bool = False, **kwargs, - ) -> plt.Axes | plt.Figure | None: + ) -> Figure | None: """Plot a UMAP visualization colored by effect strength. Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData (default is data['rna']) depending on the cluster they were assigned to. Args: - mudata: MuData object. + mdata: MuData object. effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']). To assign cell types' effects to original cells. modality_key_1: Key to the cell-level AnnData in the MuData object. modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. - show: Whether to display the figure or return axis. + color_map: The color map to use for plotting. + palette: The color palette to use for plotting. ax: A matplotlib axes object. Only works if plotting a single component. + {common_plot_args} **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()` Returns: - If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -2182,7 +2159,7 @@ def plot_effects_umap( # pragma: no cover >>> modality_key="coda", >>> reference_cell_type="18", >>> formula="condition", - >>> pen_args={"phi": 0, "lambda_1": 3.5}, + >>> pen_args=dict(phi=0, lambda_1=3.5), >>> tree_key="tree" >>> ) >>> tasccoda_model.run_nuts( @@ -2220,7 +2197,7 @@ def plot_effects_umap( # pragma: no cover else: vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name)) - return sc.pl.umap( + fig = sc.pl.umap( data_rna, color=effect_name, vmax=vmax, @@ -2229,11 +2206,16 @@ def plot_effects_umap( # pragma: no cover color_map=color_map, return_fig=return_fig, ax=ax, - show=show, - save=save, + show=False, **kwargs, ) + if show: + plt.show() + if return_fig: + return fig + return None + def get_a( tree: tt.core.ToyTree, diff --git a/pertpy/tools/_dialogue.py b/pertpy/tools/_dialogue.py index becd4a7b..9eaf349e 100644 --- a/pertpy/tools/_dialogue.py +++ b/pertpy/tools/_dialogue.py @@ -25,6 +25,8 @@ from sparsecca import lp_pmd, multicca_permute, multicca_pmd from statsmodels.sandbox.stats.multicomp import multipletests +from pertpy._doc import _doc_params, doc_common_plot_args + if TYPE_CHECKING: from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -1059,18 +1061,18 @@ def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1): return rank_dfs + @_doc_params(common_plot_args=doc_common_plot_args) def plot_split_violins( self, adata: AnnData, split_key: str, celltype_key: str, + *, split_which: tuple[str, str] = None, mcp: str = "mcp_0", - return_fig: bool | None = None, - ax: Axes | None = None, - save: bool | str | None = None, - show: bool | None = None, - ) -> Axes | Figure | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plots split violin plots for a given MCP and split variable. Any cells with a value for split_key not in split_which are removed from the plot. @@ -1081,9 +1083,10 @@ def plot_split_violins( celltype_key: Key for cell type annotations. split_which: Which values of split_key to plot. Required if more than 2 values in split_key. mcp: Key for MCP data. + {common_plot_args} Returns: - A :class:`~matplotlib.axes.Axes` object + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1105,30 +1108,26 @@ def plot_split_violins( df[split_key] = df[split_key].cat.remove_unused_categories() ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True) - ax.set_xticklabels(ax.get_xticklabels(), rotation=90) - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return ax return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_pairplot( self, adata: AnnData, celltype_key: str, color: str, sample_id: str, + *, mcp: str = "mcp_0", - return_fig: bool | None = None, - show: bool | None = None, - save: bool | str | None = None, - ) -> PairGrid | Figure | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Generate a pairplot visualization for multi-cell perturbation (MCP) data. Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type, @@ -1140,9 +1139,10 @@ def plot_pairplot( color: Key in `adata.obs` for color annotations. This parameter is used as the hue sample_id: Key in `adata.obs` for the sample annotations. mcp: Key in `adata.obs` for MCP feature values. + {common_plot_args} Returns: - Seaborn Pairgrid object. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1165,14 +1165,10 @@ def plot_pairplot( aggstats = aggstats.loc[list(mcp_pivot.index), :] aggstats[color] = aggstats["top"] mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1) - ax = sns.pairplot(mcp_pivot, hue=color, corner=True) + sns.pairplot(mcp_pivot, hue=color, corner=True) - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return ax return None diff --git a/pertpy/tools/_differential_gene_expression/_base.py b/pertpy/tools/_differential_gene_expression/_base.py index 4ec2e841..c6c30293 100644 --- a/pertpy/tools/_differential_gene_expression/_base.py +++ b/pertpy/tools/_differential_gene_expression/_base.py @@ -1,7 +1,9 @@ +import math import os from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass -from itertools import chain +from itertools import chain, zip_longest from types import MappingProxyType import adjustText @@ -10,9 +12,15 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +import scipy import seaborn as sns +import statsmodels +from lamin_utils import logger +from matplotlib.pyplot import Figure from matplotlib.ticker import MaxNLocator +from pertpy._doc import _doc_params, doc_common_plot_args +from pertpy.tools import PseudobulkSpace from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix from pertpy.tools._differential_gene_expression._formulaic import ( AmbiguousAttributeError, @@ -91,9 +99,28 @@ def compare_groups( Returns: Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column. + + Examples: + >>> # Example with EdgeR + >>> import pertpy as pt + >>> adata = pt.dt.zhang_2021() + >>> adata.layers["counts"] = adata.X.copy() + >>> ps = pt.tl.PseudobulkSpace() + >>> pdata = ps.compute( + ... adata, + ... target_col="Patient", + ... groups_col="Cluster", + ... layer_key="counts", + ... mode="sum", + ... min_cells=10, + ... min_counts=1000, + ... ) + >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment") + >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"]) """ ... + @_doc_params(common_plot_args=doc_common_plot_args) def plot_volcano( self, data: pd.DataFrame | ad.AnnData, @@ -115,13 +142,14 @@ def plot_volcano( figsize: tuple[int, int] = (5, 5), legend_pos: tuple[float, float] = (1.6, 1), point_sizes: tuple[int, int] = (15, 150), - save: bool | str | None = None, shapes: list[str] | None = None, shape_order: list[str] | None = None, x_label: str | None = None, y_label: str | None = None, + show: bool = True, + return_fig: bool = False, **kwargs: int, - ) -> None: + ) -> Figure | None: """Creates a volcano plot from a pandas DataFrame or Anndata. Args: @@ -143,12 +171,40 @@ def plot_volcano( top_right_frame: Whether to show the top and right frame of the plot. figsize: Size of the figure. legend_pos: Position of the legend as determined by matplotlib. - save: Saves the plot if True or to the path provided. shapes: List of matplotlib marker ids. shape_order: Order of categories for shapes. x_label: Label for the x-axis. y_label: Label for the y-axis. + {common_plot_args} **kwargs: Additional arguments for seaborn.scatterplot. + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + + Examples: + >>> # Example with EdgeR + >>> import pertpy as pt + >>> adata = pt.dt.zhang_2021() + >>> adata.layers["counts"] = adata.X.copy() + >>> ps = pt.tl.PseudobulkSpace() + >>> pdata = ps.compute( + ... adata, + ... target_col="Patient", + ... groups_col="Cluster", + ... layer_key="counts", + ... mode="sum", + ... min_cells=10, + ... min_counts=1000, + ... ) + >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment") + >>> edgr.fit() + >>> res_df = edgr.test_contrasts( + ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo") + ... ) + >>> edgr.plot_volcano(res_df, log2fc_thresh=0) + + Preview: + .. image:: /_static/docstring_previews/de_volcano.png """ if colors is None: colors = ["gray", "#D62728", "#1F77B4"] @@ -243,7 +299,7 @@ def _map_genes_categories_highlight( if varm_key is None: raise ValueError("Please pass a .varm key to use for plotting") - raise NotImplementedError("Anndata not implemented yet") + raise NotImplementedError("Anndata not implemented yet") # TODO: Implement this df = data.varm[varm_key].copy() df = data.copy(deep=True) @@ -449,20 +505,407 @@ def _map_genes_categories_highlight( plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False) - # TODO replace with scanpy save style - if save: - files = os.listdir() - for x in range(100): - file_pref = "volcano_" + "%02d" % (x,) - if len([x for x in files if x.startswith(file_pref)]) == 0: - plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight") - plt.savefig(file_pref + ".svg", bbox_inches="tight") - break - elif isinstance(save, str): - plt.savefig(save + ".png", dpi=300, bbox_inches="tight") - plt.savefig(save + ".svg", bbox_inches="tight") + if show: + plt.show() + if return_fig: + return plt.gcf() + return None + + @_doc_params(common_plot_args=doc_common_plot_args) + def plot_paired( + self, + adata: ad.AnnData, + results_df: pd.DataFrame, + groupby: str, + pairedby: str, + *, + var_names: Sequence[str] = None, + n_top_vars: int = 15, + layer: str = None, + pvalue_col: str = "adj_p_value", + symbol_col: str = "variable", + n_cols: int = 4, + panel_size: tuple[int, int] = (5, 5), + show_legend: bool = True, + size: int = 10, + y_label: str = "expression", + pvalue_template=lambda x: f"p={x:.2e}", + boxplot_properties=None, + palette=None, + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: + """Creates a pairwise expression plot from a Pandas DataFrame or Anndata. + + Visualizes a panel of paired scatterplots per variable. + + Args: + adata: AnnData object, can be pseudobulked. + results_df: DataFrame with results from a differential expression test. + groupby: .obs column containing the grouping. Must contain exactly two different values. + pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed. + var_names: Variables to plot. + n_top_vars: Number of top variables to plot. + layer: Layer to use for plotting. + pvalue_col: Column name of the p values. + symbol_col: Column name of gene IDs. + n_cols: Number of columns in the plot. + panel_size: Size of each panel. + show_legend: Whether to show the legend. + size: Size of the points. + y_label: Label for the y-axis. + pvalue_template: Template for the p-value string displayed in the title of each panel. + boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot. + palette: Color palette for the line- and stripplot. + {common_plot_args} + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + + Examples: + >>> # Example with EdgeR + >>> import pertpy as pt + >>> adata = pt.dt.zhang_2021() + >>> adata.layers["counts"] = adata.X.copy() + >>> ps = pt.tl.PseudobulkSpace() + >>> pdata = ps.compute( + ... adata, + ... target_col="Patient", + ... groups_col="Cluster", + ... layer_key="counts", + ... mode="sum", + ... min_cells=10, + ... min_counts=1000, + ... ) + >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment") + >>> edgr.fit() + >>> res_df = edgr.test_contrasts( + ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo") + ... ) + >>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy") + + Preview: + .. image:: /_static/docstring_previews/de_paired_expression.png + """ + if boxplot_properties is None: + boxplot_properties = {} + groups = adata.obs[groupby].unique() + if len(groups) != 2: + raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing") + + if var_names is None: + var_names = results_df.head(n_top_vars)[symbol_col].tolist() + + adata = adata[:, var_names] + + if any(adata.obs[[groupby, pairedby]].value_counts() > 1): + logger.info("Performing pseudobulk for paired samples") + ps = PseudobulkSpace() + adata = ps.compute( + adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1 + ) + + if layer is not None: + X = adata.layers[layer] + else: + X = adata.X + try: + X = X.toarray() + except AttributeError: + pass + + groupby_cols = [pairedby, groupby] + df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names)) + + # remove unpaired samples + paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby]) + df = df[df[pairedby].isin(paired_samples)] + removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique()) + if removed_samples > 0: + logger.warning(f"{removed_samples} unpaired samples removed") + + pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values + df.reset_index(drop=False, inplace=True) + + # transform data for seaborn + df_melt = df.melt( + id_vars=groupby_cols, + var_name="var", + value_name="val", + ) + + n_panels = len(var_names) + nrows = math.ceil(n_panels / n_cols) + ncols = min(n_cols, n_panels) + + fig, axes = plt.subplots( + nrows, + ncols, + figsize=(ncols * panel_size[0], nrows * panel_size[1]), + tight_layout=True, + squeeze=False, + ) + axes = axes.flatten() + for i, (var, ax) in enumerate(zip_longest(var_names, axes)): + if var is not None: + sns.boxplot( + x=groupby, + data=df_melt.loc[df_melt["var"] == var], + y="val", + ax=ax, + color="white", + fliersize=0, + **boxplot_properties, + ) + if pairedby is not None: + sns.lineplot( + x=groupby, + data=df_melt.loc[df_melt["var"] == var], + y="val", + ax=ax, + hue=pairedby, + legend=False, + errorbar=None, + palette=palette, + ) + jitter = 0 if pairedby else True + sns.stripplot( + x=groupby, + data=df_melt.loc[df_melt["var"] == var], + y="val", + ax=ax, + hue=pairedby, + jitter=jitter, + size=size, + linewidth=1, + palette=palette, + ) + + ax.set_xlabel("") + ax.tick_params( + axis="x", + labelsize=15, + ) + ax.legend().set_visible(False) + ax.set_ylabel(y_label) + ax.set_title(f"{var}\n{pvalue_template(pvalues[i])}") + else: + ax.set_visible(False) + fig.tight_layout() + + if show_legend is True: + axes[n_panels - 1].legend().set_visible(True) + axes[n_panels - 1].legend( + bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique() + ) + + plt.tight_layout() + if show: + plt.show() + if return_fig: + return plt.gcf() + return None + + @_doc_params(common_plot_args=doc_common_plot_args) + def plot_fold_change( + self, + results_df: pd.DataFrame, + *, + var_names: Sequence[str] = None, + n_top_vars: int = 15, + log2fc_col: str = "log_fc", + symbol_col: str = "variable", + y_label: str = "Log2 fold change", + figsize: tuple[int, int] = (10, 5), + show: bool = True, + return_fig: bool = False, + **barplot_kwargs, + ) -> Figure | None: + """Plot a metric from the results as a bar chart, optionally with additional information about paired samples in a scatter plot. + + Args: + results_df: DataFrame with results from DE analysis. + var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted. + n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively. + log2fc_col: Column name of log2 Fold-Change values. + symbol_col: Column name of gene IDs. + y_label: Label for the y-axis. + figsize: Size of the figure. + {common_plot_args} + **barplot_kwargs: Additional arguments for seaborn.barplot. + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + + Examples: + >>> # Example with EdgeR + >>> import pertpy as pt + >>> adata = pt.dt.zhang_2021() + >>> adata.layers["counts"] = adata.X.copy() + >>> ps = pt.tl.PseudobulkSpace() + >>> pdata = ps.compute( + ... adata, + ... target_col="Patient", + ... groups_col="Cluster", + ... layer_key="counts", + ... mode="sum", + ... min_cells=10, + ... min_counts=1000, + ... ) + >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment") + >>> edgr.fit() + >>> res_df = edgr.test_contrasts( + ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo") + ... ) + >>> edgr.plot_fold_change(res_df) + + Preview: + .. image:: /_static/docstring_previews/de_fold_change.png + """ + if var_names is None: + var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist() + var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist() + assert len(var_names) == 2 * n_top_vars + + df = results_df[results_df[symbol_col].isin(var_names)] + df.sort_values(log2fc_col, ascending=False, inplace=True) + + plt.figure(figsize=figsize) + sns.barplot( + x=symbol_col, + y=log2fc_col, + data=df, + palette="RdBu", + legend=False, + **barplot_kwargs, + ) + plt.xticks(rotation=90) + plt.xlabel("") + plt.ylabel(y_label) + + if show: + plt.show() + if return_fig: + return plt.gcf() + return None + + @_doc_params(common_plot_args=doc_common_plot_args) + def plot_multicomparison_fc( + self, + results_df: pd.DataFrame, + *, + n_top_vars=15, + contrast_col: str = "contrast", + log2fc_col: str = "log_fc", + pvalue_col: str = "adj_p_value", + symbol_col: str = "variable", + marker_size: int = 100, + figsize: tuple[int, int] = (10, 2), + x_label: str = "Contrast", + y_label: str = "Gene", + show: bool = True, + return_fig: bool = False, + **heatmap_kwargs, + ) -> Figure | None: + """Plot a matrix of log2 fold changes from the results. + + Args: + results_df: DataFrame with results from DE analysis. + n_top_vars: Number of top variables to plot per group. + contrast_col: Column in results_df containing information about the contrast. + log2fc_col: Column in results_df containing the log2 fold change. + pvalue_col: Column in results_df containing the p-value. Can be used to switch between adjusted and unadjusted p-values. + symbol_col: Column in results_df containing the gene symbol. + marker_size: Size of the biggest marker for significant variables. + figsize: Size of the figure. + x_label: Label for the x-axis. + y_label: Label for the y-axis. + {common_plot_args} + **heatmap_kwargs: Additional arguments for seaborn.heatmap. + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + + Examples: + >>> # Example with EdgeR + >>> import pertpy as pt + >>> adata = pt.dt.zhang_2021() + >>> adata.layers["counts"] = adata.X.copy() + >>> ps = pt.tl.PseudobulkSpace() + >>> pdata = ps.compute( + ... adata, + ... target_col="Patient", + ... groups_col="Cluster", + ... layer_key="counts", + ... mode="sum", + ... min_cells=10, + ... min_counts=1000, + ... ) + >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment") + >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"]) + >>> edgr.plot_multicomparison_fc(res_df) + + Preview: + .. image:: /_static/docstring_previews/de_multicomparison_fc.png + """ + groups = results_df[contrast_col].unique().tolist() + + results_df["abs_log_fc"] = results_df[log2fc_col].abs() + + def _get_significance(p_val): + if p_val < 0.001: + return "< 0.001" + elif p_val < 0.01: + return "< 0.01" + elif p_val < 0.1: + return "< 0.1" + else: + return "n.s." + + results_df["significance"] = results_df[pvalue_col].apply(_get_significance) + + var_names = [] + for group in groups: + var_names += ( + results_df[results_df[contrast_col] == group] + .sort_values("abs_log_fc", ascending=False) + .head(n_top_vars)[symbol_col] + .tolist() + ) + + results_df = results_df[results_df[symbol_col].isin(var_names)] + df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names] + + plt.figure(figsize=figsize) + sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"}) + + _size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)} + x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]] + y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]] + + for _i, row in results_df.iterrows(): + if row["significance"] != "n.s.": + plt.scatter( + x=x_locs[x_labels.index(row[symbol_col])], + y=y_locs[y_labels.index(row[contrast_col])], + s=_size[row["significance"]], + marker="*", + c="white", + ) + + plt.scatter([], [], s=marker_size, marker="*", c="black", label="< 0.001") + plt.scatter([], [], s=math.floor(marker_size / 2), marker="*", c="black", label="< 0.01") + plt.scatter([], [], s=math.floor(marker_size / 4), marker="*", c="black", label="< 0.1") + plt.legend(title="Significance", bbox_to_anchor=(1.2, -0.05)) + + plt.xlabel(x_label) + plt.ylabel(y_label) - plt.show() + if show: + plt.show() + if return_fig: + return plt.gcf() + return None class LinearModelBase(MethodBase): @@ -583,9 +1026,10 @@ def test_reduced(self, modelB): modelB: the reduced model against which to test. Example: - modelA = Model().fit() - modelB = Model().fit() - modelA.test_reduced(modelB) + >>> import pertpy as pt + >>> modelA = Model().fit() + >>> modelB = Model().fit() + >>> modelA.test_reduced(modelB) """ raise NotImplementedError diff --git a/pertpy/tools/_differential_gene_expression/_edger.py b/pertpy/tools/_differential_gene_expression/_edger.py index f50c1300..98849476 100644 --- a/pertpy/tools/_differential_gene_expression/_edger.py +++ b/pertpy/tools/_differential_gene_expression/_edger.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd -from scanpy import logging +from lamin_utils import logger from scipy.sparse import issparse from ._base import LinearModelBase @@ -60,13 +60,13 @@ def fit(self, **kwargs): # adata, design, mask, layer dge = edger.DGEList(counts=expr_r, samples=self.adata.obs) - logging.info("Calculating NormFactors") + logger.info("Calculating NormFactors") dge = edger.calcNormFactors(dge) - logging.info("Estimating Dispersions") + logger.info("Estimating Dispersions") dge = edger.estimateDisp(dge, design=self.design) - logging.info("Fitting linear model") + logger.info("Fitting linear model") fit = edger.glmQLFit(dge, design=self.design, **kwargs) ro.globalenv["fit"] = fit diff --git a/pertpy/tools/_enrichment.py b/pertpy/tools/_enrichment.py index 7574cd7a..8c6b6bc5 100644 --- a/pertpy/tools/_enrichment.py +++ b/pertpy/tools/_enrichment.py @@ -3,6 +3,7 @@ from typing import Any, Literal import blitzgsea +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc @@ -14,6 +15,7 @@ from scipy.stats import hypergeom from statsmodels.stats.multitest import multipletests +from pertpy._doc import _doc_params, doc_common_plot_args from pertpy.metadata import Drug @@ -290,9 +292,11 @@ def gsea( return enrichment + @_doc_params(common_plot_args=doc_common_plot_args) def plot_dotplot( self, adata: AnnData, + *, targets: dict[str, dict[str, list[str]]] = None, source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl", category_name: str = "interaction_type", @@ -300,10 +304,10 @@ def plot_dotplot( groupby: str = None, key: str = "pertpy_enrichment", ax: Axes | None = None, - save: bool | str | None = None, - show: bool | None = None, + show: bool = True, + return_fig: bool = False, **kwargs, - ) -> DotPlot | dict | None: + ) -> DotPlot | None: """Plots a dotplot by groupby and categories. Wraps scanpy's dotplot but formats it nicely by categories. @@ -319,11 +323,11 @@ def plot_dotplot( category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`. groupby: dotplot groupby such as clusters or cell types. key: Prefix key of enrichment results in `uns`. + {common_plot_args} kwargs: Passed to scanpy dotplot. Returns: - If `return_fig` is `True`, returns a :class:`~scanpy.pl.DotPlot` object, - else if `show` is false, return axes dict. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -403,21 +407,27 @@ def plot_dotplot( "var_group_labels": var_group_labels, } - return sc.pl.dotplot( + fig = sc.pl.dotplot( enrichment_score_adata, groupby=groupby, swap_axes=True, ax=ax, - save=save, - show=show, + show=False, **plot_args, **kwargs, ) + if show: + plt.show() + if return_fig: + return fig + return None + def plot_gsea( self, adata: AnnData, enrichment: dict[str, pd.DataFrame], + *, n: int = 10, key: str = "pertpy_enrichment_gsea", interactive_plot: bool = False, diff --git a/pertpy/tools/_milo.py b/pertpy/tools/_milo.py index 6fa99959..53bda6b9 100644 --- a/pertpy/tools/_milo.py +++ b/pertpy/tools/_milo.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import random import re from typing import TYPE_CHECKING, Literal @@ -14,6 +13,8 @@ from lamin_utils import logger from mudata import MuData +from pertpy._doc import _doc_params, doc_common_plot_args + if TYPE_CHECKING: from collections.abc import Sequence @@ -125,7 +126,7 @@ def make_nhoods( try: use_rep = adata.uns["neighbors"]["params"]["use_rep"] except KeyError: - logging.warning("Using X_pca as default embedding") + logger.warning("Using X_pca as default embedding") use_rep = "X_pca" try: knn_graph = adata.obsp["connectivities"].copy() @@ -136,7 +137,7 @@ def make_nhoods( try: use_rep = adata.uns[neighbors_key]["params"]["use_rep"] except KeyError: - logging.warning("Using X_pca as default embedding") + logger.warning("Using X_pca as default embedding") use_rep = "X_pca" knn_graph = adata.obsp[neighbors_key + "_connectivities"].copy() @@ -713,9 +714,11 @@ def _graph_spatial_fdr( sample_adata.var["SpatialFDR"] = np.nan sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp + @_doc_params(common_plot_args=doc_common_plot_args) def plot_nhood_graph( self, mdata: MuData, + *, alpha: float = 0.1, min_logFC: float = 0, min_size: int = 10, @@ -724,10 +727,10 @@ def plot_nhood_graph( color_map: Colormap | str | None = None, palette: str | Sequence[str] | None = None, ax: Axes | None = None, - show: bool | None = None, - save: bool | str | None = None, + show: bool = True, + return_fig: bool = False, **kwargs, - ) -> None: + ) -> Figure | None: """Visualize DA results on abstracted graph (wrapper around sc.pl.embedding) Args: @@ -737,9 +740,7 @@ def plot_nhood_graph( min_size: Minimum size of nodes in visualization. (default: 10) plot_edges: If edges for neighbourhood overlaps whould be plotted. title: Plot title. - show: Show the plot, do not return axis. - save: If `True` or a `str`, save the figure. A string is appended to the default filename. - Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. + {common_plot_args} **kwargs: Additional arguments to `scanpy.pl.embedding`. Examples: @@ -782,7 +783,7 @@ def plot_nhood_graph( vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())]) vmin = -vmax - sc.pl.embedding( + fig = sc.pl.embedding( nhood_adata, "X_milo_graph", color="graph_color", @@ -798,33 +799,42 @@ def plot_nhood_graph( color_map=color_map, palette=palette, ax=ax, - show=show, - save=save, + show=False, **kwargs, ) + if show: + plt.show() + if return_fig: + return fig + return None + + @_doc_params(common_plot_args=doc_common_plot_args) def plot_nhood( self, mdata: MuData, ix: int, + *, feature_key: str | None = "rna", basis: str = "X_umap", color_map: Colormap | str | None = None, palette: str | Sequence[str] | None = None, - return_fig: bool | None = None, ax: Axes | None = None, - show: bool | None = None, - save: bool | str | None = None, + show: bool = True, + return_fig: bool = False, **kwargs, - ) -> None: + ) -> Figure | None: """Visualize cells in a neighbourhood. Args: mdata: MuData object with feature_key slot, storing neighbourhood assignments in `mdata[feature_key].obsm['nhoods']` ix: index of neighbourhood to visualize + feature_key: Key in mdata to the cell-level AnnData object. basis: Embedding to use for visualization. - show: Show the plot, do not return axis. - save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}. + color_map: Colormap to use for coloring. + palette: Color palette to use for coloring. + ax: Axes to plot on. + {common_plot_args} **kwargs: Additional arguments to `scanpy.pl.embedding`. Examples: @@ -842,7 +852,7 @@ def plot_nhood( .. image:: /_static/docstring_previews/milo_nhood.png """ mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel() - sc.pl.embedding( + fig = sc.pl.embedding( mdata[feature_key], basis, color="Nhood", @@ -852,32 +862,43 @@ def plot_nhood( palette=palette, return_fig=return_fig, ax=ax, - show=show, - save=save, + show=False, **kwargs, ) + if show: + plt.show() + if return_fig: + return fig + return None + + @_doc_params(common_plot_args=doc_common_plot_args) def plot_da_beeswarm( self, mdata: MuData, + *, feature_key: str | None = "rna", anno_col: str = "nhood_annotation", alpha: float = 0.1, subset_nhoods: list[str] = None, palette: str | Sequence[str] | dict[str, str] | None = None, - return_fig: bool | None = None, - save: bool | str | None = None, - show: bool | None = None, - ) -> Figure | Axes | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plot beeswarm plot of logFC against nhood labels Args: mdata: MuData object + feature_key: Key in mdata to the cell-level AnnData object. anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.) alpha: Significance threshold. (default: 0.1) subset_nhoods: List of nhoods to plot. If None, plot all nhoods. palette: Name of Seaborn color palette for violinplots. Defaults to pre-defined category colors for violinplots. + {common_plot_args} + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -973,29 +994,23 @@ def plot_da_beeswarm( plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False) plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--") - if save: - plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None if return_fig: return plt.gcf() - if (not show and not save) or (show is None and save is None): - return plt.gca() - return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_nhood_counts_by_cond( self, mdata: MuData, test_var: str, + *, subset_nhoods: list[str] = None, log_counts: bool = False, - return_fig: bool | None = None, - save: bool | str | None = None, - show: bool | None = None, - ) -> Figure | Axes | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plot boxplot of cell numbers vs condition of interest. Args: @@ -1003,6 +1018,10 @@ def plot_nhood_counts_by_cond( test_var: Name of column in adata.obs storing condition of interest (y-axis for boxplot) subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods. log_counts: Whether to plot log1p of cell counts. + {common_plot_args} + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. """ try: nhood_adata = mdata["milo"].T.copy() @@ -1031,15 +1050,8 @@ def plot_nhood_counts_by_cond( plt.xticks(rotation=90) plt.xlabel(test_var) - if save: - plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None if return_fig: return plt.gcf() - if not (show or save): - return plt.gca() - return None diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 2184714c..ade7ea49 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -18,6 +18,7 @@ from sklearn.mixture import GaussianMixture import pertpy as pt +from pertpy._doc import _doc_params, doc_common_plot_args if TYPE_CHECKING: from collections.abc import Sequence @@ -25,6 +26,7 @@ from anndata import AnnData from matplotlib.axes import Axes from matplotlib.colors import Colormap + from matplotlib.pyplot import Figure from scipy import sparse @@ -506,21 +508,21 @@ def _define_normal_mixscape(self, X: np.ndarray | sparse.spmatrix | pd.DataFrame return [mu, sd] + @_doc_params(common_plot_args=doc_common_plot_args) def plot_barplot( # pragma: no cover self, adata: AnnData, guide_rna_column: str, + *, mixscape_class_global: str = "mixscape_class_global", axis_text_x_size: int = 8, axis_text_y_size: int = 6, axis_title_size: int = 8, legend_title_size: int = 8, legend_text_size: int = 8, - return_fig: bool | None = None, - ax: Axes | None = None, - show: bool | None = None, - save: bool | str | None = None, - ): + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Barplot to visualize perturbation scores calculated by the `mixscape` function. Args: @@ -528,12 +530,10 @@ def plot_barplot( # pragma: no cover guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels. The format must be g<#>. Examples are 'STAT2g1' and 'ATF2g1'. mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT). - show: Show the plot, do not return axis. - save: If True or a str, save the figure. A string is appended to the default filename. - Infer the filetype if ending on {'.pdf', '.png', '.svg'}. + {common_plot_args} Returns: - If `show==False`, return a :class:`~matplotlib.axes.Axes. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -565,33 +565,31 @@ def plot_barplot( # pragma: no cover all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"] NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"] - if show: - color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"} - unique_genes = NP_KO_cells["gene"].unique() - fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True) - for i, gene in enumerate(unique_genes): - ax = axs[int(i / 5), i % 5] - grouped_df = ( - NP_KO_cells[NP_KO_cells["gene"] == gene] - .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"] - .sum() - .unstack() - ) - grouped_df.plot( - kind="bar", - stacked=True, - color=[color_mapping[col] for col in grouped_df.columns], - ax=ax, - width=0.8, - legend=False, - ) - ax.set_title( - gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size - ) - ax.set(xlabel="sgRNA", ylabel="% of cells") - sns.despine(ax=ax, top=True, right=True, left=False, bottom=False) - ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size) - ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size) + color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"} + unique_genes = NP_KO_cells["gene"].unique() + fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True) + for i, gene in enumerate(unique_genes): + ax = axs[int(i / 5), i % 5] + grouped_df = ( + NP_KO_cells[NP_KO_cells["gene"] == gene] + .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"] + .sum() + .unstack() + ) + grouped_df.plot( + kind="bar", + stacked=True, + color=[color_mapping[col] for col in grouped_df.columns], + ax=ax, + width=0.8, + legend=False, + ) + ax.set_title(gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size) + ax.set(xlabel="sgRNA", ylabel="% of cells") + sns.despine(ax=ax, top=True, right=True, left=False, bottom=False) + ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size) + ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size) + fig.subplots_adjust(right=0.8) fig.subplots_adjust(hspace=0.5, wspace=0.5) ax.legend( @@ -603,25 +601,29 @@ def plot_barplot( # pragma: no cover title_fontsize=legend_title_size, ) - plt.tight_layout() - _utils.savefig_or_show("mixscape_barplot", show=show, save=save) + if show: + plt.show() + if return_fig: + return plt.gcf() + return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_heatmap( # pragma: no cover self, adata: AnnData, labels: str, target_gene: str, control: str, + *, layer: str | None = None, method: str | None = "wilcoxon", subsample_number: int | None = 900, vmin: float | None = -2, vmax: float | None = 2, - return_fig: bool | None = None, - show: bool | None = None, - save: bool | str | None = None, + show: bool = True, + return_fig: bool = False, **kwds, - ) -> Axes | None: + ) -> Figure | None: """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first. Args: @@ -634,14 +636,11 @@ def plot_heatmap( # pragma: no cover subsample_number: Subsample to this number of observations. vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin. vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax. - show: Show the plot, do not return axis. - save: If `True` or a `str`, save the figure. A string is appended to the default filename. - Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. - ax: A matplotlib axes object. Only works if plotting a single component. + {common_plot_args} **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`. Returns: - If `show==False`, return a :class:`~matplotlib.axes.Axes`. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -663,35 +662,39 @@ def plot_heatmap( # pragma: no cover sc.pp.scale(adata_subset, max_value=vmax) sc.pp.subsample(adata_subset, n_obs=subsample_number) - return sc.pl.rank_genes_groups_heatmap( + fig = sc.pl.rank_genes_groups_heatmap( adata_subset, groupby="mixscape_class", vmin=vmin, vmax=vmax, n_genes=20, groups=["NT"], - return_fig=return_fig, - show=show, - save=save, + show=False, **kwds, ) + if show: + plt.show() + if return_fig: + return fig + return None + + @_doc_params(common_plot_args=doc_common_plot_args) def plot_perturbscore( # pragma: no cover self, adata: AnnData, labels: str, target_gene: str, + *, mixscape_class: str = "mixscape_class", color: str = "orange", palette: dict[str, str] = None, split_by: str = None, before_mixscape: bool = False, perturbation_type: str = "KO", - return_fig: bool | None = None, - ax: Axes | None = None, - show: bool | None = None, - save: bool | str | None = None, - ) -> None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. Requires `pt.tl.mixscape` to be run first. @@ -710,6 +713,10 @@ def plot_perturbscore( # pragma: no cover before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. Default is set to NULL and plots cells by original class ID. perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications. + {common_plot_args} + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: Visualizing the perturbation scores for the cells in a dataset: @@ -778,14 +785,11 @@ def plot_perturbscore( # pragma: no cover plt.legend(title="gene_target", title_fontsize=14, fontsize=12) sns.despine() - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return plt.gca() + return None # If before_mixscape is False, split densities based on mixscape classifications else: @@ -843,19 +847,18 @@ def plot_perturbscore( # pragma: no cover plt.legend(title="mixscape class", title_fontsize=14, fontsize=12) sns.despine() - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() if return_fig: return plt.gcf() - if not (show or save): - return plt.gca() + return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_violin( # pragma: no cover self, adata: AnnData, target_gene_idents: str | list[str], + *, keys: str | Sequence[str] = "mixscape_class_p_ko", groupby: str | None = "mixscape_class", log: bool = False, @@ -872,10 +875,10 @@ def plot_violin( # pragma: no cover ylabel: str | Sequence[str] | None = None, rotation: float | None = None, ax: Axes | None = None, - show: bool | None = None, - save: bool | str | None = None, + show: bool = True, + return_fig: bool = False, **kwargs, - ): + ) -> Axes | Figure | None: """Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first. @@ -892,14 +895,12 @@ def plot_violin( # pragma: no cover xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown. ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`. If `None` and `groubpy` is not `None`, defaults to `keys`. - show: Show the plot, do not return axis. - save: If `True` or a `str`, save the figure. A string is appended to the default filename. - Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. ax: A matplotlib axes object. Only works if plotting a single component. + {common_plot_args} **kwargs: Additional arguments to `seaborn.violinplot`. Returns: - A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`. + If `return_fig` is `True`, returns the figure (as Axes list if it's a multi-panel plot), otherwise `None`. Examples: >>> import pertpy as pt @@ -1045,20 +1046,24 @@ def plot_violin( # pragma: no cover show = settings.autoshow if show is None else show if hue is not None and stripplot is True: plt.legend(handles, labels) - _utils.savefig_or_show("mixscape_violin", show=show, save=save) - if not show: + if show: + plt.show() + if return_fig: if multi_panel and groupby is None and len(ys) == 1: return g elif len(axs) == 1: return axs[0] else: return axs + return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_lda( # pragma: no cover self, adata: AnnData, control: str, + *, mixscape_class: str = "mixscape_class", mixscape_class_global: str = "mixscape_class_global", perturbation_type: str | None = "KO", @@ -1066,12 +1071,11 @@ def plot_lda( # pragma: no cover n_components: int | None = None, color_map: Colormap | str | None = None, palette: str | Sequence[str] | None = None, - return_fig: bool | None = None, ax: Axes | None = None, - show: bool | None = None, - save: bool | str | None = None, + show: bool = True, + return_fig: bool = False, **kwds, - ) -> None: + ) -> Figure | None: """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first. Args: @@ -1082,9 +1086,7 @@ def plot_lda( # pragma: no cover perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications. lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results. n_components: The number of dimensions of the embedding. - show: Show the plot, do not return axis. - save: If `True` or a `str`, save the figure. A string is appended to the default filename. - Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. + {common_plot_args} **kwds: Additional arguments to `scanpy.pl.umap`. Examples: @@ -1112,14 +1114,19 @@ def plot_lda( # pragma: no cover n_components = adata_subset.uns[lda_key].shape[1] sc.pp.neighbors(adata_subset, use_rep=lda_key) sc.tl.umap(adata_subset, n_components=n_components) - sc.pl.umap( + fig = sc.pl.umap( adata_subset, color=mixscape_class, palette=palette, color_map=color_map, return_fig=return_fig, - show=show, - save=save, + show=False, ax=ax, **kwds, ) + + if show: + plt.show() + if return_fig: + return fig + return None diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index 3d6c4f67..9a5a475c 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -41,7 +41,7 @@ def compute_control_diff( # type: ignore Args: adata: Anndata object of size cells x genes. target_col: .obs column name that stores the label of the perturbation applied to each cell. - group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups. + group_col: .obs column name that stores the label of the group of each cell. If None, ignore groups. reference_key: The key of the control values. layer_key: Key of the AnnData layer to use for computation. new_layer_key: the results are stored in the given layer. diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index c59604b2..a3cd954d 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -1,13 +1,20 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import decoupler as dc +import matplotlib.pyplot as plt import numpy as np from anndata import AnnData from sklearn.cluster import DBSCAN, KMeans +from pertpy._doc import _doc_params, doc_common_plot_args from pertpy.tools._perturbation_space._clustering import ClusteringSpace from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace +if TYPE_CHECKING: + from matplotlib.pyplot import Figure + class CentroidSpace(PerturbationSpace): """Computes the centroids per perturbation of a pre-computed embedding.""" @@ -168,6 +175,49 @@ def compute( return ps_adata + @_doc_params(common_plot_args=doc_common_plot_args) + def plot_psbulk_samples( + self, + adata: AnnData, + groupby: str, + *, + show: bool = True, + return_fig: bool = False, + **kwargs, + ) -> Figure | None: + """Plot the pseudobulk samples of an AnnData object. + + Plot the count number vs. the number of cells per pseudobulk sample. + + Args: + adata: Anndata containing pseudobulk samples. + groupby: `.obs` column to color the samples by. + {common_plot_args} + **kwargs: Are passed to decoupler's plot_psbulk_samples. + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.zhang_2021() + >>> ps = pt.tl.PseudobulkSpace() + >>> pdata = ps.compute( + ... adata, target_col="Patient", groups_col="Cluster", mode="sum", min_cells=10, min_counts=1000 + ... ) + >>> ps.plot_psbulk_samples(pdata, groupby=["Patient", "Major celltype"], figsize=(12, 4)) + + Preview: + .. image:: /_static/docstring_previews/pseudobulk_samples.png + """ + fig = dc.plot_psbulk_samples(adata, groupby, return_fig=True, **kwargs) + + if show: + plt.show() + if return_fig: + return fig + return None + class KMeansSpace(ClusteringSpace): """Computes K-Means clustering of the expression values.""" diff --git a/pertpy/tools/_scgen/_scgen.py b/pertpy/tools/_scgen/_scgen.py index bdd19eed..eca887ee 100644 --- a/pertpy/tools/_scgen/_scgen.py +++ b/pertpy/tools/_scgen/_scgen.py @@ -18,12 +18,16 @@ from scvi.model.base import BaseModelClass, JaxTrainingMixin from scvi.utils import setup_anndata_dsp +from pertpy._doc import _doc_params, doc_common_plot_args + from ._scgenvae import JaxSCGENVAE from ._utils import balancer, extractor if TYPE_CHECKING: from collections.abc import Sequence + from matplotlib.pyplot import Figure + font = {"family": "Arial", "size": 14} @@ -377,9 +381,8 @@ def plot_reg_mean_plot( condition_key: str, axis_keys: dict[str, str], labels: dict[str, str], - save: str | bool | None = None, + *, gene_list: list[str] = None, - show: bool = False, top_100_genes: list[str] = None, verbose: bool = False, legend: bool = True, @@ -387,6 +390,8 @@ def plot_reg_mean_plot( x_coeff: float = 0.30, y_coeff: float = 0.8, fontsize: float = 14, + show: bool = False, + save: str | bool | None = None, **kwargs, ) -> tuple[float, float] | float: """Plots mean matching for a set of specified genes. @@ -397,21 +402,23 @@ def plot_reg_mean_plot( corresponding to batch and cell type metadata, respectively. condition_key: The key for the condition axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form: - `{"x": "Key for x-axis", "y": "Key for y-axis"}`. - labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`. - path_to_save: path to save the plot. - save: Specify if the plot should be saved or not. + {`x`: `Key for x-axis`, `y`: `Key for y-axis`}. + labels: Dictionary of axes labels of the form {`x`: `x-axis-name`, `y`: `y-axis name`}. gene_list: list of gene names to be plotted. - show: if `True`: will show to the plot after saving it. top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra. - verbose: Specify if you want information to be printed while creating the plot., + verbose: Specify if you want information to be printed while creating the plot. legend: Whether to plot a legend. title: Set if you want the plot to display a title. x_coeff: Offset to print the R^2 value in x-direction. y_coeff: Offset to print the R^2 value in y-direction. fontsize: Fontsize used for text in the plot. + show: if `True`, will show to the plot after saving it. + save: Specify if the plot should be saved or not. **kwargs: + Returns: + Returns R^2 value for all genes and R^2 value for top 100 DEGs if `top_100_genes` is not `None`. + Examples: >>> import pertpy as pt >>> data = pt.dt.kang_2018() @@ -498,6 +505,7 @@ def plot_reg_mean_plot( r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}", fontsize=kwargs.get("textsize", fontsize), ) + if save: plt.savefig(save, bbox_inches="tight") if show: @@ -514,16 +522,17 @@ def plot_reg_var_plot( condition_key: str, axis_keys: dict[str, str], labels: dict[str, str], - save: str | bool | None = None, + *, gene_list: list[str] = None, top_100_genes: list[str] = None, - show: bool = False, legend: bool = True, title: str = None, verbose: bool = False, x_coeff: float = 0.3, y_coeff: float = 0.8, fontsize: float = 14, + show: bool = True, + save: str | bool | None = None, **kwargs, ) -> tuple[float, float] | float: """Plots variance matching for a set of specified genes. @@ -534,19 +543,18 @@ def plot_reg_var_plot( corresponding to batch and cell type metadata, respectively. condition_key: Key of the condition. axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form: - `{"x": "Key for x-axis", "y": "Key for y-axis"}`. - labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`. - path_to_save: path to save the plot. - save: Specify if the plot should be saved or not. + {"x": "Key for x-axis", "y": "Key for y-axis"}. + labels: Dictionary of axes labels of the form {"x": "x-axis-name", "y": "y-axis name"}. gene_list: list of gene names to be plotted. - show: if `True`: will show to the plot after saving it. top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra. - legend: Whether to plot a elgend + legend: Whether to plot a legend. title: Set if you want the plot to display a title. verbose: Specify if you want information to be printed while creating the plot. x_coeff: Offset to print the R^2 value in x-direction. y_coeff: Offset to print the R^2 value in y-direction. fontsize: Fontsize used for text in the plot. + show: if `True`, will show to the plot after saving it. + save: Specify if the plot should be saved or not. """ import seaborn as sns @@ -636,6 +644,7 @@ def plot_reg_var_plot( else: return r_value**2 + @_doc_params(common_plot_args=doc_common_plot_args) def plot_binary_classifier( self, scgen: Scgen, @@ -643,10 +652,11 @@ def plot_binary_classifier( delta: np.ndarray, ctrl_key: str, stim_key: str, - show: bool = False, - save: str | bool | None = None, + *, fontsize: float = 14, - ) -> plt.Axes | None: + show: bool = True, + return_fig: bool = False, + ) -> Figure | None: """Plots the dot product between delta and latent representation of a linear classifier. Builds a linear classifier based on the dot product between @@ -661,9 +671,11 @@ def plot_binary_classifier( delta: Difference between stimulated and control cells in latent space ctrl_key: Key for `control` part of the `data` found in `condition_key`. stim_key: Key for `stimulated` part of the `data` found in `condition_key`. - path_to_save: Path to save the plot. - save: Specify if the plot should be saved or not. fontsize: Set the font size of the plot. + {common_plot_args} + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. """ plt.close("all") adata = scgen._validate_anndata(adata) @@ -693,12 +705,10 @@ def plot_binary_classifier( ax = plt.gca() ax.grid(False) - if save: - plt.savefig(save, bbox_inches="tight") if show: plt.show() - if not (show or save): - return ax + if return_fig: + return plt.gcf() return None diff --git a/tests/tools/_differential_gene_expression/test_edger.py b/tests/tools/_differential_gene_expression/test_edger.py index b00e826b..27b4584a 100644 --- a/tests/tools/_differential_gene_expression/test_edger.py +++ b/tests/tools/_differential_gene_expression/test_edger.py @@ -6,7 +6,7 @@ def test_edger_simple(test_adata): 1. Initialized 2. Fitted - 3. and that test_contrast returns a DataFrame with the correct number of rows. + 3. That test_contrast returns a DataFrame with the correct number of rows """ method = EdgeR(adata=test_adata, design="~condition") method.fit()