Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
implement log enrichment scores
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Mar 28, 2024
1 parent 78e441a commit 91abaa3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
41 changes: 40 additions & 1 deletion src/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ def differential_abundance(
adata: AnnData | None = None,
sample_cov_keys: list[str] | None = None,
sample_subset: list[str] | None = None,
compute_log_enrichment: bool = False,
batch_size: int = 128,
) -> xr.Dataset:
"""
Expand All @@ -794,6 +795,8 @@ def differential_abundance(
when computing the differential abundance. At the moment, only discrete covariates are supported.
sample_subset
Only compute differential abundance for these sample labels.
compute_log_enrichment
Whether to compute the log enrichment scores for each covariate value.
batch_size
Minibatch size for computing the differential abundance.
Expand Down Expand Up @@ -852,12 +855,16 @@ def differential_abundance(
return log_probs_arr

sample_cov_log_probs_map = {} # maps sample_cov_key to log_probs dataframe (n_cells, n_cov_values)
sample_cov_log_enrichs_map = {} # maps sample_cov_key to log_enrichs dataframe (n_cells, n_cov_values)
for sample_cov_key in sample_cov_keys:
sample_cov_unique_values = self.sample_info[sample_cov_key].unique()
per_val_log_probs = {}
per_val_log_enrichs = {}
for sample_cov_value in sample_cov_unique_values:
cov_samples = (
self.sample_info[sample_cov_key] == sample_cov_value
self.sample_info[
self.sample_info[sample_cov_key] == sample_cov_value
]
).index.to_numpy()
if sample_subset is not None:
cov_samples = np.intersect1d(cov_samples, np.array(sample_subset))
Expand All @@ -869,9 +876,31 @@ def differential_abundance(
sel_log_probs.shape[1]
)
per_val_log_probs[sample_cov_value] = val_log_probs

if compute_log_enrichment:
rest_samples = np.setdiff1d(unique_samples, cov_samples)
if len(rest_samples) == 0:
warnings.warn(
f"All samples have {sample_cov_key}={sample_cov_value}. Skipping log enrichment computation.",
UserWarning,
stacklevel=2,
)
continue
rest_log_probs = log_probs_arr.log_probs.loc[
{"sample": rest_samples}
]
rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log(
rest_log_probs.shape[1]
)
enrichment_scores = val_log_probs - rest_val_log_probs
per_val_log_enrichs[sample_cov_value] = enrichment_scores
sample_cov_log_probs_map[sample_cov_key] = pd.DataFrame.from_dict(
per_val_log_probs
)
if compute_log_enrichment and len(per_val_log_enrichs) > 0:
sample_cov_log_enrichs_map[sample_cov_key] = pd.DataFrame.from_dict(
per_val_log_enrichs
)

coords = {
"cell_name": adata.obs_names.to_numpy(),
Expand All @@ -891,6 +920,16 @@ def differential_abundance(
for sample_cov_key, sample_cov_log_probs in sample_cov_log_probs_map.items()
},
}
if compute_log_enrichment:
data_vars.update(
{
f"{sample_cov_key}_log_enrichs": (
["cell_name", sample_cov_key],
sample_cov_log_enrichs.values,
)
for sample_cov_key, sample_cov_log_enrichs in sample_cov_log_enrichs_map.items()
}
)
return xr.Dataset(data_vars, coords=coords)

def get_outlier_cell_sample_pairs(
Expand Down
7 changes: 6 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def test_mrvi_de(adata, setup_kwargs, de_kwargs):

@pytest.mark.parametrize(
"da_kwargs",
[{"sample_cov_keys": ["meta1_cat"]}, {"sample_cov_keys": ["meta1_cat", "batch"]}],
[
{"sample_cov_keys": ["meta1_cat"]},
{"sample_cov_keys": ["meta1_cat", "batch"]},
{"sample_cov_keys": ["meta1_cat"], "compute_log_enrichment": True},
{"sample_cov_keys": ["meta1_cat", "batch"], "compute_log_enrichment": True},
],
)
def test_mrvi_da(adata, da_kwargs):
MrVI.setup_anndata(adata, sample_key="sample", batch_key="batch")
Expand Down

0 comments on commit 91abaa3

Please sign in to comment.