Skip to content

Commit

Permalink
Merge pull request #5 from khanlab/optim-dask
Browse files Browse the repository at this point in the history
Reduce runtime for warping to template, by initial downsampling
  • Loading branch information
akhanf authored Jun 3, 2024
2 parents d5565a9 + c8fdff1 commit a970349
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 25 deletions.
39 changes: 37 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ ome-zarr = "^0.9.0"
pybids = "^0.16.5"
sparse = "^0.15.1"
zarrnii = "^0.1.1a1"
bokeh = "^3.4.1"

[tool.poetry.scripts]
spimquant = "spimquant.run:main"
Expand Down
31 changes: 18 additions & 13 deletions spimquant/workflow/rules/templatereg.smk
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ rule affine_zarr_to_template_nii:
xfm_ras=rules.affine_reg.output.xfm_ras,
ref_nii=bids_tpl(root=root, template="{template}", suffix="anat.nii.gz"),
params:
chunks=(50, 50, 50),
zooms=None,
ref_opts={"chunks": (1, 50, 50, 50)},
output:
nii=bids(
root=root,
Expand All @@ -298,8 +297,7 @@ rule affine_zarr_to_template_ome_zarr:
xfm_ras=rules.affine_reg.output.xfm_ras,
ref_nii=bids_tpl(root=root, template="{template}", suffix="anat.nii.gz"),
params:
chunks=(50, 50, 50),
zooms=None,
ref_opts={"chunks": (1, 50, 50, 50)},
output:
ome_zarr=directory(
bids(
Expand All @@ -324,8 +322,10 @@ rule deform_zarr_to_template_nii:
warp_nii=rules.deform_reg.output.warp,
ref_nii=bids_tpl(root=root, template="{template}", suffix="anat.nii.gz"),
params:
chunks=(50, 50, 50),
zooms=None,
flo_opts={"level": 2}, #downsampling level to use (TODO: set this automatically based on ref resolution?)
do_downsample=True, #whether to perform further downsampling before transforming
downsample_opts={'along_z': 4}, #could also be determined automatically
ref_opts={"chunks": (1, 100, 100, 100)},
output:
nii=bids(
root=root,
Expand All @@ -337,6 +337,7 @@ rule deform_zarr_to_template_nii:
**inputs["spim"].wildcards
),
threads: 32
container: None
script:
"../scripts/deform_to_template_nii.py"

Expand All @@ -348,13 +349,17 @@ rule deform_to_template_nii_zoomed:
warp_nii=rules.deform_reg.output.warp,
ref_nii=bids_tpl(root=root, template="{template}", suffix="anat.nii.gz"),
params:
chunks=(50, 50, 50),
zooms=lambda wildcards: (
float(wildcards.res) / 1000,
float(wildcards.res) / 1000,
float(wildcards.res) / 1000,
),
#None #same resolution as template if NOne
flo_opts={"level": 1}, #downsampling level to use (TODO: set this automatically based on ref resolution?)
do_downsample=False, #whether to perform further downsampling before transforming
downsample_opts={'along_z': 4}, #could also be determined automatically
ref_opts=lambda wildcards: {
"chunks": (1, 50, 50, 50),
"zooms": (
float(wildcards.res) / 1000,
float(wildcards.res) / 1000,
float(wildcards.res) / 1000,
),
},
output:
nii=bids(
root=root,
Expand Down
2 changes: 1 addition & 1 deletion spimquant/workflow/scripts/affine_to_template_nii.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#member function of floating image
flo_znimg = ZarrNii.from_path(snakemake.input.ome_zarr, channels=[channel_index])
ref_znimg = ZarrNii.from_path_as_ref(snakemake.input.ref_nii, channels=[channel_index],chunks=snakemake.params.chunks,zooms=snakemake.params.zooms)
ref_znimg = ZarrNii.from_path_as_ref(snakemake.input.ref_nii, channels=[channel_index],**snakemake.params.ref_opts)

out_znimg = flo_znimg.apply_transform(Transform.affine_ras_from_txt(snakemake.input.xfm_ras),ref_znimg=ref_znimg)

Expand Down
2 changes: 1 addition & 1 deletion spimquant/workflow/scripts/affine_to_template_ome_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


flo_znimg = ZarrNii.from_path(snakemake.input.ome_zarr, channels=[channel_index])
ref_znimg = ZarrNii.from_path_as_ref(snakemake.input.ref_nii, channels=[channel_index],chunks=snakemake.params.chunks,zooms=snakemake.params.zooms)
ref_znimg = ZarrNii.from_path_as_ref(snakemake.input.ref_nii, channels=[channel_index],**snakemake.params.ref_opts)

out_znimg = flo_znimg.apply_transform(Transform.affine_ras_from_txt(snakemake.input.xfm_ras),ref_znimg=ref_znimg)

Expand Down
20 changes: 13 additions & 7 deletions spimquant/workflow/scripts/deform_to_template_nii.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import zarr
import nibabel as nib
from zarrnii import ZarrNii, Transform
from dask.diagnostics import ProgressBar
from dask.distributed import Client

client = Client(n_workers=4, threads_per_worker=2,processes=False)
print(client.dashboard_link)

#get channel index from omero metadata
zi = zarr.open(snakemake.input.ome_zarr)
Expand All @@ -11,17 +14,20 @@


#member function of floting image
flo_znimg = ZarrNii.from_path(snakemake.input.ome_zarr, channels=[channel_index])
ref_znimg = ZarrNii.from_path_as_ref(snakemake.input.ref_nii, channels=[channel_index],chunks=snakemake.params.chunks,zooms=snakemake.params.zooms)
flo_znimg = ZarrNii.from_path(snakemake.input.ome_zarr, channels=[channel_index], **snakemake.params.flo_opts)
ref_znimg = ZarrNii.from_path_as_ref(snakemake.input.ref_nii, channels=[channel_index],**snakemake.params.ref_opts)

if snakemake.params.do_downsample:
flo_ds_znimg = flo_znimg.downsample(**snakemake.params.downsample_opts)

deform_znimg = flo_znimg.apply_transform(Transform.displacement_from_nifti(snakemake.input.warp_nii),

deform_znimg = flo_ds_znimg.apply_transform(Transform.displacement_from_nifti(snakemake.input.warp_nii),
Transform.affine_ras_from_txt(snakemake.input.xfm_ras),
ref_znimg=ref_znimg)

with ProgressBar():

deform_znimg.to_nifti(snakemake.output.nii, scheduler='single-threaded')

deform_znimg.to_nifti(snakemake.output.nii)



client.close()
2 changes: 1 addition & 1 deletion spimquant/workflow/scripts/deform_to_template_ome_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dask.diagnostics import ProgressBar

flo_znimg = ZarrNii.from_path(snakemake.input.ome_zarr, channels=[snakemake.params.channel_index])
ref_znimg = ZarrNii.from_path_as_ref(snakemake.input.ref_nii, channels=[snakemake.params.channel_index],chunks=snakemake.params.chunks,zooms=snakemake.params.zooms)
ref_znimg = ZarrNii.from_path_as_ref(snakemake.input.ref_nii, channels=[snakemake.params.channel_index],**snakemake.params.ref_opts)


out_znimg = flo_znimg.apply_transform(Transform.displacement_from_nifti(snakemake.input.warp_nii),
Expand Down

0 comments on commit a970349

Please sign in to comment.