Skip to content

Commit

Permalink
Smaller psf (#16)
Browse files Browse the repository at this point in the history
* smaller psf

* cap psf size
  • Loading branch information
tlambert03 authored Apr 11, 2024
1 parent 65f41b1 commit 75889b3
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 152 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ expect rapid changes and breakages.
pip install git+https://github.com/tlambert03/microsim
```

If available, microsim can use either Jax or Cupy to accelerate computations.
These are not installed by default, see the
[jax](https://jax.readthedocs.io/en/latest/installation.html)
or [cupy](https://docs.cupy.dev/en/stable/install.html) installation instructions.

```bash


## Usage

Construct and run a
Expand Down Expand Up @@ -48,4 +56,4 @@ ortho_plot(result)
## Documentation

See the API Reference (<https://tlambert03.github.io/microsim/api/>) for details
on the `Simulation` object and options for all of the fields.
on the `Simulation` object and options for all of the fields.
23 changes: 6 additions & 17 deletions examples/basic_confocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,14 @@
from microsim.util import ortho_plot

sim = ms.Simulation(
truth_space=ms.ShapeScaleSpace(shape=(128, 512, 512), scale=(0.02, 0.01, 0.01)),
truth_space=ms.ShapeScaleSpace(shape=(128, 1024, 1024), scale=(0.02, 0.01, 0.01)),
output_space={"downscale": 8},
sample=ms.Sample(labels=[ms.MatsLines(density=0.5, length=30, azimuth=5, max_r=1)]),
modality=ms.Confocal(pinhole_au=1),
settings=ms.Settings(random_seed=100),
detector=ms.CameraCCD(
qe=0.82,
gain=1,
full_well=18000, # e
dark_current=0.0005, # e/pix/sec
clock_induced_charge=1,
read_noise=6,
bit_depth=12,
offset=100,
# not used here
readout_rate=1,
photodiode_size=1,
),
output_path="au1.tif",
modality=ms.Confocal(pinhole_au=0.5),
settings=ms.Settings(random_seed=100, max_psf_radius_aus=8),
detector=ms.CameraCCD(qe=0.82, read_noise=6, bit_depth=12),
# output_path="au1.tif",
)

result = sim.run()
ortho_plot(result.data)
2 changes: 1 addition & 1 deletion examples/confocal.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.11.8"
},
"orig_nbformat": 4
},
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ classifiers = [
]
# add your package dependencies here
dependencies = [
'pydantic_settings',
"dask[array]",
"numpy",
"psfmodels",
# strictly necessary
"pydantic>=2.4",
"numpy",
"annotated_types",
# probably optional
'pydantic-settings',
"scipy",
"tqdm",
"xarray",
Expand Down
215 changes: 194 additions & 21 deletions src/microsim/psf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from collections.abc import Mapping, Sequence
from typing import Any
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import numpy.typing as npt
import tqdm

from microsim.schema.backend import NumpyAPI
from microsim.schema.lens import ObjectiveLens
from microsim.schema.lens import ObjectiveKwargs, ObjectiveLens

from ._data_array import ArrayProtocol

if TYPE_CHECKING:
from collections.abc import Sequence

from microsim._data_array import ArrayProtocol
from microsim.schema.optical_config import OpticalConfig
from microsim.schema.space import SpaceProtocol


def simpson(
Expand Down Expand Up @@ -66,19 +77,26 @@ def simpson(
return xp.real(sum_I0**2 + 2.0 * sum_I1**2 + sum_I2**2) # type: ignore


def _cast_objective(objective: ObjectiveKwargs | ObjectiveLens | None) -> ObjectiveLens:
if isinstance(objective, ObjectiveLens):
return objective
if objective is None or isinstance(objective, dict):
return ObjectiveLens.model_validate(objective or {})
raise TypeError(f"Expected ObjectiveLens, got {type(objective)}")


def vectorial_rz(
zv: npt.NDArray,
nx: int = 51,
pos: tuple[float, float, float] = (0, 0, 0),
dxy: float = 0.04,
wvl: float = 0.6,
params: Mapping | None = None,
objective: ObjectiveKwargs | ObjectiveLens | None = None,
sf: int = 3,
xp: NumpyAPI | None = None,
) -> npt.NDArray:
xp = NumpyAPI.create(xp)
p = ObjectiveLens(**(params or {}))

p = _cast_objective(objective)
wave_num = 2 * np.pi / (wvl * 1e-6)

xpos, ypos, zpos = pos
Expand All @@ -105,7 +123,7 @@ def vectorial_rz(

step = p.half_angle / nSamples
theta = xp.arange(1, nSamples + 1) * step
simpson_integral = simpson(p, theta, constJ, zv, ci, zpos, wave_num)
simpson_integral = simpson(p, theta, constJ, zv, ci, zpos, wave_num, xp=xp)
return 8.0 * np.pi / 3.0 * simpson_integral * (step / ud) ** 2


Expand Down Expand Up @@ -167,15 +185,17 @@ def vectorial_psf(
pos: tuple[float, float, float] = (0, 0, 0),
dxy: float = 0.05,
wvl: float = 0.6,
params: Mapping | None = None,
objective: ObjectiveKwargs | ObjectiveLens | None = None,
sf: int = 3,
normalize: bool = True,
xp: NumpyAPI | None = None,
) -> npt.NDArray:
xp = NumpyAPI.create(xp)
zv = xp.asarray(zv * 1e-6) # convert to meters
ny = ny or nx
rz = vectorial_rz(zv, np.maximum(ny, nx), pos, dxy, wvl, params, sf, xp=xp)
rz = vectorial_rz(
zv, np.maximum(ny, nx), pos, dxy, wvl, objective=objective, sf=sf, xp=xp
)
_psf = rz_to_xyz(rz, (ny, nx), sf, off=xp.asarray(pos[:2]) / (dxy * 1e-6))
if normalize:
_psf /= xp.max(_psf)
Expand All @@ -187,22 +207,175 @@ def _centered_zv(nz: int, dz: float, pz: float = 0) -> npt.NDArray:
return np.linspace(-lim + pz, lim + pz, nz)


def vectorial_psf_centered(nz: int, dz: float = 0.05, **kwargs: Any) -> npt.NDArray:
def vectorial_psf_centered(
nz: int,
dz: float = 0.05,
pz: float = 0,
nx: int = 31,
ny: int | None = None,
pos: tuple[float, float, float] = (0, 0, 0),
dxy: float = 0.05,
wvl: float = 0.6,
objective: ObjectiveKwargs | ObjectiveLens | None = None,
sf: int = 3,
normalize: bool = True,
xp: NumpyAPI | None = None,
) -> npt.NDArray:
"""Compute a vectorial model of the microscope point spread function.
The point source is always in the center of the output volume.
"""
zv = _centered_zv(nz, dz, kwargs.get("pz", 0))
return vectorial_psf(zv, **kwargs)
return vectorial_psf(
zv=_centered_zv(nz, dz, pz),
nx=nx,
ny=ny,
pos=pos,
dxy=dxy,
wvl=wvl,
objective=objective,
sf=sf,
normalize=normalize,
xp=xp,
)


def make_confocal_psf(
nz: int,
ex_wvl_um: float = 0.475,
em_wvl_um: float = 0.525,
pinhole_au: float = 1.0,
dz: float = 0.05,
pz: float = 0,
nx: int = 31,
ny: int | None = None,
pos: tuple[float, float, float] = (0, 0, 0),
dxy: float = 0.05,
objective: ObjectiveKwargs | ObjectiveLens | None = None,
sf: int = 3,
normalize: bool = True,
xp: NumpyAPI | None = None,
) -> np.ndarray:
"""Create a confocal PSF.
This function creates a confocal PSF by multiplying the excitation PSF with
the emission PSF convolved with a pinhole mask.
All extra keyword arguments are passed to `vectorial_psf_centered`.
"""
xp = NumpyAPI.create(xp)

objective = _cast_objective(objective)
ex_psf = vectorial_psf_centered(
nz=nz,
wvl=ex_wvl_um,
dz=dz,
pz=pz,
nx=nx,
ny=ny,
pos=pos,
dxy=dxy,
objective=objective,
xp=xp,
sf=sf,
normalize=normalize,
)
em_psf = vectorial_psf_centered(
nz=nz,
wvl=em_wvl_um,
dz=dz,
pz=pz,
nx=nx,
ny=ny,
pos=pos,
dxy=dxy,
objective=objective,
xp=xp,
sf=sf,
normalize=normalize,
)

# The effective emission PSF is the regular emission PSF convolved with the
# pinhole mask. The pinhole mask is a disk with diameter equal to the pinhole
# size in AU, converted to pixels.
pinhole = _pinhole_mask(
nxy=ex_psf.shape[-1],
pinhole_au=pinhole_au,
wvl=em_wvl_um,
na=objective.numerical_aperture,
dxy=dxy,
xp=xp,
)
pinhole = xp.asarray(pinhole)

eff_em_psf = xp.empty_like(em_psf)
for i in tqdm.trange(len(em_psf), desc="convolving em_psf with pinhole..."):
plane = xp.fftconvolve(xp.asarray(em_psf[i]), pinhole, mode="same")
eff_em_psf = xp._array_assign(eff_em_psf, i, plane)

# if __name__ == "__main__":
# zv = np.linspace(-3, 3, 61)
# from time import perf_counter
# The final PSF is the excitation PSF multiplied by the effective emission PSF.
return xp.asarray(ex_psf) * eff_em_psf # type: ignore

# t0 = perf_counter()
# psf = vectorial_psf(zv, nx=512)
# t1 = perf_counter()
# print(psf.shape)
# print(t1 - t0)
# assert np.allclose(np.load("out.npy"), psf, atol=0.1)

def _pinhole_mask(
nxy: int,
pinhole_au: float,
wvl: float,
na: float,
dxy: float,
xp: NumpyAPI | None = None,
) -> npt.NDArray:
"""Create a 2D circular pinhole mask of specified `pinhole_au`."""
xp = NumpyAPI.create(xp)

pinhole_size = pinhole_au * 0.61 * wvl / na
pinhole_px = pinhole_size / dxy

x = xp.arange(nxy) - nxy // 2
xx, yy = xp.meshgrid(x, x)
r = xp.sqrt(xx**2 + yy**2)
return (r <= pinhole_px).astype(int) # type: ignore


def make_psf(
space: SpaceProtocol,
channel: OpticalConfig,
objective: ObjectiveLens,
pinhole_au: float | None = None,
max_au_relative: float | None = None,
xp: NumpyAPI | None = None,
) -> ArrayProtocol:
xp = NumpyAPI.create(xp)
nz, _ny, nx = space.shape
dz, _dy, dx = space.scale
ex_wvl_um = channel.excitation.bandcenter * 1e-3
em_wvl_um = channel.emission.bandcenter * 1e-3
objective = _cast_objective(objective)

# now restrict nx to no more than max_au_relative
if max_au_relative is not None:
airy_radius = 0.61 * ex_wvl_um / objective.numerical_aperture
n_pix_per_airy_radius = airy_radius / dx
max_nx = int(n_pix_per_airy_radius * max_au_relative * 2)
nx = min(nx, max_nx)
# if even make odd
if nx % 2 == 0:
nx += 1

if pinhole_au is None:
psf = vectorial_psf_centered(
wvl=em_wvl_um, nz=nz + 1, nx=nx + 1, dz=dz, dxy=dx, objective=objective
)
else:
psf = make_confocal_psf(
nz=nz,
ex_wvl_um=ex_wvl_um,
em_wvl_um=em_wvl_um,
pinhole_au=pinhole_au,
nx=nx,
dz=dz,
dxy=dx,
objective=objective,
xp=xp,
)

return xp.asarray(psf) # type: ignore
15 changes: 14 additions & 1 deletion src/microsim/schema/lens.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from typing import Any
from typing import Any, TypedDict

import numpy as np
from pydantic import BaseModel, Field, model_validator


class ObjectiveKwargs(TypedDict, total=False):
numerical_aperture: float
coverslip_ri: float
coverslip_ri_spec: float
immersion_medium_ri: float
immersion_medium_ri_spec: float
specimen_ri: float
working_distance: float
coverslip_thickness: float
coverslip_thickness_spec: float
magnification: float


class ObjectiveLens(BaseModel):
numerical_aperture: float = Field(1.4, alias="na")
coverslip_ri: float = 1.515 # coverslip RI experimental value (ng)
Expand Down
Loading

0 comments on commit 75889b3

Please sign in to comment.