Skip to content

Commit

Permalink
WIP - pausing for now
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Dec 12, 2024
1 parent c972839 commit 10e5a0d
Show file tree
Hide file tree
Showing 3 changed files with 711 additions and 154 deletions.
111 changes: 6 additions & 105 deletions dask-stitch/scripts/fuse_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def wrapped_process_chunk(block, block_info=None):
wrapped_process_chunk,
chunks=chunk_shape,
dtype=np.float32,
shape=fused_shape,
#shape=fused_shape,
)

return fused_volume
Expand All @@ -175,112 +175,13 @@ def wrapped_process_chunk(block, block_info=None):
# Fuse the volume
fused_volume = fuse_volume(ome_zarr_paths, optimized_translations, fused_shape, chunk_shape, bbox_min, voxel_size)

# Save the fused volume
fused_volume.to_zarr(snakemake.output.fused_volume)




def process_chunk(chunk_data, block_info, ome_zarr_paths, optimized_translations, bbox_min, voxel_size, chunk_shape):
"""
Process a single chunk of the fused volume.
Parameters:
- chunk_data (np.ndarray): Placeholder data for the chunk.
- block_info (dict): Information about the block being processed.
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
- optimized_translations (np.ndarray): Optimized translations for each tile.
- bbox_min (np.ndarray): Minimum physical coordinates of the fused volume.
- voxel_size (np.ndarray): Voxel size in physical units.
- chunk_shape (tuple): Shape of the chunk being processed.
Returns:
- np.ndarray: Fused chunk.
"""
# Extract chunk location and physical bounding box
chunk_start = np.array(block_info[0]["chunk-location"]) * voxel_size + bbox_min
chunk_end = chunk_start + np.array(chunk_shape) * voxel_size

chunk = np.zeros(chunk_shape, dtype=np.float32)
weight = np.zeros(chunk_shape, dtype=np.float32)

for path, translation in zip(ome_zarr_paths, optimized_translations):
znimg = ZarrNii.from_path(path)
tile = znimg.darr.squeeze().compute()
affine = znimg.vox2ras.affine

# Resample tile to chunk
resampled_tile = resample_tile_to_chunk(tile, affine, translation, chunk_start, chunk_end, chunk_shape)

# Fuse by summing intensities and weights
mask = resampled_tile > 0
chunk[mask] += resampled_tile[mask]
weight[mask] += 1

# Avoid division by zero
fused_chunk = np.divide(chunk, weight, out=np.zeros_like(chunk), where=weight > 0)
return fused_chunk


def fuse_volume(ome_zarr_paths, optimized_translations, fused_shape, chunk_shape, bbox_min, voxel_size):
"""
Fuse all tiles into a single volume.
Parameters:
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
- optimized_translations (np.ndarray): Optimized translations for each tile.
- fused_shape (tuple): Shape of the final fused volume.
- chunk_shape (tuple): Shape of each chunk.
- bbox_min (np.ndarray): Minimum coordinates of the fused volume.
- voxel_size (np.ndarray): Voxel size in physical units.
Returns:
- dask.array: Fused volume.
"""
# Wrap process_chunk for Dask
def wrapped_process_chunk(chunk_data, block_info=None):
return process_chunk(
chunk_data,
block_info,
ome_zarr_paths,
optimized_translations,
bbox_min,
voxel_size,
chunk_shape,
)

# Define Dask array for the fused volume
fused_volume = da.map_blocks(
wrapped_process_chunk,
chunks=chunk_shape,
dtype=np.float32,
shape=fused_shape,
)

return fused_volume


# Example usage
ome_zarr_paths = snakemake.input.ome_zarr # List of input OME-Zarr paths
optimized_translations = np.loadtxt(snakemake.input.optimized_translations, dtype=float) # Optimized translations

# Compute the fused volume shape
bbox_min, bbox_max, voxel_size = compute_fused_volume_shape(ome_zarr_paths, optimized_translations)


print(bbox_min)
print(bbox_max)
print(voxel_size)
assert all(voxel_size > 0), "Voxel size must be greater than zero."

fused_shape = tuple(np.ceil((bbox_max - bbox_min) / voxel_size).astype(int))
chunk_shape = (64, 64, 64) # Example chunk size

# Fuse the volume
fused_volume = fuse_volume(ome_zarr_paths, optimized_translations, fused_shape, chunk_shape, bbox_min, voxel_size)
print(fused_volume.shape)

# Save the fused volume
fused_volume.to_zarr(snakemake.output.fused_volume)
znimg = ZarrNii.from_darr(fused_volume)

znimg.to_ome_zarr(snakemake.output.ome_zarr)
znimg.to_nifti(snakemake.output.nifti)



Expand Down
Loading

0 comments on commit 10e5a0d

Please sign in to comment.