Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix presence matrix calculation (and derived values) #1320

Merged
merged 15 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@

@attrs.define
class PresenceResult:
dataset_id: str
dataset_soma_joinid: int
eb_name: str
data: npt.NDArray[np.bool_]
cols: npt.NDArray[np.int64]


Expand All @@ -67,10 +65,6 @@ class AxisStats:
var_stats: pd.DataFrame


AccumulateXResult = tuple[PresenceResult, AxisStats]
AccumulateXResults = Sequence[AccumulateXResult]


def _assert_open_for_write(obj: somacore.SOMAObject | None) -> None:
assert obj is not None
assert obj.exists(obj.uri)
Expand Down Expand Up @@ -132,7 +126,7 @@ def __init__(self, specification: ExperimentSpecification):
self.experiment: soma.Experiment | None = None # initialized in create()
self.experiment_uri: str | None = None # initialized in create()
self.global_var_joinids: pd.DataFrame | None = None
self.presence: dict[int, tuple[npt.NDArray[np.bool_], npt.NDArray[np.int64]]] = {}
self.presence: dict[int, npt.NDArray[np.int64]] = {}

@property
def name(self) -> str:
Expand Down Expand Up @@ -242,9 +236,8 @@ def populate_presence_matrix(self, datasets: list[Dataset]) -> None:

# LIL is fast way to create spmatrix
pm = sparse.lil_matrix((max_dataset_joinid + 1, self.n_var), dtype=bool)
for dataset_joinid, presence in self.presence.items():
data, cols = presence
pm[dataset_joinid, cols] = data
for dataset_joinid, cols in self.presence.items():
pm[dataset_joinid, cols] = 1

pm = pm.tocoo()
pm.eliminate_zeros()
Expand Down Expand Up @@ -457,14 +450,12 @@ def compute_X_file_stats(

obs_stats = res["obs_stats"]
var_stats = res["var_stats"]
obs_stats["n_measured_vars"] = (var_stats.nnz > 0).sum()
var_stats.loc[var_stats.nnz > 0, "n_measured_obs"] = n_obs
obs_stats["n_measured_vars"] = var_stats.shape[0]
var_stats["n_measured_obs"] = n_obs
res["presence"].append(
PresenceResult(
dataset_id,
dataset_soma_joinid,
eb_name,
(var_stats.nnz > 0).to_numpy(),
var_stats.index.to_numpy(),
),
)
Expand Down Expand Up @@ -713,10 +704,7 @@ def populate_X_layers(

for presence in eb_summary["presence"]:
assert presence.eb_name == eb.name
eb.presence[presence.dataset_soma_joinid] = (
presence.data,
presence.cols,
)
eb.presence[presence.dataset_soma_joinid] = presence.cols


class SummaryStats(TypedDict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_obs_stats(
"raw_variance_nnz": raw_variance_nnz.astype(
CENSUS_OBS_TABLE_SPEC.field("raw_variance_nnz").to_pandas_dtype()
),
"n_measured_vars": -1, # placeholder
"n_measured_vars": -1, # handled on dataset level in compute_X_file_stats
}
)
assert len(obs_stats) == raw_X.shape[0]
Expand All @@ -53,7 +53,7 @@ def get_var_stats(
var_stats = pd.DataFrame(
data={
"nnz": nnz.astype(CENSUS_VAR_TABLE_SPEC.field("nnz").to_pandas_dtype()),
"n_measured_obs": 0, # placeholder
"n_measured_obs": 0, # handled on dataset level in compute_X_file_stats
}
)
assert len(var_stats) == raw_X.shape[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Self, TypeVar
from typing import Any, Self, TypeVar, cast

import dask
import numpy as np
Expand Down Expand Up @@ -533,7 +533,7 @@ def _validate_X_layers_has_unique_coords(


def validate_X_layers_presence(
soma_path: str, datasets: list[Dataset], experiment_specifications: list[ExperimentSpecification]
soma_path: str, datasets: list[Dataset], experiment_specifications: list[ExperimentSpecification], assets_path: str
) -> Delayed[bool]:
"""Validate that the presence matrix accurately summarizes X[raw] for each experiment.

Expand All @@ -543,6 +543,15 @@ def validate_X_layers_presence(
3. Presence mask per dataset is correct for each dataset
"""

def _read_var_names(path: str) -> npt.NDArray[np.object_]:
import h5py
from anndata.experimental import read_elem

with h5py.File(path) as f:
index_key = f["var"].attrs["_index"]
var_names = read_elem(f["var"][index_key])
return cast(npt.NDArray[np.object_], var_names)

@logit(logger)
def _validate_X_layers_presence_general(experiment_specifications: list[ExperimentSpecification]) -> bool:
for es in experiment_specifications:
Expand Down Expand Up @@ -570,29 +579,29 @@ def _validate_X_layers_presence_general(experiment_specifications: list[Experime

@logit(logger, msg="{0.dataset_id}")
def _validate_X_layers_presence(
dataset: Dataset, experiment_specifications: list[ExperimentSpecification], soma_path: str
dataset: Dataset,
experiment_specifications: list[ExperimentSpecification],
soma_path: str,
assets_path: str,
) -> bool:
"""For a given dataset and experiment, confirm that the presence matrix matches contents of X[raw]."""
for es in experiment_specifications:
with open_experiment(soma_path, es) as exp:
obs_df = (
exp.obs.read(
value_filter=f"dataset_id == '{dataset.soma_joinid}'",
value_filter=f"dataset_id == '{dataset.dataset_id}'",
column_names=["soma_joinid", "n_measured_vars"],
)
.concat()
.to_pandas()
)
if len(obs_df) > 0: # skip empty experiments
X_raw = exp.ms[MEASUREMENT_RNA_NAME].X["raw"]

presence_accumulator = np.zeros((X_raw.shape[1]), dtype=np.bool_)
for block, _ in (
X_raw.read(coords=(obs_df.soma_joinids.to_numpy(), slice(None)))
.blockwise(axis=0, size=2**20, eager=False, reindex_disable_on_axis=[0, 1])
.tables()
):
presence_accumulator[block["soma_dim_1"].to_numpy()] = 1
feature_ids = pd.Index(
exp.ms[MEASUREMENT_RNA_NAME]
.var.read(column_names=["feature_id"])
.concat()
.to_pandas()["feature_id"]
)

presence = (
exp.ms[MEASUREMENT_RNA_NAME][FEATURE_DATASET_PRESENCE_MATRIX_NAME]
Expand All @@ -601,17 +610,22 @@ def _validate_X_layers_presence(
.concat()
)

assert np.array_equal(presence_accumulator, presence), "Presence value does not match X[raw]"
# Get soma_joinids for feature in the original h5ad
orig_feature_ids = _read_var_names(f"{assets_path}/{dataset.dataset_h5ad_path}")
orig_indices = np.sort(feature_ids.get_indexer(feature_ids.intersection(orig_feature_ids)))

assert (
obs_df.n_measured_vars.to_numpy() == presence_accumulator.sum()
).all(), f"{es.name}:{dataset.dataset_id} obs.n_measured_vars incorrect."
np.testing.assert_array_equal(presence["soma_dim_1"], orig_indices)

return True

check_presence_values = (
dask.bag.from_sequence(datasets, partition_size=8)
.map(_validate_X_layers_presence, soma_path=soma_path, experiment_specifications=experiment_specifications)
.map(
_validate_X_layers_presence,
soma_path=soma_path,
experiment_specifications=experiment_specifications,
assets_path=assets_path,
)
.reduction(all, all)
.to_delayed()
)
Expand Down Expand Up @@ -968,9 +982,14 @@ def validate_internal_consistency(
"""
datasets_df["presence_sum_var_axis"] = presence.sum(axis=1).A1
tmp = obs.merge(datasets_df, left_on="dataset_id", right_on="dataset_id")
assert (
tmp.n_measured_vars == tmp.presence_sum_var_axis
).all(), f"{eb.name}: obs.n_measured_vars does not match presence matrix."
try:
np.testing.assert_array_equal(
tmp["n_measured_vars"],
tmp["presence_sum_var_axis"],
)
except AssertionError as e:
e.add_note(f"{eb.name}: obs.n_measured_vars does not match presence matrix.")
raise
del tmp

# Assertion 3 - var.n_measured_obs is consistent with presence matrix
Expand Down Expand Up @@ -1091,7 +1110,7 @@ def validate_soma(args: CensusBuildArgs, client: dask.distributed.Client) -> das
dask.delayed(validate_X_layers_schema)(soma_path, experiment_specifications, eb_info),
validate_X_layers_normalized(soma_path, experiment_specifications),
validate_X_layers_has_unique_coords(soma_path, experiment_specifications),
validate_X_layers_presence(soma_path, datasets, experiment_specifications),
validate_X_layers_presence(soma_path, datasets, experiment_specifications, assets_path),
)
)
],
Expand Down
9 changes: 5 additions & 4 deletions tools/cellxgene_census_builder/tests/anndata/test_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,31 @@
from ..conftest import GENE_IDS, ORGANISMS, get_anndata


def test_open_anndata(datasets: list[Dataset]) -> None:
def test_open_anndata(datasets: list[Dataset], census_build_args: CensusBuildArgs) -> None:
"""`open_anndata` should open the h5ads for each of the dataset in the argument,
and yield both the dataset and the corresponding AnnData object.
This test does not involve additional filtering steps.
The `datasets` used here have no raw layer.
"""
assets_path = census_build_args.h5ads_path.as_posix()

def _todense(X: npt.NDArray[np.float32] | sparse.spmatrix) -> npt.NDArray[np.float32]:
if isinstance(X, np.ndarray):
return X
else:
return cast(npt.NDArray[np.float32], X.todense())

result = [(d, open_anndata(d, base_path=".")) for d in datasets]
result = [(d, open_anndata(d, base_path=assets_path)) for d in datasets]
assert len(result) == len(datasets) and len(datasets) > 0
for i, (dataset, anndata_obj) in enumerate(result):
assert dataset == datasets[i]
opened_anndata = anndata.read_h5ad(dataset.dataset_h5ad_path)
opened_anndata = anndata.read_h5ad(f"{assets_path}/{dataset.dataset_h5ad_path}")
assert opened_anndata.obs.equals(anndata_obj.obs)
assert opened_anndata.var.equals(anndata_obj.var)
assert np.array_equal(_todense(opened_anndata.X), _todense(anndata_obj.X))

# also check context manager
with open_anndata(datasets[0], base_path=".") as ad:
with open_anndata(datasets[0], base_path=assets_path) as ad:
assert ad.n_obs == len(ad.obs)


Expand Down
22 changes: 16 additions & 6 deletions tools/cellxgene_census_builder/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib
from functools import partial
from typing import Literal

import anndata
Expand Down Expand Up @@ -43,8 +44,17 @@ def get_anndata(
n_cells = 4
n_genes = len(gene_ids)
rng = np.random.default_rng()
min_X_val = 1 if no_zero_counts else 0
X = rng.integers(min_X_val, min_X_val + 5, size=(n_cells, n_genes)).astype(np.float32)
if no_zero_counts:
X = rng.integers(1, 6, size=(n_cells, n_genes)).astype(np.float32)
else:
X = sparse.random(
n_cells,
n_genes,
density=0.5,
random_state=rng,
data_rvs=partial(rng.integers, 1, 6),
dtype=np.float32,
).toarray()

# Builder code currently assumes (and enforces) that ALL cells (rows) contain at least
# one non-zero value in their count matrix. Enforce this assumption, as the rng will
Expand Down Expand Up @@ -148,10 +158,10 @@ def datasets(census_build_args: CensusBuildArgs) -> list[Dataset]:
for organism in ORGANISMS:
for i in range(NUM_DATASET):
h5ad = get_anndata(
organism, GENE_IDS[i], no_zero_counts=True, assay_ontology_term_id=ASSAY_IDS[i], X_format=X_FORMAT[i]
organism, GENE_IDS[i], no_zero_counts=False, assay_ontology_term_id=ASSAY_IDS[i], X_format=X_FORMAT[i]
)
h5ad_path = f"{assets_path}/{organism.name}_{i}.h5ad"
h5ad.write_h5ad(h5ad_path)
h5ad_name = f"{organism.name}_{i}.h5ad"
h5ad.write_h5ad(f"{assets_path}/{h5ad_name}")
datasets.append(
Dataset(
dataset_id=f"{organism.name}_{i}",
Expand All @@ -160,7 +170,7 @@ def datasets(census_build_args: CensusBuildArgs) -> list[Dataset]:
collection_id=f"id_{organism.name}",
collection_name=f"collection_{organism.name}",
dataset_asset_h5ad_uri="mock",
dataset_h5ad_path=h5ad_path,
dataset_h5ad_path=h5ad_name,
dataset_version_id=f"{organism.name}_{i}_v0",
),
)
Expand Down