Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: higher level psf caching #28

Merged
merged 5 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions src/microsim/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def rz_to_xyz(
# return o.reshape((nx, ny, nz)).T


@cache
def vectorial_psf(
zv: Sequence[float],
nx: int = 31,
Expand Down Expand Up @@ -355,8 +354,34 @@ def make_psf(
dz, _dy, dx = space.scale
ex_wvl_um = channel.excitation.bandcenter * 1e-3
em_wvl_um = channel.emission.bandcenter * 1e-3
objective = _cast_objective(objective)
return cached_psf(
nz=nz,
nx=nx,
dx=dx,
dz=dz,
ex_wvl_um=ex_wvl_um,
em_wvl_um=em_wvl_um,
objective=_cast_objective(objective),
pinhole_au=pinhole_au,
max_au_relative=max_au_relative,
xp=xp,
)


# variant of make_psf that only accepts hashable arguments
@cache
def cached_psf(
nz: int,
nx: int,
dx: float,
dz: float,
ex_wvl_um: float,
em_wvl_um: float,
objective: ObjectiveLens,
pinhole_au: float | None,
max_au_relative: float | None,
xp: NumpyAPI,
) -> ArrayProtocol:
# now restrict nx to no more than max_au_relative
if max_au_relative is not None:
airy_radius = 0.61 * ex_wvl_um / objective.numerical_aperture
Expand Down
13 changes: 13 additions & 0 deletions src/microsim/schema/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self) -> None:
from scipy import signal, special, stats
from scipy.ndimage import map_coordinates

self._random_seed: int | None = None
self.xp = np
self.signal = signal
self.stats = stats
Expand All @@ -72,6 +73,7 @@ def float_dtype(self, dtype: npt.DTypeLike) -> None:
)

def set_random_seed(self, seed: int) -> None:
self._random_seed = seed
self.xp.random.seed(seed)

def asarray(
Expand Down Expand Up @@ -119,6 +121,15 @@ def _array_assign(
arr[mask] = value # type: ignore
return arr

# WARNING: these hash and eq methods may be problematic later?
# the goal is to make any instance of a NumpyAPI hashable and equal to any
# other instance, as long as they are of the same type and random seed.
def __hash__(self) -> int:
return hash(type(self)) + hash(self._random_seed)

def __eq__(self, other: Any) -> bool:
return type(self) == type(other)


class JaxAPI(NumpyAPI):
def __init__(self) -> None:
Expand All @@ -129,6 +140,7 @@ def __init__(self) -> None:

from ._jax_bessel import j0, j1

self._random_seed: int | None = None
self.xp = jax.numpy
self.signal = signal
self.stats = stats
Expand All @@ -144,6 +156,7 @@ def random(self) -> ModuleType: # TODO
def set_random_seed(self, seed: int) -> None:
from jax.random import PRNGKey

self._random_seed = seed
self._key = PRNGKey(seed)
# FIXME
# tricky... we actually still do use the numpy random seed in addition to
Expand Down