From 4ff182cda11957b8d986cc8754c85233d2ba8fe0 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 2 Jan 2025 14:00:00 -0500 Subject: [PATCH] Add `EagerIterator` shim `EagerIterator` was moved from `somacore` to `tiledbsoma` in https://github.com/single-cell-data/SOMA/pull/244 / https://github.com/single-cell-data/TileDB-SOMA/pull/3307 (released 1.0.24 / 1.15.0). THis `import` shim helps us support versions before and since those. --- src/tiledbsoma_ml/_utils.py | 10 ++++++++++ src/tiledbsoma_ml/batch_iterable.py | 15 ++++----------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/tiledbsoma_ml/_utils.py b/src/tiledbsoma_ml/_utils.py index 83456eb..170fa85 100644 --- a/src/tiledbsoma_ml/_utils.py +++ b/src/tiledbsoma_ml/_utils.py @@ -11,6 +11,16 @@ import numpy as np import numpy.typing as npt +try: + # somacore<1.0.24 / tiledbsoma<1.15 + from somacore.query._eager_iter import EagerIterator as _EagerIterator +except ImportError: + # somacore>=1.0.24 / tiledbsoma>=1.15 + from tiledbsoma._eager_iter import EagerIterator as _EagerIterator + +# Abstract over the import `try` above, re-export for use in this module: +EagerIterator = _EagerIterator + _T_co = TypeVar("_T_co", covariant=True) NDArrayNumber = npt.NDArray[np.number[Any]] diff --git a/src/tiledbsoma_ml/batch_iterable.py b/src/tiledbsoma_ml/batch_iterable.py index 0c8f638..fd8173c 100644 --- a/src/tiledbsoma_ml/batch_iterable.py +++ b/src/tiledbsoma_ml/batch_iterable.py @@ -34,14 +34,7 @@ get_worker_world_rank, ) from tiledbsoma_ml._experiment_locator import ExperimentLocator -from tiledbsoma_ml._utils import NDArrayNumber, batched, splits - -try: - # somacore<1.0.24 / tiledbsoma<1.15 - from somacore.query._eager_iter import EagerIterator as _EagerIterator -except ImportError: - # somacore>=1.0.24 / tiledbsoma>=1.15 - from tiledbsoma._eager_iter import EagerIterator as _EagerIterator +from tiledbsoma_ml._utils import EagerIterator, NDArrayNumber, batched, splits logger = logging.getLogger("tiledbsoma_ml.pytorch") @@ -322,7 +315,7 @@ def __iter__(self) -> Iterator[Batch]: obs_joinid_iter = self._create_obs_joinids_partition() _mini_batch_iter = self._mini_batch_iter(exp.obs, X, obs_joinid_iter) if self.use_eager_fetch: - _mini_batch_iter = _EagerIterator( + _mini_batch_iter = EagerIterator( _mini_batch_iter, pool=exp.context.threadpool ) @@ -458,7 +451,7 @@ def make_io_buffer( for X_tbl in X_tbl_iter ) if self.use_eager_fetch: - _io_buf_iter = _EagerIterator(_io_buf_iter, pool=X.context.threadpool) + _io_buf_iter = EagerIterator(_io_buf_iter, pool=X.context.threadpool) # Now that X read is potentially in progress (in eager mode), go fetch obs data # fmt: off @@ -498,7 +491,7 @@ def _mini_batch_iter( io_batch_iter = self._io_batch_iter(obs, X, obs_joinid_iter) if self.use_eager_fetch: - io_batch_iter = _EagerIterator(io_batch_iter, pool=X.context.threadpool) + io_batch_iter = EagerIterator(io_batch_iter, pool=X.context.threadpool) mini_batch_size = self.batch_size result: Tuple[NDArrayNumber, pd.DataFrame] | None = None