Skip to content

Commit

Permalink
[perf] Improve spot finding in estimate_grid_orientation_from_img (#2987
Browse files Browse the repository at this point in the history
)

* [perf] Improve spot finding when distance between spots >> spot size

Improved spot finding and grid fitting accuracy in `estimate_grid_orientation_from_img` and `find_spot_positions` by:
- Using a configurable minimum distance between spots during local maxima finding. This addresses issues when the intensity difference between spots is high and single bright spots could be detected twice, while less bright spots were not detected.
- Splitting `min_distance` and `len_object` parameters in `peak_local_max` and `find_spot_positions` to refine spot size calculation and edge exclusion relative to minimum spot distance.

* Update after PR review

- Add pitch to simulator microscope file
- added comment to explain factor 0.75
- added pitch as an input argument to the spot-grid.py script, if not provided it will try to use the detector metadata
- Update the lens magnification in the simulated yaml file to match the hardware, this ensures the spot-grid script finds the right spots for the pitch matching the pitch on the hardware

* [refactor] Update after PR comments

- make DEFAULT_PITCH a constant in fastem.py to have a single source of truth
- clean up spot-grid.py
  • Loading branch information
tepals authored Jan 28, 2025
1 parent 961c225 commit 1ec1149
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 29 deletions.
3 changes: 2 additions & 1 deletion install/linux/usr/share/odemis/sim/fastem-sim-asm.odm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ FASTEM-sim: {
"percntl_im_rng": [1, 99],
# List[int, int] min/max percentile of the dynamic range to which the histogram of the image will be stretched.
"percntl_dyn_rng": [10, 90],
"pitch": 3.2e-6 # float [m], the expected pitch between two beamlets in meters
},
},
}
Expand Down Expand Up @@ -366,7 +367,7 @@ FASTEM-sim: {
class: static.OpticalLens,
role: lens,
init: {
mag: 60, # ratio, magnifying; higher magnification is a stronger simulated blur
mag: 40, # ratio, magnifying; higher magnification is a stronger simulated blur
na: 0.95, # numerical aperture
ri: 1, # refractive index
},
Expand Down
37 changes: 31 additions & 6 deletions scripts/spot-grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import wx.lib.wxcairo

from odemis import dataio, model
from odemis.acq.fastem import DEFAULT_PITCH
from odemis.cli.video_displayer import VideoDisplayer
from odemis.driver import ueye
from odemis.gui.conf.file import AcquisitionConfig
Expand All @@ -35,6 +36,8 @@
PIXEL_SIZE_SAMPLE_PLANE = 3.45e-6 # m
DEFAULT_MAGNIFICATION = 40
PIXEL_SIZE = PIXEL_SIZE_SAMPLE_PLANE / DEFAULT_MAGNIFICATION
# 0.75 is a safety factor to allow for some variation in spot positions
MIN_DIST_SPOTS = int(0.75 * DEFAULT_PITCH / PIXEL_SIZE)


class VideoDisplayerGrid(VideoDisplayer):
Expand All @@ -43,16 +46,19 @@ class VideoDisplayerGrid(VideoDisplayer):
It should be pretty much platform independent.
"""

def __init__(self, title="Live image", size=(640, 480), gridsize=None, pixel_size=PIXEL_SIZE):
def __init__(self, title="Live image", size=(640, 480), gridsize=None, pixel_size=PIXEL_SIZE,
min_dist_spots=MIN_DIST_SPOTS):
"""
Displays the window on the screen
size (2-tuple int,int): X and Y size of the window at initialisation
pixel_size (float): pixel size in m
min_dist_spots (int): minimum distance between spots in pixels
Note that the size of the window automatically adapts afterwards to the
coming pictures
"""
self.app = ImageWindowApp(title, size, pixel_size)
self.gridsize = (8, 8) if gridsize is None else gridsize
self.min_dist_spots = min_dist_spots
self.acqui_conf = AcquisitionConfig()

def new_image(self, data):
Expand All @@ -68,6 +74,7 @@ def new_image(self, data):
AffineTransform,
sigma=1.45,
threshold_rel=self.acqui_conf.spot_grid_threshold,
min_distance=self.min_dist_spots,
)
grid = unit_gridpoints(self.gridsize, mode="ji")
self.app.spots = tform_ji.apply(grid)
Expand Down Expand Up @@ -244,17 +251,19 @@ def start_generate(self):
self.notify(self._detector.array)


def live_display(ccd, dataflow, pixel_size, kill_ccd=True, gridsize=None):
def live_display(ccd, dataflow, pixel_size, kill_ccd=True, gridsize=None, min_dist_spots=MIN_DIST_SPOTS):
"""
Acquire an image from one (or more) dataflow and display it with a spot grid overlay.
ccd: a camera object
dataflow: dataflow to access
pixel_size (float): pixel size in m
kill_ccd: True if it is required to terminate the ccd after closing the window
gridsize: size of the grid of spots.
min_dist_spots: minimum distance between spots in pixels
"""
# create a window
window = VideoDisplayerGrid("Live from %s.%s" % (ccd.role, "data"), ccd.resolution.value, gridsize, pixel_size)
window = VideoDisplayerGrid("Live from %s.%s" % (ccd.role, "data"), ccd.resolution.value, gridsize, pixel_size,
min_dist_spots)
im_passer = ImagePasser()
t = threading.Thread(target=image_update, args=(im_passer, window))
t.daemon = True
Expand Down Expand Up @@ -293,6 +302,8 @@ def main(args):
help="size of the grid of spots in x y, default 8 8")
parser.add_argument("--magnification", dest="magnification", type=float,
help="magnification (typically 40 or 50)")
parser.add_argument("--pitch", dest="pitch", type=float, default=None,
help=f"pitch in meters (defaults to {DEFAULT_PITCH:0.1e})")
parser.add_argument("--log-level", dest="loglev", metavar="<level>", type=int, choices=[0, 1, 2],
default=0, help="set verbosity level (0-2, default = 0)")
options = parser.parse_args(args[1:])
Expand Down Expand Up @@ -327,21 +338,35 @@ def main(args):
logging.warning("No magnification specified, falling back to %s.", magnification)
pixel_size = PIXEL_SIZE_SAMPLE_PLANE / magnification

pitch = options.pitch
if not pitch:
try:
mppc = model.getComponent(role="mppc")
mppc_md = mppc.getMetadata()
pitch = mppc_md.get(model.MD_CALIB, {}).get("pitch", DEFAULT_PITCH)
except Exception as ex:
logging.debug("Failed to read pitch from mppc, ex: %s", ex)
pitch = DEFAULT_PITCH

# 0.75 is a safety factor to allow for some variation in spot positions
min_dist_spots = int(0.75 * pitch / pixel_size)

if options.filename:
logging.info("Will process image file %s" % options.filename)
converter = dataio.find_fittest_converter(options.filename, default=None, mode=os.O_RDONLY)
data = converter.read_data(options.filename)[0]
fakeccd = StaticCCD(options.filename, "fakeccd", data)
live_display(fakeccd, fakeccd.data, pixel_size, gridsize=options.gridsize)
live_display(fakeccd, fakeccd.data, pixel_size, gridsize=options.gridsize, min_dist_spots=min_dist_spots)
elif options.role:
if get_backend_status() != BACKEND_RUNNING:
raise ValueError("Backend is not running while role command is specified.")
ccd = model.getComponent(role=options.role)
live_display(ccd, ccd.data, pixel_size, kill_ccd=False, gridsize=options.gridsize)
live_display(ccd, ccd.data, pixel_size, kill_ccd=False, gridsize=options.gridsize,
min_dist_spots=min_dist_spots)
else:
ccd = ueye.Camera("camera", "ccd", device=None)
ccd.SetFrameRate(2)
live_display(ccd, ccd.data, pixel_size, gridsize=options.gridsize)
live_display(ccd, ccd.data, pixel_size, gridsize=options.gridsize, min_dist_spots=min_dist_spots)
return 0


Expand Down
20 changes: 17 additions & 3 deletions src/odemis/acq/fastem.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
# The executor is a single object, independent of how many times the module (fastem.py) is loaded.
_executor = model.CancellableThreadPoolExecutor(max_workers=1)

DEFAULT_PITCH = 3.2e-6 # distance between spots in m

# TODO: Normally we do not use component names in code, only roles. Store in the roles in the SETTINGS_SELECTION,
# and at init lookup the role -> name conversion (using model.getComponent(role=role)).
# Selection of components, VAs and values to save with the ROA acquisition, structured: {component: {VA: value}}
Expand Down Expand Up @@ -292,6 +294,16 @@ def __init__(self, scanner, multibeam, descanner, detector, stage, scan_stage, c
# save the initial multibeam resolution, because the resolution will get updated if save_full_cells is True
self._old_res = self._multibeam.resolution.value

# Calculate the expected minimum distance between spots in the grid on the diagnostic camera
detector_md = detector.getMetadata()
ccd_md = ccd.getMetadata()
self._exp_pitch_m = detector_md.get(model.MD_CALIB, {}).get("pitch", DEFAULT_PITCH) # m
lens_mag = ccd_md.get(model.MD_LENS_MAG)
ccd_px_size = ccd_md.get(model.MD_SENSOR_PIXEL_SIZE)
exp_pitch_px = self._exp_pitch_m * lens_mag / ccd_px_size[0]
# 0.75 is a safety factor to allow for some variation in spot positions
self._min_dist_spots = int(0.75 * exp_pitch_px)

beam_shift_path = fastem_util.create_image_dir("beam-shift-correction")
# If there is a project name the path will be
# [image-dir]/beam-shift-correction/[timestamp]/[project-name]/[roa-name]_[slice-idx]
Expand Down Expand Up @@ -457,8 +469,8 @@ def acquire_roa(self, dataflow):
logging.debug(f"Will run beam shift correction for field index {field_idx}")
try:
new_beam_shift = self.correct_beam_shift()
# The difference in x or y should not be larger than 2 micrometers
if any(map(lambda n, p: abs(n - p) > 2e-6, new_beam_shift, prev_beam_shift)):
# The difference in x or y should not be larger than half a pitch
if any(map(lambda n, p: abs(n - p) > 0.5 * self._exp_pitch_m, new_beam_shift, prev_beam_shift)):
raise ValueError(
f"Difference in beam shift is larger than 2 µm, therefore it most likely failed. "
f"Previous beam shift: {prev_beam_shift}, new beam shift: {new_beam_shift}"
Expand Down Expand Up @@ -663,7 +675,9 @@ def correct_beam_shift(self):
# asap=False: wait until new image is acquired (don't read from buffer)
ccd_image = self._ccd.data.get(asap=False)
tform, error = estimate_grid_orientation_from_img(ccd_image, (8, 8), SimilarityTransform, sigma,
threshold_rel=self._spot_grid_thresh)
threshold_rel=self._spot_grid_thresh,
min_distance=self._min_dist_spots,
)
logging.debug(f"Found center of grid at {tform.translation}, error: {error}.")

# Determine the shift of the spots, by subtracting the good multiprobe position from the average (center)
Expand Down
35 changes: 21 additions & 14 deletions src/odemis/acq/test/fastem_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from odemis.acq import fastem, stream
from odemis.acq.acqmng import SettingsObserver
from odemis.acq.align.fastem import Calibrations
from odemis.acq.fastem import SETTINGS_SELECTION
from odemis.acq.fastem import DEFAULT_PITCH, SETTINGS_SELECTION
from odemis.gui.comp.fastem_roa import FastEMROA
from odemis.gui.comp.overlay.shapes import EditableShape
from odemis.gui.model.main_gui_data import FastEMMainGUIData
Expand Down Expand Up @@ -757,10 +757,10 @@ def test_get_abs_stage_movement(self):
# Create an ROA with the coordinates of the field.
roa_name = "test_megafield_id"
roa = FastEMROA(shape=MockEditableShape(),
main_data=self.main_data,
overlap=0.0,
name=roa_name,
slice_index=0)
main_data=self.main_data,
overlap=0.0,
name=roa_name,
slice_index=0)
roa.shape._points = points
roa.shape.points.value = points

Expand Down Expand Up @@ -845,10 +845,10 @@ def test_get_abs_stage_movement_overlap(self):
# Create an ROA with the coordinates of the field.
roa_name = "test_megafield_id"
roa = FastEMROA(shape=MockEditableShape(),
main_data=self.main_data,
overlap=overlap,
name=roa_name,
slice_index=0)
main_data=self.main_data,
overlap=overlap,
name=roa_name,
slice_index=0)
roa.shape._points = points
roa.shape.points.value = points

Expand Down Expand Up @@ -1130,10 +1130,10 @@ def test_get_pos_first_tile(self):
# Create an ROA with the coordinates of the field.
roa_name = "test_megafield_id"
roa = FastEMROA(shape=MockEditableShape(),
main_data=self.main_data,
overlap=0.0,
name=roa_name,
slice_index=0)
main_data=self.main_data,
overlap=0.0,
name=roa_name,
slice_index=0)
roa.shape._points = points
roa.shape.points.value = points

Expand Down Expand Up @@ -1295,6 +1295,9 @@ def setUpClass(cls):
cls.mppc.frameDuration.value = 0.1
cls.mppc.cellCompleteResolution.value = (900, 900)
cls.mppc.shape = (8, 8)
cls.mppc.configure_mock(
**{"getMetadata.return_value": {model.MD_CALIB: {"pitch": DEFAULT_PITCH},}}
)

cls.multibeam = Mock()
cls.multibeam.configure_mock(**{"getMetadata.return_value": {model.MD_SCAN_OFFSET_CALIB: [0.01, 0.01],
Expand Down Expand Up @@ -1328,7 +1331,11 @@ def setUpClass(cls):
image[54:150:12, 54:150:12] = 1
image = model.DataArray(input_array=image)
cls.ccd.data.configure_mock(**{"get.return_value": image})
cls.ccd.configure_mock(**{"getMetadata.return_value": {model.MD_FAV_POS_ACTIVE: {"j": 100, "i": 100}}})
cls.ccd.configure_mock(**{"getMetadata.return_value": {
model.MD_FAV_POS_ACTIVE: {"j": 100, "i": 100},
model.MD_SENSOR_PIXEL_SIZE: (3.45e-6, 3.45e-6),
model.MD_LENS_MAG: 40,
}})
cls.ccd.pointSpreadFunctionSize.value = 1
cls.ccd.pixelSize.value = (1.0e-7, 1.0e-7)

Expand Down
14 changes: 12 additions & 2 deletions src/odemis/util/peak_local_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def peak_local_max(
num_peaks: Optional[int] = None,
footprint: Optional[np.ndarray] = None,
p_norm: float = np.inf,
len_object: Optional[int] = None,
) -> np.ndarray:
"""
Find peaks in an image as coordinate list.
Expand All @@ -378,7 +379,7 @@ def peak_local_max(
image : ndarray
Input image.
min_distance : int, optional
The minimal allowed distance separating peaks. To find the
The minimal allowed distance in pixels separating peaks. To find the
maximum number of peaks, use `min_distance=1`.
threshold_abs : float, optional
Minimum intensity of peaks. By default, the absolute threshold is
Expand Down Expand Up @@ -406,6 +407,12 @@ def peak_local_max(
A finite large p may cause a ValueError if overflow can occur.
``inf`` corresponds to the Chebyshev distance and 2 to the
Euclidean distance.
len_object : int, optional
The length of the object in pixels. This parameter is used to determine
the border width for peak exclusion when `exclude_border` is True. If
not provided, `min_distance` is used as the length of the object. This
can be useful when the object size is known and different from
`min_distance`.
Returns
-------
Expand Down Expand Up @@ -460,7 +467,10 @@ def peak_local_max(
stacklevel=2,
)

border_width = _get_excluded_border_width(image, min_distance, exclude_border)
if not len_object:
len_object = min_distance

border_width = _get_excluded_border_width(image, len_object, exclude_border)

threshold = _get_threshold(image, threshold_abs, threshold_rel)

Expand Down
6 changes: 5 additions & 1 deletion src/odemis/util/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def estimate_grid_orientation_from_img(
threshold_abs: Optional[float] = None,
threshold_rel: Optional[float] = None,
num_spots: Optional[int] = None,
min_distance: Optional[int] = None,
) -> Tuple[T, float]:
"""
Image based estimation of the orientation of a square grid of points.
Expand Down Expand Up @@ -559,6 +560,9 @@ def estimate_grid_orientation_from_img(
`num_spots = shape[0] * shape[1]` as default when set to `None`. Set
`num_spots = 0` to not impose a maximum. Note that this behavior is
different from odemis.util.spot.find_spot_position().
min_distance : int, optional
The minimal allowed distance in pixels separating peaks. To find the
maximum number of peaks, use `min_distance=1`.
Returns
-------
Expand All @@ -578,5 +582,5 @@ def estimate_grid_orientation_from_img(
num_spots = shape[0] * shape[1]
elif num_spots == 0:
num_spots = None
ji = find_spot_positions(image, sigma, threshold_abs, threshold_rel, num_spots)
ji = find_spot_positions(image, sigma, threshold_abs, threshold_rel, num_spots, min_distance)
return estimate_grid_orientation(ji, shape, transform_type)
11 changes: 9 additions & 2 deletions src/odemis/util/spot.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def find_spot_positions(
threshold_abs: Optional[float] = None,
threshold_rel: Optional[float] = None,
num_spots: Optional[int] = None,
min_distance: Optional[int] = None,
) -> numpy.ndarray:
"""
Find the center coordinates of spots with the highest intensity in an
Expand All @@ -390,6 +391,9 @@ def find_spot_positions(
num_spots : int, optional
Maximum number of spots. When the number of spots exceeds `num_spots`,
return `num_spots` peaks based on highest spot intensity.
min_distance : int, optional
The minimal allowed distance in pixels separating peaks. To find the
maximum number of peaks, use `min_distance=1`.
Returns
-------
Expand All @@ -399,15 +403,18 @@ def find_spot_positions(
"""
size = int(round(3 * sigma))
min_distance = 2 * size
filtered = bandpass_filter(image, sigma, min_distance)
len_object = 2 * size # typical length of a spot
min_distance = len_object if not min_distance else min_distance # distance between spots
filtered = bandpass_filter(image, sigma, len_object)
coordinates = peak_local_max(
filtered,
min_distance=min_distance,
threshold_abs=threshold_abs,
threshold_rel=threshold_rel,
exclude_border=False,
num_peaks=num_spots,
p_norm=2,
len_object=len_object,
)

# Improve coordinate estimate using radial symmetry center.
Expand Down
20 changes: 20 additions & 0 deletions src/odemis/util/test/spot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,26 @@ def test_multiple(self):
numpy.testing.assert_array_equal(sorted(indices), range(len(loc)))
numpy.testing.assert_array_less(distances, 0.05)

def test_spots_close_to_edge(self):
"""
`find_spot_positions` should find all spot positions in a generated
test image when the minimum distance between spots is larger than the
distance between the spots and the edge of the image.
"""
# set a grid of 8 by 8 points to 1 at the top left of the image
image = numpy.zeros((256, 256))
image[4:100:12, 8:104:12] = 1
expected_ji = numpy.column_stack(numpy.where(image))

ji = spot.find_spot_positions(image, sigma=0.75, min_distance=12)

# Sort both arrays to ensure consistent ordering
expected_ji_sorted = expected_ji[numpy.lexsort((expected_ji[:, 1], expected_ji[:, 0]))]
ji_sorted = ji[numpy.lexsort((ji[:, 1], ji[:, 0]))]

# Check if the sorted arrays are equal
numpy.testing.assert_array_almost_equal(expected_ji_sorted, ji_sorted)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1ec1149

Please sign in to comment.