Skip to content

Commit

Permalink
Rename {EAQ,Batch}Iterable
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Jan 27, 2025
1 parent 0d11360 commit d92cb72
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
will be returned with rank 1; in all other cases, objects are returned with rank 2."""


class ExperimentAxisQueryIterable(Iterable[Batch]):
class BatchIterable(Iterable[Batch]):
"""An :class:`Iterable` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as
selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator
produces a batch containing equal-sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and
Expand Down
4 changes: 2 additions & 2 deletions src/tiledbsoma_ml/datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data.dataset import Dataset
from torchdata.datapipes.iter import IterDataPipe

from tiledbsoma_ml.pytorch import Batch, ExperimentAxisQueryIterable
from tiledbsoma_ml.batch_iterable import Batch, BatchIterable


class ExperimentAxisQueryIterDataPipe(
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
deprecated
"""
super().__init__()
self._exp_iter = ExperimentAxisQueryIterable(
self._exp_iter = BatchIterable(
query=query,
X_name=X_name,
obs_column_names=obs_column_names,
Expand Down
4 changes: 2 additions & 2 deletions src/tiledbsoma_ml/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from somacore import ExperimentAxisQuery
from torch.utils.data import IterableDataset

from tiledbsoma_ml.pytorch import Batch, ExperimentAxisQueryIterable
from tiledbsoma_ml.batch_iterable import Batch, BatchIterable


class ExperimentAxisQueryIterableDataset(IterableDataset[Batch]): # type:ignore[misc]
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(
"""
super().__init__()
self._exp_iter = ExperimentAxisQueryIterable(
self._exp_iter = BatchIterable(
query=query,
X_name=X_name,
obs_column_names=obs_column_names,
Expand Down
6 changes: 3 additions & 3 deletions tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
)
from tiledbsoma_ml.pytorch import ExperimentAxisQueryIterable
from tiledbsoma_ml.batch_iterable import BatchIterable

assert_array_equal = partial(np.testing.assert_array_equal, strict=True)

Expand All @@ -30,11 +30,11 @@
ExperimentAxisQueryIterableDataset,
)
PipeClassType = Union[
Type[ExperimentAxisQueryIterable],
Type[BatchIterable],
IterableWrapperType,
]
PipeClasses = (
ExperimentAxisQueryIterable,
BatchIterable,
*IterableWrappers,
)
XValueGen = Callable[[range, range], spmatrix]
Expand Down
12 changes: 6 additions & 6 deletions tests/test_pytorch.py → tests/test_batch_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
pytorch_seq_x_value_gen,
pytorch_x_value_gen,
)
from tiledbsoma_ml.pytorch import ExperimentAxisQueryIterable
from tiledbsoma_ml.batch_iterable import BatchIterable


@pytest.mark.parametrize(
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_non_batched(
assert X_batch.todense().tolist() == [expected_X]
else:
assert isinstance(X_batch, np.ndarray)
if PipeClass is ExperimentAxisQueryIterable:
if PipeClass is BatchIterable:
assert X_batch.shape == (1, 3)
assert X_batch.tolist() == [expected_X]
else:
Expand Down Expand Up @@ -409,7 +409,7 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
mock_dist_get_world_size.return_value = world_size

with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = ExperimentAxisQueryIterable(
dp = BatchIterable(
query,
X_name="raw",
obs_column_names=["soma_joinid"],
Expand Down Expand Up @@ -439,7 +439,7 @@ def test__shuffle(PipeClass: PipeClassType, soma_experiment: Experiment) -> None
)

all_rows = list(iter(dp))
if PipeClass is ExperimentAxisQueryIterable:
if PipeClass is BatchIterable:
assert all(np.squeeze(r[0], axis=0).shape == (1,) for r in all_rows)
else:
assert all(r[0].shape == (1,) for r in all_rows)
Expand All @@ -462,7 +462,7 @@ def test_experiment_axis_query_iterable_error_checks(
soma_experiment: Experiment,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = ExperimentAxisQueryIterable(
dp = BatchIterable(
query,
X_name="raw",
shuffle=True,
Expand All @@ -471,7 +471,7 @@ def test_experiment_axis_query_iterable_error_checks(
dp[0]

with pytest.raises(ValueError):
ExperimentAxisQueryIterable(
BatchIterable(
query,
obs_column_names=(),
X_name="raw",
Expand Down

0 comments on commit d92cb72

Please sign in to comment.