Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor iterators, implement train/test splitting, update notebooks #27

Open
wants to merge 35 commits into
base: main
Choose a base branch
from

Conversation

ryan-williams
Copy link
Member

Previous PR: #26

Changes:

  • Factor 3 iterators (which we compose during shuffling/fetching/batching):
    • QueryIDs (GPU/worker partitioning, train/test splitting, shuffle-chunking)
    • IOBatches (re-batches shuffle-chunks, converts soma_joinid coords to X/obs rows)
    • GPUBatches (re-batches for GPU)
  • Update notebooks, add tests

Workers / obs_joinids serialization

Some changes to what is serialized / sent to each worker:

  • We now fetch {obs,var}_joinids from the provided ExperimentAxisQuery early on, and discard it.
  • The user then has the option to apply a train/test split, before copies of the IDs are sent to each worker.
  • Each worker partitions its obs_joinids according to GPU rank and worker ID, as before.

obs_joinids should fit in-memory; X/obs rows don't, but are only iterated over in "IO batches" by each worker, never serialized.

Tests

I added many tests and more explicit checks, verifying specific (seeded) shuffle outputs. I've used pytest.fixture in some new (to me) ways, but I think the result conveys a lot info with minimal boilerplate.

For example, here's a case that allocates 40 obs_joinids over 2 GPUs x 2 workers, shuffled and unshuffled, with shuffle_chunk_size=2, io_batch_size=4, and batch_size=3:

@param(obs_range=40, world_size=2, num_workers=2, shuffle_chunk_size=2, io_batch_size=4, batch_size=3)
@parametrize("seed,rank,worker_id,expected", [
    (False, 0, 0, [[ 0,  1,  2], [ 3,  4,  5], [ 6,  7,  8], [ 9]]),
    (False, 0, 1, [[10, 11, 12], [13, 14, 15], [16, 17, 18], [19]]),
    (False, 1, 0, [[20, 21, 22], [23, 24, 25], [26, 27, 28], [29]]),
    (False, 1, 1, [[30, 31, 32], [33, 34, 35], [36, 37, 38], [39]]),
    (  111, 0, 0, [[ 3,  2,  0], [ 1,  5,  4], [ 9,  8,  7], [ 6]]),
    (  111, 0, 1, [[13, 12, 10], [11, 15, 14], [19, 18, 17], [16]]),
    (  111, 1, 0, [[23, 22, 20], [21, 25, 24], [29, 28, 27], [26]]),
    (  111, 1, 1, [[33, 32, 30], [31, 35, 34], [39, 38, 37], [36]]),
])
def test_gpu_worker_partitioning__even(check):
    """40 rows / 2 GPUs / [2 workers per GPU] = 10 rows per worker.

    Those 10 are then shuffled in chunks of 2, concatenated/shuffled/fetched in IO batches of 4, and re-batched for GPU
    in 3's.

    Note that each worker's row idxs are a constant offset of the others'; the same shuffle-seed is used on each
    worker's row-idx range.
    """
    pass

A lot of nuances of the shuffling/batching behavior can be verified in a few lines there (and in a couple similar cases in test_dataset.py).

Notebooks

I've updated notebooks/ to use ExperimentDataset, and also be runnable by Papermill, which I've used (in subsequent benchmarking: rw/27..rw/nbs).

@@ -19,4 +20,8 @@
__all__ = [
"ExperimentDataset",
"experiment_dataloader",
"Chunks",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why these classes (Chunks, Partition and QueryIDs) are top-level public API. I.e., why are they exported? Reading the code they look like implementation details, not API we want people coding to (ie., stuff we will need to maintain b/w compat with).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on this? Is this driven by docs?

I just prefer that we keep the public API as clean and simple as possible, so we don't implicitly make commits that affect b/w compat, etc.



@attrs.define(frozen=True)
class GPUBatches(Iterable[Batch]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the naming confusing -- these are not partitioned for a GPU, but are rather what is commonly known as a mini-batch.

Recommend sticking with standard nomenclature from the field

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On top of that this might be just a personal preference and feel free to omit. I would preferably change the name so it does not hide/shadow the functionality of this class. Since it is an iterator a camel-case MiniBatchIterator similar to EagerIterator would make it easier traceable and self-explanatory.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed it to MiniBatchIterable, to hopefully address both concerns.

@@ -62,7 +57,7 @@ class ExperimentDataset(IterableDataset[Batch]): # type: ignore[misc]
soma_joinid
0 57905025

When :obj:`__iter__ <.__iter__>` is invoked, ``obs_joinids`` goes through several partitioning, shuffling, and
When :obj:`__iter__ <.__iter__>` is invoked, |ED.obs_joinids| goes through several partitioning, shuffling, and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I get the appeal of this modified syntax in some cases, it really makes a mess of Python docstrings when used in their native format -- the Python help subsystem.

Suggest taking a look at help(tiledbsoma_ml) inside an ipython shell. For example, the above line renders as:

     |  When :obj:`__iter__ <.__iter__>` is invoked, |ED.obs_joinids|  goes through several partitioning, shuffling, and
     |  batching steps, ultimately yielding :class:`GPU batches <tiledbsoma_ml.common.Batch>` (tuples of matched ``X`` and
     |  ``obs`` rows):

which is quite a bit harder to read simple prose. AFAIK, there is no common ground between the different markup/styling systems, but using these features seems to leave one of the primary use cases (ie., in-notebook help(foo)) not working very well. Probably the best we can do is live with the extraneous :obj:, etc., but skip the use of anchors (e.g., ED.) which are effectively uninterpretable when viewed via the Python help system (or maybe as an alternative - use an anchor name which makes it super obvious what it points at)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed overall, and thanks for the pointer to help() output.

I've dropped the ED. prefix from obs_joinids references in this file, and expanded |Q.obs_joinids| to |QueryIDs.obs_joinids|. AFAICT a relative .obs_joinids should work across both files, but make html has not agreed.

I've also added replace:: directives for the other "links" on this line:

.. |__iter__| replace:: :obj:`__iter__ <.__iter__>`
.. |mini batches| replace:: :class:`"mini batches" <tiledbsoma_ml.common.MiniBatch>`

so it's now:

    When |__iter__| is invoked, |obs_joinids| goes through several partitioning, shuffling, and batching steps,
    ultimately yielding |mini batches| (tuples of matched ``X`` and ``obs`` rows):

Hopefully that minimizes help-str cruft, while still generating docsite hypertext. I settled on these bar-delimited abbreviations partly to keep docstrs readable in editors, which seems identical to the help-str use case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely better! And I can confirm that if you make the in-IDE "view" better, it has the same effect on the ipython and notebook help system (AFAIK, they use exactly the same rendering).

world_size: int = field(init=False)

@classmethod
def create(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super tiny nit: changing from a normal class init to a factory method is going to trip up some folks. Suggestion: can you catch this mistake and give them a useful error?

given how attrs is structured, I'm not sure it is simple to do, so entirely optional idea. But I'll note that most PyTorch Dataset subclasses I've seen (e.g., in Pytorch or Torchvision docs) are initialized via __init__ -- i.e., don't use a factory method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'd tried to avoid the factory previously, but failed.

From attrs' "init" docs, I think I found something that works (let the constructor take either {query,layer_name} or {x_locator,query_ids}). lmk what you think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it!

gpu_split = obs_joinids[gpu_splits[rank] : gpu_splits[rank + 1]]

# Trim all GPU splits to be of equal length (equivalent to a "drop_last"); required for distributed training.
# TODO: may need to add an option to do padding as well.
Copy link
Member

@bkmartinjr bkmartinjr Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good to add the "pad" option to backlog, as this is very commonly used as well (it accomplishes the same thing by padding out the partial/final mini-batches, rather than throwing away the partial bits). This is sometimes preferred in the case that there are rare/unique samples in the partial split

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting; what would you recommend padding the final mini-batches with?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are examples of how this should be done in the PyTorch ecosystem. IIRC, the typical approach is to just duplicate the contents of the contents of the too-small minibatch until it is large enough. I.e., add duplicates of existing data.

E.g., if your mini batch size is 8, and you end up with [0,1,2,3,4,5], then just duplicate the last two, yielding [0,1,2,3,4,5,4,5]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if you elect not to do this in this PR, which is fine, it would be good to capture this as a story. I'm pretty sure the request will come eventually.

Used as a compromise between a full random shuffle (optimal for training performance/convergence) and a
sequential, un-shuffled traversal (optimal for I/O efficiency).
"""
shuffle_chunks: List[NDArrayJoinId] = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this variable should be of type Chunks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

worker_id = partition.worker_id
n_workers = partition.n_workers
else:
rank, world_size = get_distributed_world_rank()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code path appears to be unused. Recommend removing it (and removing the optionality on the partition param). Or add a test case for it :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (removed).

rng = np.random.default_rng(seed)
shuffled_joinids = rng.permutation(obs_joinids)

if method == "deterministic":
Copy link
Member

@bkmartinjr bkmartinjr Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I already mentioned this, but you will likely get requests for more types of splits (example: splits that are based upon the distribution of cells with a given metadata label, etc). Would be good to think about how this might have to evolve to support that.

Example: a 70/30 split by cell_type, meaning that for each unique cell_type value, 70% of the cells go into one split, 30% in the other.

The only concrete suggestion I can give today is that (perhaps) the method parameter should take a more complex type than a simple string, e.g., an object which could be sub-classes to include more complex params. Or add a method_args param which is a simple dict, for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, another random idea to toss into the pot: just let the user pass a lamba that does their own split (i.e., returns the split indices). As the user already has the experiment and query open, the can create a closure containing any required state.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. One way to achieve "a 70/30 split by cell_type" with this PR's code would be:

  1. Create an EAQ that returns soma_joinid and cell_type, materialize its obs as a Pandas DF.
  2. groupby('cell_type'), then for each group: init a QueryIDs, call random_split (resulting in a list[(QueryIDs, QueryIDs)], a train and test QueryIDs for each group)
  3. "Unzip" that to (list[QueryIDs], list[QueryIDs]) (all train QueryIDs, all test QueryIDs), concatenate obs_joinids within each list[QueryIDs].

Not trivial, but I think most of it is "intrinsic complexity." QueryIDs is mostly a dataclass-wrapper around an obs_joinids array, so it shouldn't foreclose any arbitrary manipulations a user might want to do, and random_split should be usable at any stage of a user's pipeline.

I'm guessing anything in this direction would be a follow-on PR, but happy to discuss further / hear other patterns you think we might encounter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to resolve this issue in this PR - I just wanted to give you a heads up about a possible future request, and get some ideas down about how we might handle it. Might be worth cloning this in a story?

with Experiment.open(
self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context
) as exp:
X = exp.ms[self.measurement_name].X[self.layer_name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we confident that opening the Experiment at a timestamp X will also open a sub-object (X) as it was at that timestamp. I seem to recall the answer is no, but we should confirm.

If not, we need to capture the timestamps of all of the sub-objects to get a consistent point-in-time open.

I realize this is not new code - it is just something I noticed in this review. I would ask Isaiah how they work or just run a quick test

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming this experiment/research happened?

Copy link
Collaborator

@ktsitsi ktsitsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some miscs



@attrs.define(frozen=True)
class GPUBatches(Iterable[Batch]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On top of that this might be just a personal preference and feel free to omit. I would preferably change the name so it does not hide/shadow the functionality of this class. Since it is an iterator a camel-case MiniBatchIterator similar to EagerIterator would make it easier traceable and self-explanatory.

data = np.vstack([list(obs_range)] * len(var_range)).flatten()
rows = np.vstack([list(obs_range)] * len(var_range)).flatten()
cols = np.column_stack([list(var_range)] * len(obs_range)).flatten()
return float(f"{r}.{str(c)[::-1]}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use another method here instead of reversing? Like scaling for example since the former might create some misinterpretations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed the reversing is awkward, but other approaches I thought of seemed worse.

The values aren't handled directly in any tests, test-utils just receive "expected" row-idxs, and deep-cmp the rows' values (so the main goal is for each row to be distinct).

Since these end up as float32s (8 sigfigs), one "scaling" option would be r + c * 1e-4:

-# Before
-[
-    [ 0 , 0.1 , 0   , 0.3 , 0   ],
-    [ 1 , 0   , 1.2 , 0   , 1.4 ],
-    [ 0 , 2.1 , 0   , 2.3 , 0   ],
-    [ 3 , 0   , 3.2 , 0   , 3.4 ],
-    [ 0 , 4.1 , 0   , 4.3 , 0   ],
-]
+# After
+[
+    [ 0 , 0.0001 , 0      , 0.0003 , 0      ],
+    [ 1 , 0      , 1.0002 , 0      , 1.0004 ],
+    [ 0 , 2.0001 , 0      , 2.0003 , 0      ],
+    [ 3 , 0      , 3.0002 , 0      , 3.0004 ],
+    [ 0 , 4.0001 , 0      , 4.0003 , 0      ],
+]

Either wfm, lmk what you prefer!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one "scaling" option would be r + c * 1e-4

imho this looks cleaner but I will leave it to your judgement.


def __iter__(self) -> Iterator[IOBatch]:
"""Emit |IOBatch|'s."""
# Create RNG - does not need to be identical across processes, but use the seed anyway
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to think about this comment a bit. It would be more helpful to our future selves if it said something like: Because obs/var IDs are pre-partitioned and shuffled, this RNG does not need to be identical across sub-processes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, paraphrased what you wrote.

)
# Round-trip though tuple avoids `TypeError: IntIndexer only supports array of type int64`.
# TODO: debug / work around that error; serde'ing the ndarray apparently results in a second np.int64 instance, that fails reference equality check vs. the version from the worker-process.
var_joinids = np.array(tuple(self.var_joinids))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IntIndexer supports np.ndarray and Arrow integer arrays, as long as they are int64. Not sure why this is needed, but it will be quite expensive for large arrays.

I'd like to see this root caused in case it is actually an IntIndexer bug. Can you provide any info on the actual type (both container and dtype) of the object?

it would also be much better to do a zero-copy cast if possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, the types (and their casts) supported by IntIndexer are mostly done in the Python code (in tiledbsoma), so we can easily extend that if we have missed something useful.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is still open, and is costing a full copy of the var ids. I left a note in the new code to the same effect.

*fracs: float,
seed: Optional[int] = None,
method: SamplingMethod = "stochastic_rounding",
) -> Tuple[ExperimentDataset, ...]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs docstring

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (also renamed to random_split, to match naming used by PyTorch)

Copy link
Member

@bkmartinjr bkmartinjr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of minor issues, but no large concerns with the overall direction of the PR, nor did I spot any bugs or functional concerns. I think with a cleanup pass, it is going to be GTG!

Copy link
Member Author

@ryan-williams ryan-williams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I responded to everything except 2 comments, that I'll fork to Slack. I was planning to land this one after #26 anyway.

@@ -62,7 +57,7 @@ class ExperimentDataset(IterableDataset[Batch]): # type: ignore[misc]
soma_joinid
0 57905025

When :obj:`__iter__ <.__iter__>` is invoked, ``obs_joinids`` goes through several partitioning, shuffling, and
When :obj:`__iter__ <.__iter__>` is invoked, |ED.obs_joinids| goes through several partitioning, shuffling, and
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed overall, and thanks for the pointer to help() output.

I've dropped the ED. prefix from obs_joinids references in this file, and expanded |Q.obs_joinids| to |QueryIDs.obs_joinids|. AFAICT a relative .obs_joinids should work across both files, but make html has not agreed.

I've also added replace:: directives for the other "links" on this line:

.. |__iter__| replace:: :obj:`__iter__ <.__iter__>`
.. |mini batches| replace:: :class:`"mini batches" <tiledbsoma_ml.common.MiniBatch>`

so it's now:

    When |__iter__| is invoked, |obs_joinids| goes through several partitioning, shuffling, and batching steps,
    ultimately yielding |mini batches| (tuples of matched ``X`` and ``obs`` rows):

Hopefully that minimizes help-str cruft, while still generating docsite hypertext. I settled on these bar-delimited abbreviations partly to keep docstrs readable in editors, which seems identical to the help-str use case.

world_size: int = field(init=False)

@classmethod
def create(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'd tried to avoid the factory previously, but failed.

From attrs' "init" docs, I think I found something that works (let the constructor take either {query,layer_name} or {x_locator,query_ids}). lmk what you think.

worker_id = partition.worker_id
n_workers = partition.n_workers
else:
rank, world_size = get_distributed_world_rank()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (removed).

gpu_split = obs_joinids[gpu_splits[rank] : gpu_splits[rank + 1]]

# Trim all GPU splits to be of equal length (equivalent to a "drop_last"); required for distributed training.
# TODO: may need to add an option to do padding as well.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting; what would you recommend padding the final mini-batches with?

Used as a compromise between a full random shuffle (optimal for training performance/convergence) and a
sequential, un-shuffled traversal (optimal for I/O efficiency).
"""
shuffle_chunks: List[NDArrayJoinId] = [
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

data = np.vstack([list(obs_range)] * len(var_range)).flatten()
rows = np.vstack([list(obs_range)] * len(var_range)).flatten()
cols = np.column_stack([list(var_range)] * len(obs_range)).flatten()
return float(f"{r}.{str(c)[::-1]}")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed the reversing is awkward, but other approaches I thought of seemed worse.

The values aren't handled directly in any tests, test-utils just receive "expected" row-idxs, and deep-cmp the rows' values (so the main goal is for each row to be distinct).

Since these end up as float32s (8 sigfigs), one "scaling" option would be r + c * 1e-4:

-# Before
-[
-    [ 0 , 0.1 , 0   , 0.3 , 0   ],
-    [ 1 , 0   , 1.2 , 0   , 1.4 ],
-    [ 0 , 2.1 , 0   , 2.3 , 0   ],
-    [ 3 , 0   , 3.2 , 0   , 3.4 ],
-    [ 0 , 4.1 , 0   , 4.3 , 0   ],
-]
+# After
+[
+    [ 0 , 0.0001 , 0      , 0.0003 , 0      ],
+    [ 1 , 0      , 1.0002 , 0      , 1.0004 ],
+    [ 0 , 2.0001 , 0      , 2.0003 , 0      ],
+    [ 3 , 0      , 3.0002 , 0      , 3.0004 ],
+    [ 0 , 4.0001 , 0      , 4.0003 , 0      ],
+]

Either wfm, lmk what you prefer!

Copy link
Collaborator

@ktsitsi ktsitsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after resolving open comments.

Also execute via `juq papermill run`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants