Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Restore resampling to T1w target #3116

Merged
merged 19 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions fmriprep/interfaces/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class ResampleSeriesInputSpec(TraitedSpec):
in_file = File(exists=True, mandatory=True, desc="3D or 4D image file to resample")
ref_file = File(exists=True, mandatory=True, desc="File to resample in_file to")
transforms = InputMultiObject(
File(exists=True), mandatory=True, desc="Transform files, from in_file to ref_file (image mode)"
File(exists=True),
mandatory=True,
desc="Transform files, from in_file to ref_file (image mode)",
)
inverse = InputMultiObject(
traits.Bool,
Expand All @@ -48,6 +50,16 @@ class ResampleSeriesInputSpec(TraitedSpec):
desc="the phase-encoding direction corresponding to in_data",
)
num_threads = traits.Int(1, usedefault=True, desc="Number of threads to use for resampling")
output_data_type = traits.Str("float32", usedefault=True, desc="Data type of output image")
order = traits.Int(3, usedefault=True, desc="Order of interpolation (0=nearest, 3=cubic)")
mode = traits.Str(
'constant',
usedefault=True,
desc="How data is extended beyond its boundaries. "
"See scipy.ndimage.map_coordinates for more details.",
)
cval = traits.Float(0.0, usedefault=True, desc="Value to fill past edges of data")
prefilter = traits.Bool(True, usedefault=True, desc="Spline-prefilter data if order > 1")


class ResampleSeriesOutputSpec(TraitedSpec):
Expand Down Expand Up @@ -87,13 +99,18 @@ def _run_interface(self, runtime):

pe_info = [(pe_axis, -ro_time if (axis_flip ^ pe_flip) else ro_time)] * nvols

resampled = resample_bold(
resampled = resample_image(
source=source,
target=target,
transforms=transforms,
fieldmap=fieldmap,
pe_info=pe_info,
nthreads=self.inputs.num_threads,
output_dtype=self.inputs.output_data_type,
order=self.inputs.order,
mode=self.inputs.mode,
cval=self.inputs.cval,
prefilter=self.inputs.prefilter,
)
resampled.to_filename(out_path)

Expand All @@ -105,10 +122,16 @@ class ReconstructFieldmapInputSpec(TraitedSpec):
in_coeffs = InputMultiObject(
File(exists=True), mandatory=True, desc="SDCflows-style spline coefficient files"
)
target_ref_file = File(exists=True, mandatory=True, desc="Image to reconstruct the field in alignment with")
fmap_ref_file = File(exists=True, mandatory=True, desc="Reference file aligned with coefficients")
target_ref_file = File(
exists=True, mandatory=True, desc="Image to reconstruct the field in alignment with"
)
fmap_ref_file = File(
exists=True, mandatory=True, desc="Reference file aligned with coefficients"
)
transforms = InputMultiObject(
File(exists=True), mandatory=True, desc="Transform files, from in_file to ref_file (image mode)"
File(exists=True),
mandatory=True,
desc="Transform files, from in_file to ref_file (image mode)",
)
inverse = InputMultiObject(
traits.Bool,
Expand Down Expand Up @@ -252,6 +275,9 @@ def resample_vol(
coordinates = nb.affines.apply_affine(
hmc_xfm, coordinates.reshape(coords_shape[0], -1).T
).T.reshape(coords_shape)
else:
# Copy coordinates to avoid interfering with other calls
coordinates = coordinates.copy()

vsm = fmap_hz * pe_info[1]
coordinates[pe_info[0], ...] += vsm
Expand Down Expand Up @@ -346,15 +372,17 @@ async def resample_series_async(

semaphore = asyncio.Semaphore(max_concurrent)

out_array = np.zeros(coordinates.shape[1:] + data.shape[-1:], dtype=output_dtype)
# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
out_array = np.zeros(coordinates.shape[1:] + data.shape[-1:], dtype=output_dtype, order='F')

tasks = [
asyncio.create_task(
worker(
partial(
resample_vol,
data=volume,
coordinates=coordinates.copy(),
coordinates=coordinates,
pe_info=pe_info[volid],
hmc_xfm=hmc_xfms[volid] if hmc_xfms else None,
fmap_hz=fmap_hz,
Expand Down Expand Up @@ -451,21 +479,26 @@ def resample_series(
)


def resample_bold(
def resample_image(
source: nb.Nifti1Image,
target: nb.Nifti1Image,
transforms: nt.TransformChain,
fieldmap: nb.Nifti1Image | None,
pe_info: list[tuple[int, float]] | None,
nthreads: int = 1,
output_dtype: np.dtype | str | None = 'f4',
order: int = 3,
mode: str = 'constant',
cval: float = 0.0,
prefilter: bool = True,
) -> nb.Nifti1Image:
"""Resample a 4D bold series into a target space, applying head-motion
"""Resample a 3- or 4D image into a target space, applying head-motion
and susceptibility-distortion correction simultaneously.

Parameters
----------
source
The 4D bold series to resample.
The 3D bold image or 4D bold series to resample.
target
An image sampled in the target space.
transforms
Expand All @@ -480,6 +513,17 @@ def resample_bold(
of the data array in the second dimension.
nthreads
Number of threads to use for parallel resampling
output_dtype
The dtype of the output array.
order
Order of interpolation (default: 3 = cubic)
mode
How ``data`` is extended beyond its boundaries. See
:func:`scipy.ndimage.map_coordinates` for more details.
cval
Value to fill past edges of ``data`` if ``mode`` is ``'constant'``.
prefilter
Determines if ``data`` is pre-filtered before interpolation.

Returns
-------
Expand Down Expand Up @@ -527,8 +571,12 @@ def resample_bold(
pe_info=pe_info,
hmc_xfms=hmc_xfms,
fmap_hz=fieldmap.get_fdata(dtype='f4'),
output_dtype='f4',
output_dtype=output_dtype,
nthreads=nthreads,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
resampled_img = nb.Nifti1Image(resampled_data, target.affine, target.header)
resampled_img.set_data_dtype('f4')
Expand Down
3 changes: 3 additions & 0 deletions fmriprep/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ def init_single_subject_wf(subject_id: str):
precomputed=functional_cache,
fieldmap_id=fieldmap_id,
)
if bold_wf is None:
continue

bold_wf.__desc__ = func_pre_desc + (bold_wf.__desc__ or "")

workflow.connect([
Expand Down
118 changes: 118 additions & 0 deletions fmriprep/workflows/bold/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import nipype.interfaces.utility as niu
import nipype.pipeline.engine as pe
from niworkflows.interfaces.header import ValidateImage
from niworkflows.interfaces.nibabel import GenerateSamplingReference
from niworkflows.interfaces.utility import KeySelect
from niworkflows.utils.connections import listify

Expand All @@ -25,6 +26,110 @@
from niworkflows.utils.spaces import SpatialReferences


def init_bold_volumetric_resample_wf(
*,
metadata: dict,
fieldmap_id: str | None = None,
omp_nthreads: int = 1,
name: str = 'bold_volumetric_resample_wf',
) -> pe.Workflow:
workflow = pe.Workflow(name=name)

inputnode = pe.Node(
niu.IdentityInterface(
fields=[
"bold_file",
"bold_ref_file",
"target_ref_file",
"target_mask",
# HMC
"motion_xfm",
# SDC
"boldref2fmap_xfm",
"fmap_ref",
"fmap_coeff",
"fmap_id",
# Anatomical
"boldref2anat_xfm",
# Template
"anat2std_xfm",
],
),
name='inputnode',
)

outputnode = pe.Node(niu.IdentityInterface(fields=["bold_file"]), name='outputnode')

gen_ref = pe.Node(GenerateSamplingReference(), name='gen_ref', mem_gb=0.3)

boldref2target = pe.Node(niu.Merge(2), name='boldref2target')
bold2target = pe.Node(niu.Merge(2), name='bold2target')
resample = pe.Node(ResampleSeries(), name="resample", n_procs=omp_nthreads)

workflow.connect([
(inputnode, gen_ref, [
('bold_ref_file', 'moving_image'),
('target_ref_file', 'fixed_image'),
('target_mask', 'fov_mask'),
]),
(inputnode, boldref2target, [
('boldref2anat_xfm', 'in1'),
('anat2std_xfm', 'in2'),
]),
(inputnode, bold2target, [('motion_xfm', 'in1')]),
(inputnode, resample, [('bold_file', 'in_file')]),
(gen_ref, resample, [('out_file', 'ref_file')]),
(boldref2target, bold2target, [('out', 'in2')]),
(bold2target, resample, [('out', 'transforms')]),
(resample, outputnode, [('out_file', 'bold_file')]),
]) # fmt:skip

if not fieldmap_id:
return workflow

fmap_select = pe.Node(
KeySelect(fields=["fmap_ref", "fmap_coeff"], key=fieldmap_id),
name="fmap_select",
run_without_submitting=True,
)
distortion_params = pe.Node(
DistortionParameters(metadata=metadata),
name="distortion_params",
run_without_submitting=True,
)
fmap2target = pe.Node(niu.Merge(2), name='fmap2target')
inverses = pe.Node(niu.Function(function=_gen_inverses), name='inverses')

fmap_recon = pe.Node(ReconstructFieldmap(), name="fmap_recon")

workflow.connect([
(inputnode, fmap_select, [
("fmap_ref", "fmap_ref"),
("fmap_coeff", "fmap_coeff"),
("fmap_id", "keys"),
]),
(inputnode, distortion_params, [('bold_file', 'in_file')]),
(inputnode, fmap2target, [('boldref2fmap_xfm', 'in1')]),
(gen_ref, fmap_recon, [('out_file', 'target_ref_file')]),
(boldref2target, fmap2target, [('out', 'in2')]),
(boldref2target, inverses, [('out', 'inlist')]),
(fmap_select, fmap_recon, [
("fmap_coeff", "in_coeffs"),
("fmap_ref", "fmap_ref_file"),
]),
(fmap2target, fmap_recon, [('out', 'transforms')]),
(inverses, fmap_recon, [('out', 'inverse')]),
# Inject fieldmap correction into resample node
(distortion_params, resample, [
("readout_time", "ro_time"),
("pe_direction", "pe_dir"),
]),
(fmap_recon, resample, [('out_file', 'fieldmap')]),
]) # fmt:skip

return workflow


def init_bold_apply_wf(
*,
spaces: SpatialReferences,
Expand All @@ -49,3 +154,16 @@ def init_bold_apply_wf(
# )

return workflow


def _gen_inverses(inlist: list) -> list[bool]:
"""Create a list indicating the first transform should be inverted.

The input list is the collection of transforms that follow the
inverted one.
"""
from niworkflows.utils.connections import listify

if not inlist:
return [True]
return [True] + [False] * len(listify(inlist))
Loading