diff --git a/anndata/experimental/multi_files/_anncollection.py b/anndata/experimental/multi_files/_anncollection.py index 895305c39..3ab7201cb 100644 --- a/anndata/experimental/multi_files/_anncollection.py +++ b/anndata/experimental/multi_files/_anncollection.py @@ -466,6 +466,16 @@ def shape(self): """Shape of the lazily concatenated subset of the data matrix.""" return len(self.obs_names), len(self.var_names) + @property + def n_obs(self): + """Number of observations.""" + return self.shape[0] + + @property + def n_vars(self): + """Number of variables/features.""" + return self.shape[1] + @property def convert(self): """On the fly converters for keys of attributes and data matrix. diff --git a/anndata/experimental/pytorch/_annloader.py b/anndata/experimental/pytorch/_annloader.py index 83e0c7906..409d3f3be 100644 --- a/anndata/experimental/pytorch/_annloader.py +++ b/anndata/experimental/pytorch/_annloader.py @@ -13,10 +13,10 @@ try: import torch - from torch.utils.data import Sampler, Dataset, DataLoader + from torch.utils.data import Sampler, BatchSampler, Dataset, DataLoader except ImportError: warnings.warn("Сould not load pytorch.") - Sampler, Dataset, DataLoader = object, object, object + Sampler, BatchSampler, Dataset, DataLoader = object, object, object, object # Custom sampler to get proper batches instead of joined separate indices @@ -188,17 +188,21 @@ def __init__( if ( batch_size is not None and batch_size > 1 - and not has_sampler and not has_batch_sampler and not use_parallel ): drop_last = kwargs.pop("drop_last", False) - default_sampler = BatchIndexSampler( - len(dataset), batch_size, shuffle, drop_last - ) - super().__init__( - dataset, batch_size=None, sampler=default_sampler, **kwargs - ) + if has_sampler: + sampler = kwargs.pop("sampler") + sampler = BatchSampler( + sampler, batch_size=batch_size, drop_last=drop_last + ) + else: + sampler = BatchIndexSampler( + len(dataset), batch_size, shuffle, drop_last + ) + + super().__init__(dataset, batch_size=None, sampler=sampler, **kwargs) else: super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs) diff --git a/docs/release-latest.rst b/docs/release-latest.rst index c3f971f10..c2a4a3b17 100644 --- a/docs/release-latest.rst +++ b/docs/release-latest.rst @@ -25,6 +25,9 @@ This should make it much easier to support new datatypes, use partial access, an - In many cases :attr:`~anndata.AnnData.X` can now be `None` :pr:`463` :smaller:`R Cannoodt` :pr:`677` :smaller:`I Virshup`. Remaining work is documented in :issue:`467`. - Removed hard `xlrd` dependency :smaller:`I Virshup` - `obs` and `var` dataframes are no longer copied by default on `AnnData` instantiation :issue:`371` :smaller:`I Virshup` +- Added PyTorch dataloader :class:`~anndata.experimental.AnnLoader` and lazy concatenation object :class:`~anndata.experimental.AnnCollection`. See the `tutorials`_ :pr:`416` :smaller:`S Rybakov` + +.. _tutorials: https://anndata-tutorials.readthedocs.io/en/latest/index.html .. rubric:: Bug fixes