Skip to content

Commit

Permalink
Some additions to AnnLoader and AnnCollection (#704)
Browse files Browse the repository at this point in the history
* add automatic bacthing

* release note
  • Loading branch information
Koncopd authored Feb 13, 2022
1 parent 00a1022 commit 283b0c1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
10 changes: 10 additions & 0 deletions anndata/experimental/multi_files/_anncollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,16 @@ def shape(self):
"""Shape of the lazily concatenated subset of the data matrix."""
return len(self.obs_names), len(self.var_names)

@property
def n_obs(self):
"""Number of observations."""
return self.shape[0]

@property
def n_vars(self):
"""Number of variables/features."""
return self.shape[1]

@property
def convert(self):
"""On the fly converters for keys of attributes and data matrix.
Expand Down
22 changes: 13 additions & 9 deletions anndata/experimental/pytorch/_annloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

try:
import torch
from torch.utils.data import Sampler, Dataset, DataLoader
from torch.utils.data import Sampler, BatchSampler, Dataset, DataLoader
except ImportError:
warnings.warn("Сould not load pytorch.")
Sampler, Dataset, DataLoader = object, object, object
Sampler, BatchSampler, Dataset, DataLoader = object, object, object, object


# Custom sampler to get proper batches instead of joined separate indices
Expand Down Expand Up @@ -188,17 +188,21 @@ def __init__(
if (
batch_size is not None
and batch_size > 1
and not has_sampler
and not has_batch_sampler
and not use_parallel
):
drop_last = kwargs.pop("drop_last", False)
default_sampler = BatchIndexSampler(
len(dataset), batch_size, shuffle, drop_last
)

super().__init__(
dataset, batch_size=None, sampler=default_sampler, **kwargs
)
if has_sampler:
sampler = kwargs.pop("sampler")
sampler = BatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last
)
else:
sampler = BatchIndexSampler(
len(dataset), batch_size, shuffle, drop_last
)

super().__init__(dataset, batch_size=None, sampler=sampler, **kwargs)
else:
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
3 changes: 3 additions & 0 deletions docs/release-latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ This should make it much easier to support new datatypes, use partial access, an
- In many cases :attr:`~anndata.AnnData.X` can now be `None` :pr:`463` :smaller:`R Cannoodt` :pr:`677` :smaller:`I Virshup`. Remaining work is documented in :issue:`467`.
- Removed hard `xlrd` dependency :smaller:`I Virshup`
- `obs` and `var` dataframes are no longer copied by default on `AnnData` instantiation :issue:`371` :smaller:`I Virshup`
- Added PyTorch dataloader :class:`~anndata.experimental.AnnLoader` and lazy concatenation object :class:`~anndata.experimental.AnnCollection`. See the `tutorials`_ :pr:`416` :smaller:`S Rybakov`

.. _tutorials: https://anndata-tutorials.readthedocs.io/en/latest/index.html

.. rubric:: Bug fixes

Expand Down

0 comments on commit 283b0c1

Please sign in to comment.