From 95bf008ee62ce04896e322f5574d63f249ae1243 Mon Sep 17 00:00:00 2001 From: Ali Khan Date: Mon, 27 Jan 2025 12:01:37 -0500 Subject: [PATCH] global optimization working now too!! --- dask-stitch/Snakefile | 3 +- dask-stitch/scripts/global_optimization.py | 80 ++++++++++++---------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/dask-stitch/Snakefile b/dask-stitch/Snakefile index 4e99a87..a3464f4 100644 --- a/dask-stitch/Snakefile +++ b/dask-stitch/Snakefile @@ -66,9 +66,10 @@ rule compute_pairwise_correlation: rule global_optimization: input: - ome_zarr=expand('results/tile-{tile}_SPIM.ome.zarr', tile=range(gridx*gridy)), pairs='results/overlapping_pairs.txt', offsets='results/pairwise_offsets.txt' + params: + n_tiles=gridx*gridy output: optimized_translations='results/optimized_translations.txt' script: diff --git a/dask-stitch/scripts/global_optimization.py b/dask-stitch/scripts/global_optimization.py index 9ce994f..8993973 100644 --- a/dask-stitch/scripts/global_optimization.py +++ b/dask-stitch/scripts/global_optimization.py @@ -1,65 +1,73 @@ import numpy as np from scipy.optimize import least_squares -from zarrnii import ZarrNii +from scipy.sparse import csr_matrix +from scipy.sparse.linalg import lsqr -def global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets): +def calculate_global_translations(pairs, pairwise_offsets, num_tiles): """ - Perform global optimization to adjust translations for all tiles. + Calculate global translations from pairwise offsets. Parameters: - - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. - - overlapping_pairs (list of tuples): List of overlapping tile indices ((i, j)). - - pairwise_offsets (np.ndarray): Array of pairwise offsets (N, 3), where N is the number of pairs. + - pairs (np.ndarray): Array of shape (N, 2) containing indices of overlapping tile pairs. + - pairwise_offsets (np.ndarray): Array of shape (N, 3) containing pairwise offsets for each pair. + - num_tiles (int): Total number of tiles. Returns: - - np.ndarray: Optimized global translations of shape (T, 3). + - np.ndarray: Global translations for each tile of shape (num_tiles, 3). """ - # Number of tiles is the number of OME-Zarr paths - num_tiles = len(ome_zarr_paths) + # Number of pairwise offsets + num_pairs = pairs.shape[0] - # Initial translations (start with identity translation: no offsets) - initial_translations = np.zeros((num_tiles, 3)) + # Initialize the sparse matrix A and vector b + data = [] + row_indices = [] + col_indices = [] + b = np.zeros((num_pairs * 3,)) - # Flatten initial translations for optimization - x0 = initial_translations.flatten() + for k, (i, j) in enumerate(pairs): + # Each pair contributes to three equations (x, y, z components) + for d in range(3): # x=0, y=1, z=2 + row = 3 * k + d - def objective(x): - """ - Compute the residuals for global optimization. + # T_j - T_i = O_ij + data.append(-1) + row_indices.append(row) + col_indices.append(3 * i + d) # T_i[d] - Parameters: - - x (np.ndarray): Flattened translations array (T * 3,). + data.append(1) + row_indices.append(row) + col_indices.append(3 * j + d) # T_j[d] - Returns: - - np.ndarray: Residuals for least-squares optimization. - """ - translations = x.reshape((num_tiles, 3)) - residuals = [] + # Right-hand side + b[row] = pairwise_offsets[k, d] - for (i, j), offset in zip(overlapping_pairs, pairwise_offsets): - # Residual is the difference between the predicted and actual offset - predicted_offset = translations[j] - translations[i] - residuals.append(predicted_offset - offset) + # Convert to sparse matrix + A = csr_matrix((data, (row_indices, col_indices)), shape=(num_pairs * 3, num_tiles * 3)) - return np.concatenate(residuals) + # Anchor the first tile (T_0 = [0, 0, 0]) + anchor_rows = np.zeros((3, num_tiles * 3)) + for d in range(3): + anchor_rows[d, d] = 1 + A = csr_matrix(np.vstack([A.toarray(), anchor_rows])) + b = np.hstack([b, [0, 0, 0]]) - # Perform least-squares optimization - result = least_squares(objective, x0) + # Solve the linear system using least squares + x = lsqr(A, b)[0] - # Reshape result back to (T, 3) - optimized_translations = result.x.reshape((num_tiles, 3)) + # Reshape the result into (num_tiles, 3) + global_translations = x.reshape((num_tiles, 3)) - return optimized_translations + return global_translations # Example usage -overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs +overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int) # Overlapping pairs pairwise_offsets = np.loadtxt(snakemake.input.offsets, dtype=float) # Pairwise offsets -ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths +n_tiles = snakemake.params.n_tiles # Perform global optimization -optimized_translations = global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets) +optimized_translations = calculate_global_translations(overlapping_pairs, pairwise_offsets, n_tiles) # Save results np.savetxt(snakemake.output.optimized_translations, optimized_translations, fmt="%.6f")