diff --git a/README.md b/README.md index e4bbbb2..6373eae 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,14 @@ expect rapid changes and breakages. pip install git+https://github.com/tlambert03/microsim ``` +If available, microsim can use either Jax or Cupy to accelerate computations. +These are not installed by default, see the +[jax](https://jax.readthedocs.io/en/latest/installation.html) +or [cupy](https://docs.cupy.dev/en/stable/install.html) installation instructions. + +```bash + + ## Usage Construct and run a @@ -48,4 +56,4 @@ ortho_plot(result) ## Documentation See the API Reference () for details -on the `Simulation` object and options for all of the fields. \ No newline at end of file +on the `Simulation` object and options for all of the fields. diff --git a/examples/basic_confocal.py b/examples/basic_confocal.py index 97b6616..ddd0b12 100644 --- a/examples/basic_confocal.py +++ b/examples/basic_confocal.py @@ -2,25 +2,14 @@ from microsim.util import ortho_plot sim = ms.Simulation( - truth_space=ms.ShapeScaleSpace(shape=(128, 512, 512), scale=(0.02, 0.01, 0.01)), + truth_space=ms.ShapeScaleSpace(shape=(128, 1024, 1024), scale=(0.02, 0.01, 0.01)), output_space={"downscale": 8}, sample=ms.Sample(labels=[ms.MatsLines(density=0.5, length=30, azimuth=5, max_r=1)]), - modality=ms.Confocal(pinhole_au=1), - settings=ms.Settings(random_seed=100), - detector=ms.CameraCCD( - qe=0.82, - gain=1, - full_well=18000, # e - dark_current=0.0005, # e/pix/sec - clock_induced_charge=1, - read_noise=6, - bit_depth=12, - offset=100, - # not used here - readout_rate=1, - photodiode_size=1, - ), - output_path="au1.tif", + modality=ms.Confocal(pinhole_au=0.5), + settings=ms.Settings(random_seed=100, max_psf_radius_aus=8), + detector=ms.CameraCCD(qe=0.82, read_noise=6, bit_depth=12), + # output_path="au1.tif", ) + result = sim.run() ortho_plot(result.data) diff --git a/examples/confocal.ipynb b/examples/confocal.ipynb index 3700d22..6828f8c 100644 --- a/examples/confocal.ipynb +++ b/examples/confocal.ipynb @@ -469,7 +469,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.11.8" }, "orig_nbformat": 4 }, diff --git a/pyproject.toml b/pyproject.toml index 0f8f4b5..b8ee562 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,11 +36,12 @@ classifiers = [ ] # add your package dependencies here dependencies = [ - 'pydantic_settings', - "dask[array]", - "numpy", - "psfmodels", + # strictly necessary "pydantic>=2.4", + "numpy", + "annotated_types", + # probably optional + 'pydantic-settings', "scipy", "tqdm", "xarray", diff --git a/src/microsim/psf.py b/src/microsim/psf.py index d7c1e93..917f04f 100644 --- a/src/microsim/psf.py +++ b/src/microsim/psf.py @@ -1,11 +1,22 @@ -from collections.abc import Mapping, Sequence -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np import numpy.typing as npt +import tqdm from microsim.schema.backend import NumpyAPI -from microsim.schema.lens import ObjectiveLens +from microsim.schema.lens import ObjectiveKwargs, ObjectiveLens + +from ._data_array import ArrayProtocol + +if TYPE_CHECKING: + from collections.abc import Sequence + + from microsim._data_array import ArrayProtocol + from microsim.schema.optical_config import OpticalConfig + from microsim.schema.space import SpaceProtocol def simpson( @@ -66,19 +77,26 @@ def simpson( return xp.real(sum_I0**2 + 2.0 * sum_I1**2 + sum_I2**2) # type: ignore +def _cast_objective(objective: ObjectiveKwargs | ObjectiveLens | None) -> ObjectiveLens: + if isinstance(objective, ObjectiveLens): + return objective + if objective is None or isinstance(objective, dict): + return ObjectiveLens.model_validate(objective or {}) + raise TypeError(f"Expected ObjectiveLens, got {type(objective)}") + + def vectorial_rz( zv: npt.NDArray, nx: int = 51, pos: tuple[float, float, float] = (0, 0, 0), dxy: float = 0.04, wvl: float = 0.6, - params: Mapping | None = None, + objective: ObjectiveKwargs | ObjectiveLens | None = None, sf: int = 3, xp: NumpyAPI | None = None, ) -> npt.NDArray: xp = NumpyAPI.create(xp) - p = ObjectiveLens(**(params or {})) - + p = _cast_objective(objective) wave_num = 2 * np.pi / (wvl * 1e-6) xpos, ypos, zpos = pos @@ -105,7 +123,7 @@ def vectorial_rz( step = p.half_angle / nSamples theta = xp.arange(1, nSamples + 1) * step - simpson_integral = simpson(p, theta, constJ, zv, ci, zpos, wave_num) + simpson_integral = simpson(p, theta, constJ, zv, ci, zpos, wave_num, xp=xp) return 8.0 * np.pi / 3.0 * simpson_integral * (step / ud) ** 2 @@ -167,7 +185,7 @@ def vectorial_psf( pos: tuple[float, float, float] = (0, 0, 0), dxy: float = 0.05, wvl: float = 0.6, - params: Mapping | None = None, + objective: ObjectiveKwargs | ObjectiveLens | None = None, sf: int = 3, normalize: bool = True, xp: NumpyAPI | None = None, @@ -175,7 +193,9 @@ def vectorial_psf( xp = NumpyAPI.create(xp) zv = xp.asarray(zv * 1e-6) # convert to meters ny = ny or nx - rz = vectorial_rz(zv, np.maximum(ny, nx), pos, dxy, wvl, params, sf, xp=xp) + rz = vectorial_rz( + zv, np.maximum(ny, nx), pos, dxy, wvl, objective=objective, sf=sf, xp=xp + ) _psf = rz_to_xyz(rz, (ny, nx), sf, off=xp.asarray(pos[:2]) / (dxy * 1e-6)) if normalize: _psf /= xp.max(_psf) @@ -187,22 +207,175 @@ def _centered_zv(nz: int, dz: float, pz: float = 0) -> npt.NDArray: return np.linspace(-lim + pz, lim + pz, nz) -def vectorial_psf_centered(nz: int, dz: float = 0.05, **kwargs: Any) -> npt.NDArray: +def vectorial_psf_centered( + nz: int, + dz: float = 0.05, + pz: float = 0, + nx: int = 31, + ny: int | None = None, + pos: tuple[float, float, float] = (0, 0, 0), + dxy: float = 0.05, + wvl: float = 0.6, + objective: ObjectiveKwargs | ObjectiveLens | None = None, + sf: int = 3, + normalize: bool = True, + xp: NumpyAPI | None = None, +) -> npt.NDArray: """Compute a vectorial model of the microscope point spread function. The point source is always in the center of the output volume. """ - zv = _centered_zv(nz, dz, kwargs.get("pz", 0)) - return vectorial_psf(zv, **kwargs) + return vectorial_psf( + zv=_centered_zv(nz, dz, pz), + nx=nx, + ny=ny, + pos=pos, + dxy=dxy, + wvl=wvl, + objective=objective, + sf=sf, + normalize=normalize, + xp=xp, + ) + + +def make_confocal_psf( + nz: int, + ex_wvl_um: float = 0.475, + em_wvl_um: float = 0.525, + pinhole_au: float = 1.0, + dz: float = 0.05, + pz: float = 0, + nx: int = 31, + ny: int | None = None, + pos: tuple[float, float, float] = (0, 0, 0), + dxy: float = 0.05, + objective: ObjectiveKwargs | ObjectiveLens | None = None, + sf: int = 3, + normalize: bool = True, + xp: NumpyAPI | None = None, +) -> np.ndarray: + """Create a confocal PSF. + + This function creates a confocal PSF by multiplying the excitation PSF with + the emission PSF convolved with a pinhole mask. + + All extra keyword arguments are passed to `vectorial_psf_centered`. + """ + xp = NumpyAPI.create(xp) + + objective = _cast_objective(objective) + ex_psf = vectorial_psf_centered( + nz=nz, + wvl=ex_wvl_um, + dz=dz, + pz=pz, + nx=nx, + ny=ny, + pos=pos, + dxy=dxy, + objective=objective, + xp=xp, + sf=sf, + normalize=normalize, + ) + em_psf = vectorial_psf_centered( + nz=nz, + wvl=em_wvl_um, + dz=dz, + pz=pz, + nx=nx, + ny=ny, + pos=pos, + dxy=dxy, + objective=objective, + xp=xp, + sf=sf, + normalize=normalize, + ) + + # The effective emission PSF is the regular emission PSF convolved with the + # pinhole mask. The pinhole mask is a disk with diameter equal to the pinhole + # size in AU, converted to pixels. + pinhole = _pinhole_mask( + nxy=ex_psf.shape[-1], + pinhole_au=pinhole_au, + wvl=em_wvl_um, + na=objective.numerical_aperture, + dxy=dxy, + xp=xp, + ) + pinhole = xp.asarray(pinhole) + eff_em_psf = xp.empty_like(em_psf) + for i in tqdm.trange(len(em_psf), desc="convolving em_psf with pinhole..."): + plane = xp.fftconvolve(xp.asarray(em_psf[i]), pinhole, mode="same") + eff_em_psf = xp._array_assign(eff_em_psf, i, plane) -# if __name__ == "__main__": -# zv = np.linspace(-3, 3, 61) -# from time import perf_counter + # The final PSF is the excitation PSF multiplied by the effective emission PSF. + return xp.asarray(ex_psf) * eff_em_psf # type: ignore -# t0 = perf_counter() -# psf = vectorial_psf(zv, nx=512) -# t1 = perf_counter() -# print(psf.shape) -# print(t1 - t0) -# assert np.allclose(np.load("out.npy"), psf, atol=0.1) + +def _pinhole_mask( + nxy: int, + pinhole_au: float, + wvl: float, + na: float, + dxy: float, + xp: NumpyAPI | None = None, +) -> npt.NDArray: + """Create a 2D circular pinhole mask of specified `pinhole_au`.""" + xp = NumpyAPI.create(xp) + + pinhole_size = pinhole_au * 0.61 * wvl / na + pinhole_px = pinhole_size / dxy + + x = xp.arange(nxy) - nxy // 2 + xx, yy = xp.meshgrid(x, x) + r = xp.sqrt(xx**2 + yy**2) + return (r <= pinhole_px).astype(int) # type: ignore + + +def make_psf( + space: SpaceProtocol, + channel: OpticalConfig, + objective: ObjectiveLens, + pinhole_au: float | None = None, + max_au_relative: float | None = None, + xp: NumpyAPI | None = None, +) -> ArrayProtocol: + xp = NumpyAPI.create(xp) + nz, _ny, nx = space.shape + dz, _dy, dx = space.scale + ex_wvl_um = channel.excitation.bandcenter * 1e-3 + em_wvl_um = channel.emission.bandcenter * 1e-3 + objective = _cast_objective(objective) + + # now restrict nx to no more than max_au_relative + if max_au_relative is not None: + airy_radius = 0.61 * ex_wvl_um / objective.numerical_aperture + n_pix_per_airy_radius = airy_radius / dx + max_nx = int(n_pix_per_airy_radius * max_au_relative * 2) + nx = min(nx, max_nx) + # if even make odd + if nx % 2 == 0: + nx += 1 + + if pinhole_au is None: + psf = vectorial_psf_centered( + wvl=em_wvl_um, nz=nz + 1, nx=nx + 1, dz=dz, dxy=dx, objective=objective + ) + else: + psf = make_confocal_psf( + nz=nz, + ex_wvl_um=ex_wvl_um, + em_wvl_um=em_wvl_um, + pinhole_au=pinhole_au, + nx=nx, + dz=dz, + dxy=dx, + objective=objective, + xp=xp, + ) + + return xp.asarray(psf) # type: ignore diff --git a/src/microsim/schema/lens.py b/src/microsim/schema/lens.py index 6343c3b..670e410 100644 --- a/src/microsim/schema/lens.py +++ b/src/microsim/schema/lens.py @@ -1,9 +1,22 @@ -from typing import Any +from typing import Any, TypedDict import numpy as np from pydantic import BaseModel, Field, model_validator +class ObjectiveKwargs(TypedDict, total=False): + numerical_aperture: float + coverslip_ri: float + coverslip_ri_spec: float + immersion_medium_ri: float + immersion_medium_ri_spec: float + specimen_ri: float + working_distance: float + coverslip_thickness: float + coverslip_thickness_spec: float + magnification: float + + class ObjectiveLens(BaseModel): numerical_aperture: float = Field(1.4, alias="na") coverslip_ri: float = 1.515 # coverslip RI experimental value (ng) diff --git a/src/microsim/schema/modality/confocal.py b/src/microsim/schema/modality/confocal.py index 7c66d29..3ffd136 100644 --- a/src/microsim/schema/modality/confocal.py +++ b/src/microsim/schema/modality/confocal.py @@ -1,15 +1,14 @@ -from typing import TYPE_CHECKING, Annotated, Literal, cast +from typing import Annotated, Literal from annotated_types import Ge from pydantic import BaseModel from microsim._data_array import DataArray +from microsim.psf import make_psf from microsim.schema.backend import NumpyAPI from microsim.schema.lens import ObjectiveLens from microsim.schema.optical_config import OpticalConfig - -if TYPE_CHECKING: - from microsim.schema.space import Space +from microsim.schema.settings import Settings class Confocal(BaseModel): @@ -21,25 +20,19 @@ def render( truth: DataArray, channel: OpticalConfig, objective_lens: ObjectiveLens, + settings: Settings, xp: NumpyAPI | None = None, ) -> DataArray: - from microsim.util import make_confocal_psf - - xp = xp or NumpyAPI() + xp = NumpyAPI.create(xp) - # FIXME, this is probably derivable from truth.coords - truth_space = cast("Space", truth.attrs["space"]) - psf = make_confocal_psf( - ex_wvl_um=channel.excitation.bandcenter * 1e-3, - em_wvl_um=channel.emission.bandcenter * 1e-3, + psf = make_psf( + space=truth.attrs["space"], + channel=channel, + objective=objective_lens, pinhole_au=self.pinhole_au, - nz=truth_space.shape[-3] + 1, - nx=truth_space.shape[-1] + 1, - dz=truth_space.scale[-3], - dxy=truth_space.scale[-1], - params={"NA": objective_lens.numerical_aperture}, + max_au_relative=settings.max_psf_radius_aus, xp=xp, ) - psf = xp.asarray(psf) + img = xp.fftconvolve(truth.data, psf, mode="same") return DataArray(img, coords=truth.coords, attrs=truth.attrs) diff --git a/src/microsim/schema/modality/widefield.py b/src/microsim/schema/modality/widefield.py index 9d53673..a06c926 100644 --- a/src/microsim/schema/modality/widefield.py +++ b/src/microsim/schema/modality/widefield.py @@ -1,15 +1,13 @@ -from typing import TYPE_CHECKING, Literal, cast +from typing import Literal from pydantic import BaseModel from microsim._data_array import DataArray -from microsim.psf import vectorial_psf_centered +from microsim.psf import make_psf from microsim.schema.backend import NumpyAPI from microsim.schema.lens import ObjectiveLens from microsim.schema.optical_config import OpticalConfig - -if TYPE_CHECKING: - from microsim.schema.space import Space +from microsim.schema.settings import Settings class Widefield(BaseModel): @@ -20,21 +18,18 @@ def render( truth: DataArray, channel: OpticalConfig, objective_lens: ObjectiveLens, + settings: Settings, xp: NumpyAPI | None = None, ) -> DataArray: xp = NumpyAPI.create(xp) - # FIXME, this is probably derivable from truth.coords - truth_space = cast("Space", truth.attrs["space"]) - em_psf = vectorial_psf_centered( - wvl=channel.emission.bandcenter * 1e-3, - nz=truth_space.shape[-3] + 1, - nx=truth_space.shape[-1] + 1, - dz=truth_space.scale[-3], - dxy=truth_space.scale[-1], - params={"NA": objective_lens.numerical_aperture}, + em_psf = make_psf( + space=truth.attrs["space"], + channel=channel, + objective=objective_lens, + max_au_relative=settings.max_psf_radius_aus, + xp=xp, ) - em_psf = xp.asarray(em_psf) img = xp.fftconvolve(truth.data, em_psf, mode="same") return DataArray(img, coords=truth.coords, attrs=truth.attrs) diff --git a/src/microsim/schema/settings.py b/src/microsim/schema/settings.py index 02e04e1..8039e18 100644 --- a/src/microsim/schema/settings.py +++ b/src/microsim/schema/settings.py @@ -12,6 +12,15 @@ class Settings(BaseSettings): random_seed: int | None = Field( default_factory=lambda: random.randint(0, 2**32 - 1) ) + max_psf_radius_aus: float | None = Field( + 8, + description=( + "When simulating, restrict generated lateral PSF size to no more than this " + "many Airy units. Decreasing this can *dramatically* speed up simulations, " + "but will decrease accuracy. If `None`, no restriction is applied, and the " + "psf will be generated to the full extent of the simulated space." + ), + ) def backend_module(self) -> NumpyAPI: backend = NumpyAPI.create(self.np_backend) diff --git a/src/microsim/schema/simulation.py b/src/microsim/schema/simulation.py index f076247..df668e4 100644 --- a/src/microsim/schema/simulation.py +++ b/src/microsim/schema/simulation.py @@ -88,7 +88,13 @@ def optical_image( raise ValueError("truth must be a DataArray") # let the given modality render the as an image (convolved, etc..) channel = self.channels[channel_idx] # TODO - return self.modality.render(truth, channel, self.objective_lens, xp=self._xp) + return self.modality.render( + truth, + channel, + objective_lens=self.objective_lens, + settings=self.settings, + xp=self._xp, + ) def digital_image( self, diff --git a/src/microsim/simulate/_camera.py b/src/microsim/simulate/_camera.py index 39a0d33..9ea54b1 100644 --- a/src/microsim/simulate/_camera.py +++ b/src/microsim/simulate/_camera.py @@ -49,6 +49,8 @@ def simulate_camera( exposure_s = exposure_ms / 1000 incident_photons = image * exposure_s + # restrict to positive values + incident_photons = xp.maximum(incident_photons, 0) # sample poisson noise if add_poisson: diff --git a/src/microsim/util.py b/src/microsim/util.py index 41539a3..7fd314a 100644 --- a/src/microsim/util.py +++ b/src/microsim/util.py @@ -2,18 +2,14 @@ import itertools import warnings -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast +from typing import TYPE_CHECKING, Protocol, TypeVar, cast import numpy as np import numpy.typing as npt import tqdm -from dask.array.core import normalize_chunks from scipy import signal -from microsim.schema.backend import NumpyAPI - from ._data_array import ArrayProtocol, DataArray -from .psf import vectorial_psf_centered if TYPE_CHECKING: from collections.abc import Callable, Iterator, Sequence @@ -189,6 +185,8 @@ def tiled_convolve( func: Convolver = signal.convolve, dtype: DTypeLike | None = None, ) -> npt.NDArray: + from dask.array.core import normalize_chunks + if chunks is None: chunks = getattr(in1, "chunks", None) or (100,) * in1.ndim # TODO: change 100 @@ -215,74 +213,6 @@ def tiled_convolve( return out -def make_confocal_psf( - ex_wvl_um: float = 0.475, - em_wvl_um: float = 0.525, - pinhole_au: float = 1.0, - xp: NumpyAPI | None = None, - **kwargs: Any, -) -> np.ndarray: - """Create a confocal PSF. - - This function creates a confocal PSF by multiplying the excitation PSF with - the emission PSF convolved with a pinhole mask. - - All extra keyword arguments are passed to `vectorial_psf_centered`. - """ - xp = NumpyAPI.create(xp) - kwargs.pop("wvl", None) - params: dict = kwargs.setdefault("params", {}) - na = params.setdefault("NA", 1.4) - dxy = kwargs.setdefault("dxy", 0.01) - - print("making excitation PSF...") - ex_psf = vectorial_psf_centered(wvl=ex_wvl_um, **kwargs) - print("making emission PSF...") - em_psf = vectorial_psf_centered(wvl=em_wvl_um, **kwargs) - - # The effective emission PSF is the regular emission PSF convolved with the - # pinhole mask. The pinhole mask is a disk with diameter equal to the pinhole - # size in AU, converted to pixels. - pinhole = _pinhole_mask( - nxy=ex_psf.shape[-1], - pinhole_au=pinhole_au, - wvl=em_wvl_um, - na=na, - dxy=dxy, - xp=xp, - ) - pinhole = xp.asarray(pinhole) - - print("convolving em_psf with pinhole...") - eff_em_psf = xp.empty_like(em_psf) - for i in tqdm.trange(len(em_psf)): - plane = xp.fftconvolve(xp.asarray(em_psf[i]), pinhole, mode="same") - eff_em_psf = xp._array_assign(eff_em_psf, i, plane) - - # The final PSF is the excitation PSF multiplied by the effective emission PSF. - return xp.asarray(ex_psf) * eff_em_psf # type: ignore - - -def _pinhole_mask( - nxy: int, - pinhole_au: float, - wvl: float, - na: float, - dxy: float, - xp: NumpyAPI | None = None, -) -> npt.NDArray: - """Create a 2D circular pinhole mask of specified `pinhole_au`.""" - xp = NumpyAPI.create(xp) - - pinhole_size = pinhole_au * 0.61 * wvl / na - pinhole_px = pinhole_size / dxy - - x = xp.arange(nxy) - nxy // 2 - xx, yy = xp.meshgrid(x, x) - r = xp.sqrt(xx**2 + yy**2) - return (r <= pinhole_px).astype(int) # type: ignore - - # convenience function we'll use a couple times def ortho_plot(img: ArrayProtocol, gamma: float = 0.5, mip: bool = False) -> None: import matplotlib.pyplot as plt