Skip to content

Commit

Permalink
Add EagerIterator shim
Browse files Browse the repository at this point in the history
`EagerIterator` was moved from `somacore` to `tiledbsoma` in single-cell-data/SOMA#244 / single-cell-data/TileDB-SOMA#3307 (released 1.0.24 / 1.15.0). THis `import` shim helps us support versions before and since those.
  • Loading branch information
ryan-williams committed Jan 27, 2025
1 parent d92cb72 commit 4ff182c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
10 changes: 10 additions & 0 deletions src/tiledbsoma_ml/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
import numpy as np
import numpy.typing as npt

try:
# somacore<1.0.24 / tiledbsoma<1.15
from somacore.query._eager_iter import EagerIterator as _EagerIterator
except ImportError:
# somacore>=1.0.24 / tiledbsoma>=1.15
from tiledbsoma._eager_iter import EagerIterator as _EagerIterator

# Abstract over the import `try` above, re-export for use in this module:
EagerIterator = _EagerIterator

_T_co = TypeVar("_T_co", covariant=True)
NDArrayNumber = npt.NDArray[np.number[Any]]

Expand Down
15 changes: 4 additions & 11 deletions src/tiledbsoma_ml/batch_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,7 @@
get_worker_world_rank,
)
from tiledbsoma_ml._experiment_locator import ExperimentLocator
from tiledbsoma_ml._utils import NDArrayNumber, batched, splits

try:
# somacore<1.0.24 / tiledbsoma<1.15
from somacore.query._eager_iter import EagerIterator as _EagerIterator
except ImportError:
# somacore>=1.0.24 / tiledbsoma>=1.15
from tiledbsoma._eager_iter import EagerIterator as _EagerIterator
from tiledbsoma_ml._utils import EagerIterator, NDArrayNumber, batched, splits

logger = logging.getLogger("tiledbsoma_ml.pytorch")

Expand Down Expand Up @@ -322,7 +315,7 @@ def __iter__(self) -> Iterator[Batch]:
obs_joinid_iter = self._create_obs_joinids_partition()
_mini_batch_iter = self._mini_batch_iter(exp.obs, X, obs_joinid_iter)
if self.use_eager_fetch:
_mini_batch_iter = _EagerIterator(
_mini_batch_iter = EagerIterator(
_mini_batch_iter, pool=exp.context.threadpool
)

Expand Down Expand Up @@ -458,7 +451,7 @@ def make_io_buffer(
for X_tbl in X_tbl_iter
)
if self.use_eager_fetch:
_io_buf_iter = _EagerIterator(_io_buf_iter, pool=X.context.threadpool)
_io_buf_iter = EagerIterator(_io_buf_iter, pool=X.context.threadpool)

# Now that X read is potentially in progress (in eager mode), go fetch obs data
# fmt: off
Expand Down Expand Up @@ -498,7 +491,7 @@ def _mini_batch_iter(

io_batch_iter = self._io_batch_iter(obs, X, obs_joinid_iter)
if self.use_eager_fetch:
io_batch_iter = _EagerIterator(io_batch_iter, pool=X.context.threadpool)
io_batch_iter = EagerIterator(io_batch_iter, pool=X.context.threadpool)

mini_batch_size = self.batch_size
result: Tuple[NDArrayNumber, pd.DataFrame] | None = None
Expand Down

0 comments on commit 4ff182c

Please sign in to comment.