Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add possibility to simulate images in batches #95

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

federico-carrara
Copy link
Contributor

Description

This is a draft PR. For more context please have a look at #94 :)

@federico-carrara federico-carrara marked this pull request as draft February 5, 2025 14:48
Copy link

codecov bot commented Feb 5, 2025

Codecov Report

Attention: Patch coverage is 66.66667% with 1 line in your changes missing coverage. Please review.

Project coverage is 84.75%. Comparing base (729dcfe) to head (af726d4).

Files with missing lines Patch % Lines
src/microsim/schema/simulation.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #95      +/-   ##
==========================================
- Coverage   84.79%   84.75%   -0.04%     
==========================================
  Files          47       47              
  Lines        3156     3156              
==========================================
- Hits         2676     2675       -1     
- Misses        480      481       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@federico-carrara
Copy link
Contributor Author

I made a few changes to enable batch processing.
Tested the new code on notebooks/cosem_multi_sample.ipynb and it works as expected.

I also left a couple of TODOs about things I was not sure about. For instance:
-regarding the apply_pre_quantization_binning() and apply_post_quantization_binning(), I am pretty sure we need to change something for batch processing. However, I did not fully get the purpose of these functions, hence I was not able to provide a sensible fix.

In general, I will experiment a bit the changes on my own to see potentials pitfalls that I couldn't spot straight away.

Please let me know your opinion about this, super happy to correct/discuss :) Thank you!!

Copy link
Owner

@tlambert03 tlambert03 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some initial thoughts

Comment on lines +98 to +102
fp_names = [{lbl.fluorophore.name for lbl in s.labels} for s in self.samples]
if len({frozenset(s) for s in fp_names}) != 1:
raise ValueError(
"All samples in the batch must use the same set of fluorophores."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about this. I can see that this would be necessary in the case where we assume that the sample "axis" is just an extension of a non-jagged array. But, maybe sample should be special in that case, and maybe samples should always be iterated over in a for-loop rather than processed in a vectorized fashion? (might help with memory bloat too)

that can be decided later. It's definitely easier to start with this restriction and relax it later.

@@ -53,7 +53,7 @@ class Simulation(SimBaseModel):

truth_space: Space
output_space: Space | None = None
sample: Sample
samples: Sample | list[Sample]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
samples: Sample | list[Sample]
samples: list[Sample]

it's best in models to avoid unions (except perhaps in the case of Optional). If one type can serve both purposes, just use that one. Here we can avoid having to check if isinstance(samples, list) everywhere else in the code. And your field_validator is casting a single object anyway

Comment on lines 105 to 107
@field_validator("samples")
def _samples_to_list(value: Sample | list[Sample]) -> list[Sample]:
return [value] if isinstance(value, Sample) else value
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@field_validator("samples")
def _samples_to_list(value: Sample | list[Sample]) -> list[Sample]:
return [value] if isinstance(value, Sample) else value
@field_validator("samples", mode="before")
@classmethod
def _samples_to_list(cls, value: Any) -> Sequence[Any]:
return [value] if not isinstance(value, (list, tuple)) else value
  1. makes it a classmethod for clarity (pydantic would do this anyway)
  2. adds the cls, argument to the signature
  3. because it's a "before" validator, all we really need to do is make sure we're passing a list to pydantic's "actual" validation, which will cast everything to a Sample

DeprecationWarning,
stacklevel=2,
)
return self.samples[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we don't have to worry about the case where there are zero samples right? sample was previously a required field, so there should always be a sample?

If that's the case, we might want to explicitly require that in the samples field definition. That can be done by using annotated-types

    samples: Annotated[list[Sample], MinLen(1)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with your point. However, I don't fully get the relation with the lines you pointed your comment to

@tlambert03
Copy link
Owner

by the way, this PR starts to make #68 even more relevant. The idea there would be that everything is a dask array, and nothing gets computed until needed. That means you can essentially pass these xarrays (backed by dask) all the way through the simulation chain without anything actually being computed (it would build up a computation graph behind the scenes), and then when you actually run the simulation, or explicitly compute one of the intermediate objects, the graph is executed and the result calculated. Could even be done in chunks for memory limitations

@federico-carrara
Copy link
Contributor Author

federico-carrara commented Feb 13, 2025

Hi Talley! Thanks for your feedback. I integrated most of the things you pointed out.

Just FYI, in the last week, I did some data simulation using this code and, with a batch size of 16 I noticed roughly a 50% speed-up compared to simple for loop. Clearly, with larger BS it gets even better, but there are also the memory constraints we discussed.

Happy to learn more about Dask arrays, I'll take some time to read a bit about their functioning and (I hope) we can discuss about them next month in CSHL ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants