Skip to content

Commit

Permalink
WIP: getting the full stiching pipeline tested out
Browse files Browse the repository at this point in the history
- ground truth translation not correct yet when nonzero
  • Loading branch information
akhanf committed Jan 26, 2025
1 parent de52398 commit 84e0986
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 83 deletions.
44 changes: 39 additions & 5 deletions dask-stitch/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ wildcard_constraints:

rule all:
input:
'results/fused_SPIM.nii'
'results/fused_SPIM.nii',
'results/fused_desc-groundtruth_SPIM.nii'

rule create_test_dataset_single_ome_zarr:
params:
Expand All @@ -25,8 +26,17 @@ rule create_test_dataset_single_ome_zarr:
ome_zarr=directory('results/tile-{tile}_SPIM.ome.zarr'),
nifti='results/tile-{tile}_SPIM.nii',
nifti_fromzarr='results/tile-{tile}_fromzarr_SPIM.nii',
true_offset='results/tile-{tile}_desc-groundtruth_offset.txt'
script: 'scripts/create_test_dataset_singletile.py'

rule concat_ground_truth_translations:
input:
ome_zarr=expand('results/tile-{tile}_desc-groundtruth_offset.txt', tile=range(gridx*gridy)),
output:
true_translations='results/groundtruth_translations.txt'
shell:
'cat {input} > {output}'


rule find_overlapping_pairs:
input:
Expand Down Expand Up @@ -55,25 +65,49 @@ rule global_optimization:
'scripts/global_optimization.py'



rule assign_translations:
input:
ome_zarr=expand('results/tile-{tile}_SPIM.ome.zarr', tile=range(gridx*gridy)),
optimized_translations='results/optimized_translations.txt'
translations='results/optimized_translations.txt'
output:
nifti=expand('results/tile-{tile}_optimized_SPIM.nii', tile=range(gridx*gridy))
niftis=expand('results/tile-{tile}_optimized_SPIM.nii', tile=range(gridx*gridy)),
ome_zarrs=directory(expand('results/tile-{tile}_optimized_SPIM.ome.zarr', tile=range(gridx*gridy)))
script:
"scripts/assign_translation.py"

rule fuse_volume:

rule assign_true_translations:
input:
ome_zarr=expand('results/tile-{tile}_SPIM.ome.zarr', tile=range(gridx*gridy)),
#optimized_translations='results/optimized_translations.txt'
translations='results/groundtruth_translations.txt'
output:
niftis=expand('results/tile-{tile}_groundtruth_SPIM.nii', tile=range(gridx*gridy)),
ome_zarrs=directory(expand('results/tile-{tile}_groundtruth_SPIM.ome.zarr', tile=range(gridx*gridy)))
script:
"scripts/assign_translation.py"


rule fuse_volume:
input:
ome_zarr=expand('results/tile-{tile}_optimized_SPIM.ome.zarr', tile=range(gridx*gridy)),
nifti=expand('results/tile-{tile}_optimized_SPIM.nii', tile=range(gridx*gridy)),
output:
ome_zarr = directory('results/fused_SPIM.ome.zarr'),
nifti = 'results/fused_SPIM.nii'
script:
'scripts/fuse_volume.py'

rule fuse_true_volume:
input:
ome_zarr=expand('results/tile-{tile}_groundtruth_SPIM.ome.zarr', tile=range(gridx*gridy)),
nifti=expand('results/tile-{tile}_groundtruth_SPIM.nii', tile=range(gridx*gridy)),
output:
ome_zarr = directory('results/fused_desc-groundtruth_SPIM.ome.zarr'),
nifti = 'results/fused_desc-groundtruth_SPIM.nii'
script:
'scripts/fuse_volume.py'



rule test_rtree_index:
Expand Down
59 changes: 59 additions & 0 deletions dask-stitch/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,65 @@
#dask.config.set(scheduler='synchronous') # overwrite default with single-threaded scheduler


def find_overlapping_pairs(ome_zarr_paths):
"""
Identify overlapping tile pairs based on their physical offsets.
Parameters:
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
Returns:
- List of tuples: Each tuple is a pair of overlapping tile indices ((i, j)).
"""
from zarrnii import ZarrNii

# Read physical transformations and calculate bounding boxes
bounding_boxes = []
for path in ome_zarr_paths:
znimg = ZarrNii.from_ome_zarr(path)

tile_shape = znimg.darr.shape[1:]

affine = np.eye(4,4)
affine[:3, :3] = np.diag(znimg.get_zooms(axes_order=znimg.axes_order)) # Set the zooms (scaling factors) along the diagonal
affine[:3, 3] = znimg.get_origin(axes_order=znimg.axes_order) # Set the translation (origin)

# Compute physical bounding box using affine
corners = [
np.array([0, 0, 0, 1]),
np.array([tile_shape[0], 0, 0, 1]),
np.array([0, tile_shape[1], 0, 1]),
np.array([tile_shape[0], tile_shape[1], 0, 1]),
np.array([0, 0, tile_shape[0], 1]),
np.array([tile_shape[0], 0, tile_shape[2], 1]),
np.array([0, tile_shape[1], tile_shape[2], 1]),
np.array([tile_shape[0], tile_shape[1], tile_shape[2], 1]),
]
corners_physical = np.dot(affine, np.array(corners).T).T[:, :3] # Drop homogeneous coordinate
bbox_min = corners_physical.min(axis=0)
bbox_max = corners_physical.max(axis=0)

bounding_boxes.append((bbox_min, bbox_max))

# Find overlapping pairs
overlapping_pairs = []
for i, (bbox1_min, bbox1_max) in enumerate(bounding_boxes):
for j, (bbox2_min, bbox2_max) in enumerate(bounding_boxes):
if i >= j:
continue # Avoid duplicate pairs and self-comparison

# Check for overlap in all dimensions
overlap = all(
bbox1_min[d] < bbox2_max[d] and bbox1_max[d] > bbox2_min[d]
for d in range(3)
)
if overlap:
overlapping_pairs.append((i, j))

return overlapping_pairs



def compute_chunk_bounding_boxes(dask_array, zooms, origin, tile_index):
"""
Compute the bounding boxes of each chunk in a Dask array in physical space.
Expand Down
31 changes: 16 additions & 15 deletions dask-stitch/scripts/assign_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
import numpy as np
from zarrnii import ZarrNii

def assign_translations(ome_zarr_paths, optimized_translations, output_paths):
def assign_translations(ome_zarr_paths, translations, output_niftis, output_ome_zarrs):
"""
Update the OME-Zarr datasets with optimized translations by modifying the vox2ras.
Parameters:
- ome_zarr_paths (list of str): List of input OME-Zarr dataset paths.
- optimized_translations (np.ndarray): Optimized translations (T, 3).
- output_paths (list of str): List of paths to save the updated OME-Zarr datasets.
- translations (np.ndarray): Optimized translations (T, 3).
- output_niftis (list of str): List of paths to save the updated nifti datasets.
- output_ome_zarrs (list of str): List of paths to save the updated OME-Zarr datasets.
"""
if len(ome_zarr_paths) != len(optimized_translations):
if len(ome_zarr_paths) != len(translations):
raise ValueError("Number of OME-Zarr paths must match the number of optimized translations.")

for i, (path, translation, output_path) in enumerate(zip(ome_zarr_paths, optimized_translations, output_paths)):
for i, (path, translation, out_nii, out_zarr) in enumerate(zip(ome_zarr_paths, translations, output_niftis, output_ome_zarrs)):
print(path)
print(output_path)
print(translation)
# Load the OME-Zarr dataset
znimg = ZarrNii.from_path(path)
znimg = ZarrNii.from_ome_zarr(path)
# Update the affine matrix
updated_affine = znimg.vox2ras.affine.copy()
updated_affine = znimg.affine
print(f'original affine: {updated_affine}')

#HACK FIX:
Expand All @@ -34,20 +34,21 @@ def assign_translations(ome_zarr_paths, optimized_translations, output_paths):
updated_affine[:3, 3] += translation # Add the optimized translation
print(f'updated affine: {updated_affine}')

znimg.vox2ras.affine = updated_affine
znimg.ras2vox.affine = np.linalg.inv(updated_affine)
znimg.affine = updated_affine

# Save the updated ZarrNii
znimg.to_nifti(output_path)
znimg.to_nifti(out_nii)
znimg.to_ome_zarr(out_zarr)

print(f"Updated translations saved to: {output_path}")
print(f"Updated translations saved to: {out_nii} and {out_zarr}")


# Example usage
ome_zarr_paths = snakemake.input.ome_zarr # List of input OME-Zarr paths
optimized_translations = np.loadtxt(snakemake.input.optimized_translations, dtype=float) # Optimized translations
output_paths = snakemake.output.nifti # List of output nifti paths
translations = np.loadtxt(snakemake.input.translations, dtype=float) # Optimized translations
output_niftis = snakemake.output.niftis # List of output nifti paths
output_ome_zarrs = snakemake.output.ome_zarrs # List of output nifti paths

# Assign translations
assign_translations(ome_zarr_paths, optimized_translations, output_paths)
assign_translations(ome_zarr_paths, translations, output_niftis, output_ome_zarrs)

16 changes: 9 additions & 7 deletions dask-stitch/scripts/compute_pairwise_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,20 @@ def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape

for pair_index, (i, j) in enumerate(overlapping_pairs):
# Load the two images and their affines
znimg1 = ZarrNii.from_path(ome_zarr_paths[i])
znimg2 = ZarrNii.from_path(ome_zarr_paths[j])
znimg1 = ZarrNii.from_ome_zarr(ome_zarr_paths[i])
znimg2 = ZarrNii.from_ome_zarr(ome_zarr_paths[j])

img1 = znimg1.darr.squeeze().compute()
img2 = znimg2.darr.squeeze().compute()

affine1 = znimg1.vox2ras.affine
affine2 = znimg2.vox2ras.affine
#HACK FIX
# affine1[:3,3] = -1 * np.flip(affine1[:3,3])
# affine2[:3,3] = -1 * np.flip(affine2[:3,3])
#TODO: should make this a class member function, or just return the reordered affine
affine1 = np.eye(4,4)
affine1[:3, :3] = np.diag(znimg1.get_zooms(axes_order=znimg1.axes_order)) # Set the zooms (scaling factors) along the diagonal
affine1[:3, 3] = znimg1.get_origin(axes_order=znimg1.axes_order) # Set the translation (origin)

affine2 = np.eye(4,4)
affine2[:3, :3] = np.diag(znimg2.get_zooms(axes_order=znimg2.axes_order)) # Set the zooms (scaling factors) along the diagonal
affine2[:3, 3] = znimg2.get_origin(axes_order=znimg2.axes_order) # Set the translation (origin)


# Compute the corrected bounding boxes
Expand Down
10 changes: 8 additions & 2 deletions dask-stitch/scripts/create_test_dataset_singletile.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,15 @@ def create_test_dataset_single(tile_index, template="MNI152NLin2009cAsym", res=2


# TODO: Simulate error by applying a transformation to the image before

random_offset_low=-0
random_offset_high=0
# initially lets just do a random jitter
offset = np.random.uniform(random_offset_low, random_offset_high, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets for each tile

#save this offset to a text file so we know the ground truth
np.savetxt(snakemake.output.true_offset, -offset[x,y,:].reshape((1,3)), fmt="%.6f")

# initially lets just do a random jitter:
offset = np.random.uniform(0, 0, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets for each tile

xfm_img_data = affine_transform(img_data,matrix=np.eye(3,3),offset=offset[x,y,:],order=3,mode='nearest')

Expand Down
55 changes: 1 addition & 54 deletions dask-stitch/scripts/find_overlapping_pairs.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,5 @@
import numpy as np

def find_overlapping_pairs(ome_zarr_paths):
"""
Identify overlapping tile pairs based on their physical offsets.
Parameters:
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
Returns:
- List of tuples: Each tuple is a pair of overlapping tile indices ((i, j)).
"""
from zarrnii import ZarrNii

# Read physical transformations and calculate bounding boxes
bounding_boxes = []
for path in ome_zarr_paths:
znimg = ZarrNii.from_path(path)
affine = znimg.vox2ras.affine # 4x4 matrix
tile_shape = znimg.darr.shape[1:]

# Compute physical bounding box using affine
corners = [
np.array([0, 0, 0, 1]),
np.array([tile_shape[2], 0, 0, 1]),
np.array([0, tile_shape[1], 0, 1]),
np.array([tile_shape[2], tile_shape[1], 0, 1]),
np.array([0, 0, tile_shape[0], 1]),
np.array([tile_shape[2], 0, tile_shape[0], 1]),
np.array([0, tile_shape[1], tile_shape[0], 1]),
np.array([tile_shape[2], tile_shape[1], tile_shape[0], 1]),
]
corners_physical = np.dot(affine, np.array(corners).T).T[:, :3] # Drop homogeneous coordinate
bbox_min = corners_physical.min(axis=0)
bbox_max = corners_physical.max(axis=0)

bounding_boxes.append((bbox_min, bbox_max))

# Find overlapping pairs
overlapping_pairs = []
for i, (bbox1_min, bbox1_max) in enumerate(bounding_boxes):
for j, (bbox2_min, bbox2_max) in enumerate(bounding_boxes):
if i >= j:
continue # Avoid duplicate pairs and self-comparison

# Check for overlap in all dimensions
overlap = all(
bbox1_min[d] < bbox2_max[d] and bbox1_max[d] > bbox2_min[d]
for d in range(3)
)
if overlap:
overlapping_pairs.append((i, j))

return overlapping_pairs

from lib.utils import *

overlapping_pairs = find_overlapping_pairs(snakemake.input)
np.savetxt(snakemake.output.txt,np.array(overlapping_pairs),fmt='%d')
Expand Down
3 changes: 3 additions & 0 deletions dask-stitch/scripts/fuse_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# Paths to input tiles
ome_zarr_paths = snakemake.input.ome_zarr




# Build R-tree index for all chunks in all tiles
rtree_idx, znimgs = build_rtree_index(ome_zarr_paths)

Expand Down

0 comments on commit 84e0986

Please sign in to comment.