Skip to content

Commit

Permalink
Add enrichment (#482)
Browse files Browse the repository at this point in the history
* Add Drug metadata from chembl

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Add score

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Add docs

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Add dotplot

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Add hypergeometric

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Add gsea + plot

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Refactoring

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Python 3.10+

Signed-off-by: zethson <lukas.heumos@posteo.net>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: zethson <lukas.heumos@posteo.net>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Zethson and pre-commit-ci[bot] authored Jan 4, 2024
1 parent 2c6e816 commit 03428b3
Show file tree
Hide file tree
Showing 17 changed files with 568 additions and 119 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
matrix:
include:
- os: ubuntu-latest
python: "3.9"
python: "3.10"
- os: ubuntu-latest
python: "3.11"
- os: ubuntu-latest
Expand Down
37 changes: 30 additions & 7 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ sccoda.plot_effects_barplot(
)
```

### Multi-cellular or gene programs
### Multi-cellular and gene programs

#### DIALOGUE

Expand Down Expand Up @@ -284,6 +284,29 @@ all_results, new_mcps = dl.multilevel_modeling(
)
```

#### Enrichment

```{eval-rst}
.. currentmodule:: pertpy
```

```{eval-rst}
.. autosummary::
:toctree: tools
tools.Enrichment
```

```python
import pertpy as pt
import scanpy as sc

adata = sc.datasets.pbmc3k_processed()

pt_enricher = pt.tl.Enrichment()
pt_enricher.score(adata)
```

### Distances and Permutation Tests

General purpose functions for distances and permutation tests.
Expand Down Expand Up @@ -484,12 +507,6 @@ ps_adata = ps.compute(

See [perturbation space tutorial](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/perturbation_space.html) for a more elaborate tutorial.

## Plots

Every tool has a set of plotting functions that start with `plot_`.

However, we are planning to offer more general plots at a later point.

## MetaData

MetaData provides tooling to fetch and add more metadata to perturbations by querying a couple of databases.
Expand Down Expand Up @@ -527,4 +544,10 @@ Available databases for mechanism of action metadata:
metadata.CellLine
metadata.Compound
metadata.Moa
metadata.Drug
```

## Plots

Every tool has a set of plotting functions that start with `plot_`.
However, we are planning to offer more general plots at a later point.
2 changes: 1 addition & 1 deletion pertpy/data/_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def _download( # pragma: no cover
url: str,
output_file_name: str = None,
output_path: Union[str, Path] = None,
output_path: str | Path = None,
block_size: int = 1024,
overwrite: bool = False,
is_zip: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion pertpy/metadata/_cell_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def plot_correlation(
plt.ylabel("Baseline")
else:
subset_identifier_list = (
[subset_identifier] if isinstance(subset_identifier, (str, int)) else list(subset_identifier)
[subset_identifier] if isinstance(subset_identifier, str | int) else list(subset_identifier)
)

if all(isinstance(id, int) and 0 <= id < adata.n_obs for id in subset_identifier_list):
Expand Down
15 changes: 10 additions & 5 deletions pertpy/metadata/_drug.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import json
from collections import ChainMap
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from rich import print
from scanpy import settings
Expand All @@ -22,17 +23,21 @@ class Drug(MetaData):

def __init__(self):
# Prepared in https://github.com/theislab/pertpy-datasets/blob/main/chembl_data.ipynb
chembl_path = Path(settings.cachedir) / "chembl.parquet"
chembl_path = Path(settings.cachedir) / "chembl.json"
if not Path(chembl_path).exists():
print("[bold yellow]No metadata file was found for chembl. Starting download now.")
_download(
url="https://figshare.com/ndownloader/files/43848687",
output_file_name="chembl.parquet",
url="https://figshare.com/ndownloader/files/43871718",
output_file_name="chembl.json",
output_path=settings.cachedir,
block_size=4096,
is_zip=False,
)
self.chembl = pd.read_parquet(chembl_path)
with chembl_path.open() as file:
chembl_json = json.load(file)
self._chembl_json = chembl_json
targets = dict(ChainMap(*[chembl_json[cat] for cat in chembl_json]))
self.chembl = pd.DataFrame([{"Compound": k, "Targets": v} for k, v in targets.items()])
self.chembl.rename(columns={"Targets": "targets", "Compound": "compounds"}, inplace=True)

def annotate(self, adata: AnnData, copy: bool = False) -> AnnData:
Expand Down
106 changes: 53 additions & 53 deletions pertpy/plot/_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def __stackbar( # pragma: no cover
type_names: list[str],
title: str,
level_names: list[str],
figsize: Optional[tuple[float, float]] = None,
dpi: Optional[int] = 100,
cmap: Optional[ListedColormap] = cm.tab20,
show_legend: Optional[bool] = True,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
cmap: ListedColormap | None = cm.tab20,
show_legend: bool | None = True,
) -> plt.Axes:
"""Plots a stacked barplot for one (discrete) covariate.
Expand Down Expand Up @@ -57,7 +57,7 @@ def __stackbar( # pragma: no cover
cum_bars = np.zeros(n_bars)

for n in range(n_types):
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums)]
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
plt.bar(
r,
bars,
Expand All @@ -80,13 +80,13 @@ def __stackbar( # pragma: no cover

@staticmethod
def stacked_barplot( # pragma: no cover
data: Union[AnnData, MuData],
data: AnnData | MuData,
feature_name: str,
modality_key: str = "coda",
figsize: Optional[tuple[float, float]] = None,
dpi: Optional[int] = 100,
cmap: Optional[ListedColormap] = cm.tab20,
show_legend: Optional[bool] = True,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
cmap: ListedColormap | None = cm.tab20,
show_legend: bool | None = True,
level_order: list[str] = None,
) -> plt.Axes:
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
Expand Down Expand Up @@ -136,19 +136,19 @@ def stacked_barplot( # pragma: no cover

@staticmethod
def effects_barplot( # pragma: no cover
data: Union[AnnData, MuData],
data: AnnData | MuData,
modality_key: str = "coda",
covariates: Optional[Union[str, list]] = None,
covariates: str | list | None = None,
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
plot_facets: bool = True,
plot_zero_covariate: bool = True,
plot_zero_cell_type: bool = False,
figsize: Optional[tuple[float, float]] = None,
dpi: Optional[int] = 100,
cmap: Optional[Union[str, ListedColormap]] = cm.tab20,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
cmap: str | ListedColormap | None = cm.tab20,
level_order: list[str] = None,
args_barplot: Optional[dict] = None,
) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
args_barplot: dict | None = None,
) -> plt.Axes | sns.axisgrid.FacetGrid | None:
"""Barplot visualization for effects.
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
Expand Down Expand Up @@ -213,21 +213,21 @@ def effects_barplot( # pragma: no cover

@staticmethod
def boxplots( # pragma: no cover
data: Union[AnnData, MuData],
data: AnnData | MuData,
feature_name: str,
modality_key: str = "coda",
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
plot_facets: bool = False,
add_dots: bool = False,
cell_types: Optional[list] = None,
args_boxplot: Optional[dict] = None,
args_swarmplot: Optional[dict] = None,
figsize: Optional[tuple[float, float]] = None,
dpi: Optional[int] = 100,
cmap: Optional[str] = "Blues",
show_legend: Optional[bool] = True,
cell_types: list | None = None,
args_boxplot: dict | None = None,
args_swarmplot: dict | None = None,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
cmap: str | None = "Blues",
show_legend: bool | None = True,
level_order: list[str] = None,
) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
) -> plt.Axes | sns.axisgrid.FacetGrid | None:
"""Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots,
with intra--group separation by a covariate from data.obs.
Expand Down Expand Up @@ -292,14 +292,14 @@ def boxplots( # pragma: no cover

@staticmethod
def rel_abundance_dispersion_plot( # pragma: no cover
data: Union[AnnData, MuData],
data: AnnData | MuData,
modality_key: str = "coda",
abundant_threshold: Optional[float] = 0.9,
default_color: Optional[str] = "Grey",
abundant_color: Optional[str] = "Red",
abundant_threshold: float | None = 0.9,
default_color: str | None = "Grey",
abundant_color: str | None = "Red",
label_cell_types: bool = True,
figsize: Optional[tuple[float, float]] = None,
dpi: Optional[int] = 100,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
ax: Axes = None,
) -> plt.Axes:
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
Expand Down Expand Up @@ -357,17 +357,17 @@ def rel_abundance_dispersion_plot( # pragma: no cover

@staticmethod
def draw_tree( # pragma: no cover
data: Union[AnnData, MuData],
data: AnnData | MuData,
modality_key: str = "coda",
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
tight_text: Optional[bool] = False,
show_scale: Optional[bool] = False,
show: Optional[bool] = True,
file_name: Optional[str] = None,
units: Optional[Literal["px", "mm", "in"]] = "px",
h: Optional[float] = None,
w: Optional[float] = None,
dpi: Optional[int] = 90,
tight_text: bool | None = False,
show_scale: bool | None = False,
show: bool | None = True,
file_name: str | None = None,
units: Literal["px", "mm", "in"] | None = "px",
h: float | None = None,
w: float | None = None,
dpi: int | None = 90,
):
"""Plot a tree using input ete3 tree object.
Expand Down Expand Up @@ -439,20 +439,20 @@ def draw_tree( # pragma: no cover

@staticmethod
def draw_effects( # pragma: no cover
data: Union[AnnData, MuData],
data: AnnData | MuData,
covariate: str,
modality_key: str = "coda",
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
show_legend: Optional[bool] = None,
show_leaf_effects: Optional[bool] = False,
tight_text: Optional[bool] = False,
show_scale: Optional[bool] = False,
show: Optional[bool] = True,
file_name: Optional[str] = None,
units: Optional[Literal["px", "mm", "in"]] = "in",
h: Optional[float] = None,
w: Optional[float] = None,
dpi: Optional[int] = 90,
show_legend: bool | None = None,
show_leaf_effects: bool | None = False,
tight_text: bool | None = False,
show_scale: bool | None = False,
show: bool | None = True,
file_name: str | None = None,
units: Literal["px", "mm", "in"] | None = "in",
h: float | None = None,
w: float | None = None,
dpi: int | None = 90,
):
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
Expand Down Expand Up @@ -529,7 +529,7 @@ def draw_effects( # pragma: no cover
@staticmethod
def effects_umap( # pragma: no cover
data: MuData,
effect_name: Optional[Union[str, list]],
effect_name: str | list | None,
cluster_key: str,
modality_key_1: str = "rna",
modality_key_2: str = "coda",
Expand Down
1 change: 1 addition & 0 deletions pertpy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pertpy.tools._differential_gene_expression import DifferentialGeneExpression
from pertpy.tools._distances._distance_tests import DistanceTest
from pertpy.tools._distances._distances import Distance
from pertpy.tools._enrichment import Enrichment
from pertpy.tools._milo import Milo
from pertpy.tools._mixscape import Mixscape
from pertpy.tools._perturbation_space._clustering import ClusteringSpace
Expand Down
6 changes: 3 additions & 3 deletions pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def run_cross_validation(
# feature importances
feature_importances = defaultdict(list)
if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
feature_importances["genes"].extend(x.columns.tolist())
feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
Expand All @@ -502,7 +502,7 @@ def run_cross_validation(
# standardized coefficients with Agresti method
# cf. https://think-lab.github.io/d/205/#3
if isinstance(self.estimator, LogisticRegression):
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
feature_importances["genes"].extend(x.columns.tolist())
feature_importances["feature_importances"].extend(
(self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
Expand Down Expand Up @@ -809,7 +809,7 @@ def predict(
* (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
)

for idx, cv in zip(range(n_subsamples), results[cell_type]):
for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
results["full_results"]["idx"].extend([idx] * folds)
results["full_results"]["augur_score"].extend(cv["test_augur_score"])
results["full_results"]["folds"].extend(range(folds))
Expand Down
12 changes: 8 additions & 4 deletions pertpy/tools/_cinemaot.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def x_ordered_rank(self):
# same as pandas rank method 'first'
rankdata = ss.rankdata(randomized, method="ordinal")
# Reindexing based on pairs of indices before and after
unrandomized = [rankdata[j] for i, j in sorted(zip(randomized_indices, range(len_x)))]
unrandomized = [rankdata[j] for i, j in sorted(zip(randomized_indices, range(len_x), strict=False))]
return unrandomized

@property
Expand Down Expand Up @@ -705,6 +705,7 @@ def mean_absolute(self):
for x, y in zip(
x1,
x2,
strict=False,
)
]
)
Expand Down Expand Up @@ -751,13 +752,16 @@ def pval_asymptotic(self, ties: bool = False):
ind = [i + 1 for i in range(self.sample_size)]
ind2 = [2 * self.sample_size - 2 * ind[i - 1] + 1 for i in ind]

a = np.mean([i * j * j for i, j in zip(ind2, sorted_ordered_x_rank)]) / self.sample_size
a = np.mean([i * j * j for i, j in zip(ind2, sorted_ordered_x_rank, strict=False)]) / self.sample_size

c = np.mean([i * j for i, j in zip(ind2, sorted_ordered_x_rank)]) / self.sample_size
c = np.mean([i * j for i, j in zip(ind2, sorted_ordered_x_rank, strict=False)]) / self.sample_size

cq = np.cumsum(sorted_ordered_x_rank)

m = [(i + (self.sample_size - j) * k) / self.sample_size for i, j, k in zip(cq, ind, sorted_ordered_x_rank)]
m = [
(i + (self.sample_size - j) * k) / self.sample_size
for i, j, k in zip(cq, ind, sorted_ordered_x_rank, strict=False)
]

b = np.mean([np.square(i) for i in m])
v = (a - 2 * b + np.square(c)) / np.square(self.inverse_g_mean)
Expand Down
Loading

0 comments on commit 03428b3

Please sign in to comment.