Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
Merge pull request #99 from YosefLab/jhong/dabugfix
Browse files Browse the repository at this point in the history
Fix de/da bug for non ordered int sample keys
  • Loading branch information
justjhong authored Apr 15, 2024
2 parents ee602c6 + 162b2a4 commit ec84c24
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 27 deletions.
13 changes: 5 additions & 8 deletions src/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __init__(
obs_df = adata.obs.copy()
obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")]
self.sample_info = obs_df.set_index("_scvi_sample").sort_index()
self.sample_key = self.adata_manager.get_state_registry("sample").original_key
self.sample_key = self.adata_manager.get_state_registry(
MRVI_REGISTRY_KEYS.SAMPLE_KEY
).original_key
self.sample_order = self.adata_manager.get_state_registry(
MRVI_REGISTRY_KEYS.SAMPLE_KEY
).categorical_mapping
Expand All @@ -127,11 +129,6 @@ def __init__(
)
self.init_params_ = self._get_init_params(locals())

@property
def original_sample_key(self):
"""Original sample key used for training the model."""
return self.adata_manager.registry["setup_args"]["sample_key"]

def to_device(self, device):
# TODO(jhong): remove this once we have a better way to handle device.
pass
Expand Down Expand Up @@ -865,7 +862,7 @@ def differential_abundance(
self.sample_info[
self.sample_info[sample_cov_key] == sample_cov_value
]
).index.to_numpy()
)[self.sample_key].to_numpy()
if sample_subset is not None:
cov_samples = np.intersect1d(cov_samples, np.array(sample_subset))
if len(cov_samples) == 0:
Expand Down Expand Up @@ -1470,7 +1467,7 @@ def _construct_design_matrix(
if (cov.dtype == str) or (cov.dtype == "category"):
cov = cov.cat.remove_unused_categories()
cov = pd.get_dummies(cov, drop_first=True)
cov_names = sample_cov_key + np.array(cov.columns)
cov_names = np.array([f"{sample_cov_key}_{col}" for col in cov.columns])
cov = cov.values
else:
cov_names = np.array([sample_cov_key])
Expand Down
36 changes: 17 additions & 19 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
def adata():
adata = synthetic_iid()
adata.obs["sample"] = np.random.choice(15, size=adata.shape[0])
adata.obs["sample_str"] = [chr(i + ord("a")) for i in adata.obs["sample"]]
meta1 = np.random.randint(0, 2, size=15)
adata.obs["meta1"] = meta1[adata.obs["sample"].values]
meta2 = np.random.randn(15)
Expand All @@ -23,7 +24,7 @@ def adata():


def test_mrvi(adata):
MrVI.setup_anndata(adata, sample_key="sample", batch_key="batch")
MrVI.setup_anndata(adata, sample_key="sample_str", batch_key="batch")
model = MrVI(adata)
model.train(1, check_val_every_n_epoch=1, train_size=0.5)
model.get_local_sample_distances()
Expand All @@ -36,7 +37,7 @@ def test_mrvi(adata):
"setup_kwargs, de_kwargs",
[
(
{"sample_key": "sample", "batch_key": "batch"},
{"sample_key": "sample_str", "batch_key": "batch"},
[
{
"sample_cov_keys": ["meta1_cat", "meta2", "cont_cov"],
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_mrvi(adata):
),
(
{
"sample_key": "sample",
"sample_key": "sample_str",
"batch_key": "batch",
"continuous_covariate_keys": ["cont_cov"],
},
Expand All @@ -77,7 +78,7 @@ def test_mrvi(adata):
],
),
(
{"sample_key": "sample", "batch_key": "dummy_batch"},
{"sample_key": "sample_str", "batch_key": "dummy_batch"},
[
{
"sample_cov_keys": ["meta1_cat", "meta2", "cont_cov"],
Expand Down Expand Up @@ -105,6 +106,10 @@ def test_mrvi_de(adata, setup_kwargs, de_kwargs):
model.differential_expression(**de_kwarg)


@pytest.mark.parametrize(
"sample_key",
["sample", "sample_str"],
)
@pytest.mark.parametrize(
"da_kwargs",
[
Expand All @@ -114,8 +119,8 @@ def test_mrvi_de(adata, setup_kwargs, de_kwargs):
{"sample_cov_keys": ["meta1_cat", "batch"], "compute_log_enrichment": True},
],
)
def test_mrvi_da(adata, da_kwargs):
MrVI.setup_anndata(adata, sample_key="sample", batch_key="batch")
def test_mrvi_da(adata, sample_key, da_kwargs):
MrVI.setup_anndata(adata, sample_key=sample_key, batch_key="batch")
model = MrVI(adata)
model.train(1, check_val_every_n_epoch=1, train_size=0.5)
model.differential_abundance(**da_kwargs)
Expand Down Expand Up @@ -153,7 +158,7 @@ def test_mrvi_da(adata, da_kwargs):
def test_mrvi_model_kwargs(adata, model_kwargs):
MrVI.setup_anndata(
adata,
sample_key="sample",
sample_key="sample_str",
batch_key="batch",
continuous_covariate_keys=["cont_cov"],
)
Expand All @@ -166,14 +171,14 @@ def test_mrvi_model_kwargs(adata, model_kwargs):
def test_mrvi_sample_subset(adata):
MrVI.setup_anndata(
adata,
sample_key="sample",
sample_key="sample_str",
batch_key="batch",
continuous_covariate_keys=["cont_cov"],
)
model = MrVI(adata)
model.train(1, check_val_every_n_epoch=1, train_size=0.5)
sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"]
sample_subset = list(range(8))
sample_subset = [chr(i + ord("a")) for i in range(8)]
model.differential_expression(
sample_cov_keys=sample_cov_keys, sample_subset=sample_subset
)
Expand All @@ -182,7 +187,7 @@ def test_mrvi_sample_subset(adata):
def test_mrvi_shrink_u(adata):
MrVI.setup_anndata(
adata,
sample_key="sample",
sample_key="sample_str",
batch_key="batch",
continuous_covariate_keys=["cont_cov"],
)
Expand All @@ -202,6 +207,7 @@ def test_mrvi_shrink_u(adata):
def adata_stratifications():
adata = synthetic_iid()
adata.obs["sample"] = np.random.choice(15, size=adata.shape[0])
adata.obs["sample_str"] = [chr(i + ord("a")) for i in adata.obs["sample"]]
meta1 = np.random.randint(0, 2, size=15)
adata.obs["meta1"] = meta1[adata.obs["sample"].values]
meta2 = np.random.randn(15)
Expand All @@ -214,7 +220,7 @@ def adata_stratifications():
def test_mrvi_stratifications(adata_stratifications):
MrVI.setup_anndata(
adata_stratifications,
sample_key="sample",
sample_key="sample_str",
batch_key="batch",
continuous_covariate_keys=["cont_cov"],
)
Expand All @@ -232,11 +238,3 @@ def test_mrvi_stratifications(adata_stratifications):
ct_dists = dists["label_2"]
assert ct_dists.shape == (2, 15, 15)
assert np.allclose(ct_dists[0].values, ct_dists[0].values.T, atol=1e-6)


def test_compute_local_statistics(adata):
n_sample = 15
adata.obs["sample"] = np.random.choice(n_sample, size=adata.shape[0])
meta1 = np.random.randint(0, 2, size=n_sample)
adata.obs["meta1"] = meta1[adata.obs["sample"].values]
MrVI.setup_anndata(adata, sample_key="sample", batch_key="batch")

0 comments on commit ec84c24

Please sign in to comment.