From 86c06f7eb579bfb4e9e842edb032174a177017f2 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 14:56:05 -0400 Subject: [PATCH 1/4] feat: better psf caching --- src/microsim/psf.py | 29 +++++++++++++++++++++++++++-- src/microsim/schema/backend.py | 9 +++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/microsim/psf.py b/src/microsim/psf.py index 1695e80..0d74ee3 100644 --- a/src/microsim/psf.py +++ b/src/microsim/psf.py @@ -181,7 +181,6 @@ def rz_to_xyz( # return o.reshape((nx, ny, nz)).T -@cache def vectorial_psf( zv: Sequence[float], nx: int = 31, @@ -355,8 +354,34 @@ def make_psf( dz, _dy, dx = space.scale ex_wvl_um = channel.excitation.bandcenter * 1e-3 em_wvl_um = channel.emission.bandcenter * 1e-3 - objective = _cast_objective(objective) + return cached_psf( + nz=nz, + nx=nx, + dx=dx, + dz=dz, + ex_wvl_um=ex_wvl_um, + em_wvl_um=em_wvl_um, + objective=_cast_objective(objective), + pinhole_au=pinhole_au, + max_au_relative=max_au_relative, + xp=xp, + ) + +# variant of make_psf that only accepts hashable arguments +@cache +def cached_psf( + nz: int, + nx: int, + dx: float, + dz: float, + ex_wvl_um: float, + em_wvl_um: float, + objective: ObjectiveLens, + pinhole_au: float | None, + max_au_relative: float | None, + xp: NumpyAPI, +) -> ArrayProtocol: # now restrict nx to no more than max_au_relative if max_au_relative is not None: airy_radius = 0.61 * ex_wvl_um / objective.numerical_aperture diff --git a/src/microsim/schema/backend.py b/src/microsim/schema/backend.py index 860a9f1..60201dc 100644 --- a/src/microsim/schema/backend.py +++ b/src/microsim/schema/backend.py @@ -51,6 +51,7 @@ def __init__(self) -> None: from scipy import signal, special, stats from scipy.ndimage import map_coordinates + self._random_seed = None self.xp = np self.signal = signal self.stats = stats @@ -72,6 +73,7 @@ def float_dtype(self, dtype: npt.DTypeLike) -> None: ) def set_random_seed(self, seed: int) -> None: + self._random_seed = seed self.xp.random.seed(seed) def asarray( @@ -119,6 +121,12 @@ def _array_assign( arr[mask] = value # type: ignore return arr + def __hash__(self) -> int: + return hash(type(self)) + hash(self._random_seed) + + def __eq__(self, other: Any) -> bool: + return type(self) == type(other) + class JaxAPI(NumpyAPI): def __init__(self) -> None: @@ -144,6 +152,7 @@ def random(self) -> ModuleType: # TODO def set_random_seed(self, seed: int) -> None: from jax.random import PRNGKey + self._random_seed = seed self._key = PRNGKey(seed) # FIXME # tricky... we actually still do use the numpy random seed in addition to From e5e43f2d0a059fb3fee3afe1094b265b5370e8ec Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 14:58:34 -0400 Subject: [PATCH 2/4] add comment --- src/microsim/schema/backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/microsim/schema/backend.py b/src/microsim/schema/backend.py index 60201dc..20439eb 100644 --- a/src/microsim/schema/backend.py +++ b/src/microsim/schema/backend.py @@ -121,6 +121,9 @@ def _array_assign( arr[mask] = value # type: ignore return arr + # WARNING: these hash and eq methods may be problematic later? + # the goal is to make any instance of a NumpyAPI hashable and equal to any + # other instance, as long as they are of the same type and random seed. def __hash__(self) -> int: return hash(type(self)) + hash(self._random_seed) From 04233482055c4c6406cb6ed9109df50aaa3b53b6 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 15:02:20 -0400 Subject: [PATCH 3/4] fix: lint --- src/microsim/schema/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/microsim/schema/backend.py b/src/microsim/schema/backend.py index 20439eb..847d190 100644 --- a/src/microsim/schema/backend.py +++ b/src/microsim/schema/backend.py @@ -51,7 +51,7 @@ def __init__(self) -> None: from scipy import signal, special, stats from scipy.ndimage import map_coordinates - self._random_seed = None + self._random_seed: int | None = None self.xp = np self.signal = signal self.stats = stats From daedf0bf05e30520175559871bb0c731578dd6bf Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 15:04:44 -0400 Subject: [PATCH 4/4] fix jax --- src/microsim/schema/backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/microsim/schema/backend.py b/src/microsim/schema/backend.py index 847d190..d1dfa46 100644 --- a/src/microsim/schema/backend.py +++ b/src/microsim/schema/backend.py @@ -140,6 +140,7 @@ def __init__(self) -> None: from ._jax_bessel import j0, j1 + self._random_seed: int | None = None self.xp = jax.numpy self.signal = signal self.stats = stats