Skip to content

Commit

Permalink
feat: cache psf (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 authored May 7, 2024
1 parent f373bc6 commit ac92442
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/microsim/psf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import cache
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -38,7 +39,7 @@ def simpson(
ni2sin2theta = objective.ni**2 * sintheta**2
nsroot = xp.sqrt(objective.ns**2 - ni2sin2theta)
ngroot = xp.sqrt(objective.ng**2 - ni2sin2theta)
_z = zv[:, xp.newaxis, xp.newaxis] if zv.ndim else zv
_z = xp.asarray(zv if np.isscalar(zv) else zv[:, xp.newaxis, xp.newaxis])
L0 = (
objective.ni * (ci - _z) * costheta
+ zp * nsroot
Expand Down Expand Up @@ -85,8 +86,9 @@ def _cast_objective(objective: ObjectiveKwargs | ObjectiveLens | None) -> Object
raise TypeError(f"Expected ObjectiveLens, got {type(objective)}")


@cache
def vectorial_rz(
zv: npt.NDArray,
zv: Sequence[float],
nx: int = 51,
pos: tuple[float, float, float] = (0, 0, 0),
dxy: float = 0.04,
Expand Down Expand Up @@ -123,7 +125,9 @@ def vectorial_rz(

step = p.half_angle / nSamples
theta = xp.arange(1, nSamples + 1) * step
simpson_integral = simpson(p, theta, constJ, zv, ci, zpos, wave_num, xp=xp)
simpson_integral = simpson(
p, theta, constJ, xp.asarray(zv), ci, zpos, wave_num, xp=xp
)
return 8.0 * np.pi / 3.0 * simpson_integral * (step / ud) ** 2


Expand Down Expand Up @@ -177,8 +181,9 @@ def rz_to_xyz(
# return o.reshape((nx, ny, nz)).T


@cache
def vectorial_psf(
zv: npt.NDArray,
zv: Sequence[float],
nx: int = 31,
ny: int | None = None,
pos: tuple[float, float, float] = (0, 0, 0),
Expand All @@ -190,7 +195,7 @@ def vectorial_psf(
xp: NumpyAPI | None = None,
) -> npt.NDArray:
xp = NumpyAPI.create(xp)
zv = xp.asarray(zv * 1e-6) # convert to meters
zv = tuple(np.asarray(zv) * 1e-6) # convert to meters
ny = ny or nx
rz = vectorial_rz(
zv, np.maximum(ny, nx), pos, dxy, wvl, objective=objective, sf=sf, xp=xp
Expand All @@ -203,9 +208,9 @@ def vectorial_psf(
return _psf


def _centered_zv(nz: int, dz: float, pz: float = 0) -> npt.NDArray:
def _centered_zv(nz: int, dz: float, pz: float = 0) -> tuple[float, ...]:
lim = (nz - 1) * dz / 2
return np.linspace(-lim + pz, lim + pz, nz)
return tuple(np.linspace(-lim + pz, lim + pz, nz))


def vectorial_psf_centered(
Expand Down
16 changes: 16 additions & 0 deletions src/microsim/schema/lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ class ObjectiveLens(SimBaseModel):

magnification: float = Field(1, description="magnification of objective lens.")

def __hash__(self) -> int:
return hash(
(
self.numerical_aperture,
self.coverslip_ri,
self.coverslip_ri_spec,
self.immersion_medium_ri,
self.immersion_medium_ri_spec,
self.specimen_ri,
self.working_distance,
self.coverslip_thickness,
self.coverslip_thickness_spec,
self.magnification,
)
)

@model_validator(mode="before")
def _vroot(cls, values: Any) -> Any:
if isinstance(values, dict):
Expand Down

0 comments on commit ac92442

Please sign in to comment.