Skip to content

Commit

Permalink
global optimization working now too!!
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Jan 27, 2025
1 parent bcf58f0 commit 95bf008
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
3 changes: 2 additions & 1 deletion dask-stitch/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ rule compute_pairwise_correlation:

rule global_optimization:
input:
ome_zarr=expand('results/tile-{tile}_SPIM.ome.zarr', tile=range(gridx*gridy)),
pairs='results/overlapping_pairs.txt',
offsets='results/pairwise_offsets.txt'
params:
n_tiles=gridx*gridy
output:
optimized_translations='results/optimized_translations.txt'
script:
Expand Down
80 changes: 44 additions & 36 deletions dask-stitch/scripts/global_optimization.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,73 @@
import numpy as np
from scipy.optimize import least_squares
from zarrnii import ZarrNii

from scipy.sparse import csr_matrix
from scipy.sparse.linalg import lsqr

def global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets):
def calculate_global_translations(pairs, pairwise_offsets, num_tiles):
"""
Perform global optimization to adjust translations for all tiles.
Calculate global translations from pairwise offsets.
Parameters:
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
- overlapping_pairs (list of tuples): List of overlapping tile indices ((i, j)).
- pairwise_offsets (np.ndarray): Array of pairwise offsets (N, 3), where N is the number of pairs.
- pairs (np.ndarray): Array of shape (N, 2) containing indices of overlapping tile pairs.
- pairwise_offsets (np.ndarray): Array of shape (N, 3) containing pairwise offsets for each pair.
- num_tiles (int): Total number of tiles.
Returns:
- np.ndarray: Optimized global translations of shape (T, 3).
- np.ndarray: Global translations for each tile of shape (num_tiles, 3).
"""
# Number of tiles is the number of OME-Zarr paths
num_tiles = len(ome_zarr_paths)
# Number of pairwise offsets
num_pairs = pairs.shape[0]

# Initial translations (start with identity translation: no offsets)
initial_translations = np.zeros((num_tiles, 3))
# Initialize the sparse matrix A and vector b
data = []
row_indices = []
col_indices = []
b = np.zeros((num_pairs * 3,))

# Flatten initial translations for optimization
x0 = initial_translations.flatten()
for k, (i, j) in enumerate(pairs):
# Each pair contributes to three equations (x, y, z components)
for d in range(3): # x=0, y=1, z=2
row = 3 * k + d

def objective(x):
"""
Compute the residuals for global optimization.
# T_j - T_i = O_ij
data.append(-1)
row_indices.append(row)
col_indices.append(3 * i + d) # T_i[d]

Parameters:
- x (np.ndarray): Flattened translations array (T * 3,).
data.append(1)
row_indices.append(row)
col_indices.append(3 * j + d) # T_j[d]

Returns:
- np.ndarray: Residuals for least-squares optimization.
"""
translations = x.reshape((num_tiles, 3))
residuals = []
# Right-hand side
b[row] = pairwise_offsets[k, d]

for (i, j), offset in zip(overlapping_pairs, pairwise_offsets):
# Residual is the difference between the predicted and actual offset
predicted_offset = translations[j] - translations[i]
residuals.append(predicted_offset - offset)
# Convert to sparse matrix
A = csr_matrix((data, (row_indices, col_indices)), shape=(num_pairs * 3, num_tiles * 3))

return np.concatenate(residuals)
# Anchor the first tile (T_0 = [0, 0, 0])
anchor_rows = np.zeros((3, num_tiles * 3))
for d in range(3):
anchor_rows[d, d] = 1
A = csr_matrix(np.vstack([A.toarray(), anchor_rows]))
b = np.hstack([b, [0, 0, 0]])

# Perform least-squares optimization
result = least_squares(objective, x0)
# Solve the linear system using least squares
x = lsqr(A, b)[0]

# Reshape result back to (T, 3)
optimized_translations = result.x.reshape((num_tiles, 3))
# Reshape the result into (num_tiles, 3)
global_translations = x.reshape((num_tiles, 3))

return optimized_translations
return global_translations


# Example usage
overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs
overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int) # Overlapping pairs
pairwise_offsets = np.loadtxt(snakemake.input.offsets, dtype=float) # Pairwise offsets
ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths
n_tiles = snakemake.params.n_tiles

# Perform global optimization
optimized_translations = global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets)
optimized_translations = calculate_global_translations(overlapping_pairs, pairwise_offsets, n_tiles)

# Save results
np.savetxt(snakemake.output.optimized_translations, optimized_translations, fmt="%.6f")
Expand Down

0 comments on commit 95bf008

Please sign in to comment.