From 7d3f10dea67e31cb4fb5c037e606b1c39c6671d4 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 13 Nov 2020 10:55:26 +0100 Subject: [PATCH] ENH: A 3D tensor B-Spline approximator and extrapolator This PR finally adds an implementation for B-Spline smoothing and extrapolation of fieldmaps. References: #71, #22. Resolves: #72. Resolves: #14. --- sdcflows/interfaces/bspline.py | 188 ++++++++++++++++++++ sdcflows/workflows/fit/fieldmap.py | 26 +-- sdcflows/workflows/fit/tests/test_phdiff.py | 2 +- setup.cfg | 1 + 4 files changed, 203 insertions(+), 14 deletions(-) create mode 100644 sdcflows/interfaces/bspline.py diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py new file mode 100644 index 0000000000..bf2e344b27 --- /dev/null +++ b/sdcflows/interfaces/bspline.py @@ -0,0 +1,188 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +""" +B-Spline filtering. + + .. testsetup:: + + >>> tmpdir = getfixture('tmpdir') + >>> tmp = tmpdir.chdir() # changing to a temporary directory + >>> nb.Nifti1Image(np.zeros((90, 90, 60)), None, None).to_filename( + ... tmpdir.join('epi.nii.gz').strpath) + +""" +from pathlib import Path +import numpy as np +import nibabel as nb + +from nipype.utils.filemanip import fname_presuffix +from nipype.interfaces.base import ( + BaseInterfaceInputSpec, + TraitedSpec, + File, + traits, + SimpleInterface, + InputMultiObject, + OutputMultiObject, +) + + +DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm +DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm +DEFAULT_HF_ZOOMS_MM = (16.0, 16.0, 10.0) # For human adults (high-frequency), in mm + + +class _BSplineApproxInputSpec(BaseInterfaceInputSpec): + in_data = File(exists=True, mandatory=True, desc="path to a fieldmap") + in_mask = File(exists=True, mandatory=True, desc="path to a brain mask") + bs_spacing = InputMultiObject( + [DEFAULT_ZOOMS_MM], + traits.Tuple(traits.Float, traits.Float, traits.Float), + usedefault=True, + desc="spacing between B-Spline control points", + ) + ridge_alpha = traits.Float( + 1e-4, usedefault=True, desc="controls the regularization" + ) + + +class _BSplineApproxOutputSpec(TraitedSpec): + out_field = File(exists=True) + out_coeff = OutputMultiObject(File(exists=True)) + + +class BSplineApprox(SimpleInterface): + """ + Approximate the field to smooth it removing spikes and extrapolating beyond the brain mask. + + Examples + -------- + + """ + + input_spec = _BSplineApproxInputSpec + output_spec = _BSplineApproxOutputSpec + + def _run_interface(self, runtime): + from gridbspline.maths import cubic + from sklearn import linear_model as lm + + _vbspl = np.vectorize(cubic) + + # Load in the fieldmap + fmapnii = nb.load(self.inputs.in_data) + data = fmapnii.get_fdata() + mask = nb.load(self.inputs.in_mask).get_fdata() > 0 + bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing] + + # Calculate B-Splines grid(s) + bs_levels = [] + for sp in bs_spacing: + bs_levels.append(bspline_grid(fmapnii, control_zooms_mm=sp)) + + # Calculate spatial location of voxels, and normalize per B-Spline grid + fmap_points = grid_coords(fmapnii) + sample_points = [] + for sp in bs_spacing: + sample_points.append((fmap_points / sp).astype("float32")) + + # Calculate the spatial location of control points + bs_x = [] + ncoeff = [] + for sp, level, points in zip(bs_spacing, bs_levels, sample_points): + ncoeff.append(level.dataobj.size) + control_points = grid_coords(level, control_zooms_mm=sp) + bs_x.append(control_points[:, np.newaxis, :] - points[np.newaxis, ...]) + + # Calculate the cubic spline weights per dimension and tensor-product + dist = np.vstack(bs_x) + dist_support = (np.abs(dist) < 2).all(axis=-1) + weights = _vbspl(dist[dist_support]).prod(axis=-1) + + # Compose the interpolation matrix + interp_mat = np.zeros(dist.shape[:2]) + interp_mat[dist_support] = weights + + # Fit the model + model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False) + model.fit( + interp_mat[..., mask.reshape(-1)].T, # Regress only within brainmask + data[mask], + ) + + # Store outputs + out_name = str( + Path( + fname_presuffix( + self.inputs.in_data, suffix="_field", newpath=runtime.cwd + ) + ).absolute() + ) + hdr = fmapnii.header.copy() + hdr.set_data_dtype("float32") + nb.Nifti1Image( + (model.intercept_ + np.array(model.coef_) @ interp_mat) + .astype("float32") # Interpolation + .reshape(data.shape), + fmapnii.affine, + hdr, + ).to_filename(out_name) + self._results["out_field"] = out_name + + index = 0 + self._results["out_coeff"] = [] + for i, (n, bsl) in enumerate(zip(ncoeff, bs_levels)): + out_level = out_name.replace("_field.", f"_coeff{i:03}.") + nb.Nifti1Image( + np.array(model.coef_, dtype="float32")[index : index + n].reshape( + bsl.shape + ), + bsl.affine, + bsl.header, + ).to_filename(out_level) + index += n + self._results["out_coeff"].append(out_level) + return runtime + + +def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM): + """Calculate a Nifti1Image object encoding the location of control points.""" + if isinstance(img, (str, Path)): + img = nb.load(img) + + im_zooms = np.array(img.header.get_zooms()) + im_shape = np.array(img.shape[:3]) + + # Calculate the direction cosines of the target image + dir_cos = img.affine[:3, :3] / im_zooms + + # Initialize the affine of the B-Spline grid + bs_affine = np.diag(np.hstack((np.array(control_zooms_mm) @ dir_cos, 1))) + bs_zooms = nb.affines.voxel_sizes(bs_affine) + + # Calculate the shape of the B-Spline grid + im_extent = im_zooms * (im_shape - 1) + bs_shape = (im_extent // bs_zooms + 3).astype(int) + + # Center both images + im_center = img.affine @ np.hstack((0.5 * (im_shape - 1), 1)) + bs_center = bs_affine @ np.hstack((0.5 * (bs_shape - 1), 1)) + bs_affine[:3, 3] = im_center[:3] - bs_center[:3] + + return nb.Nifti1Image(np.zeros(bs_shape, dtype="float32"), bs_affine) + + +def grid_coords(img, control_zooms_mm=None, dtype="float32"): + """Create a linear space of physical coordinates.""" + if isinstance(img, (str, Path)): + img = nb.load(img) + + grid = np.array( + np.meshgrid(*[range(s) for s in img.shape[:3]]), dtype=dtype + ).reshape(3, -1) + coords = (img.affine @ np.vstack((grid, np.ones(grid.shape[-1])))).T[..., :3] + + if control_zooms_mm is not None: + coords /= np.array(control_zooms_mm) + + return coords.astype(dtype) diff --git a/sdcflows/workflows/fit/fieldmap.py b/sdcflows/workflows/fit/fieldmap.py index 9d0b272288..892097d89a 100644 --- a/sdcflows/workflows/fit/fieldmap.py +++ b/sdcflows/workflows/fit/fieldmap.py @@ -109,7 +109,7 @@ from niworkflows.engine.workflows import LiterateWorkflow as Workflow -def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"): +def init_fmap_wf(omp_nthreads=1, debug=False, mode="phasediff", name="fmap_wf"): """ Estimate the fieldmap based on a field-mapping MRI acquisition. @@ -156,6 +156,10 @@ def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"): pair. """ + from ...interfaces.bspline import ( + BSplineApprox, DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM + ) + workflow = Workflow(name=name) inputnode = pe.Node( @@ -167,19 +171,19 @@ def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"): ) magnitude_wf = init_magnitude_wf(omp_nthreads=omp_nthreads) - fmap_postproc_wf = init_fmap_postproc_wf(omp_nthreads=omp_nthreads) + bs_filter = pe.Node(BSplineApprox( + bs_spacing=[DEFAULT_LF_ZOOMS_MM] if debug else [DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM], + ), n_procs=omp_nthreads, name="bs_filter") # fmt: off workflow.connect([ (inputnode, magnitude_wf, [("magnitude", "inputnode.magnitude")]), - (magnitude_wf, fmap_postproc_wf, [ - ("outputnode.fmap_mask", "inputnode.fmap_mask"), - ("outputnode.fmap_ref", "inputnode.fmap_ref")]), + (magnitude_wf, bs_filter, [("outputnode.fmap_mask", "in_mask")]), (magnitude_wf, outputnode, [ ("outputnode.fmap_mask", "fmap_mask"), ("outputnode.fmap_ref", "fmap_ref"), ]), - (fmap_postproc_wf, outputnode, [("outputnode.out_fmap", "fmap")]), + (bs_filter, outputnode, [("out_field", "fmap")]), ]) # fmt: on @@ -198,13 +202,12 @@ def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"): ("outputnode.fmap_ref", "inputnode.magnitude"), ("outputnode.fmap_mask", "inputnode.mask"), ]), - (phdiff_wf, fmap_postproc_wf, [ - ("outputnode.fieldmap", "inputnode.fmap"), + (phdiff_wf, bs_filter, [ + ("outputnode.fieldmap", "in_data"), ]), ]) # fmt: on else: - from niworkflows.interfaces.nibabel import ApplyMask from niworkflows.interfaces.images import IntraModalMerge workflow.__desc__ = """\ @@ -215,13 +218,10 @@ def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"): fmapmrg = pe.Node( IntraModalMerge(zero_based_avg=False, hmc=False), name="fmapmrg" ) - applymsk = pe.Node(ApplyMask(), name="applymsk") # fmt: off workflow.connect([ (inputnode, fmapmrg, [("fieldmap", "in_files")]), - (fmapmrg, applymsk, [("out_avg", "in_file")]), - (magnitude_wf, applymsk, [("outputnode.fmap_mask", "in_mask")]), - (applymsk, fmap_postproc_wf, [("out_file", "inputnode.fmap")]), + (fmapmrg, bs_filter, [("out_avg", "in_data")]), ]) # fmt: on diff --git a/sdcflows/workflows/fit/tests/test_phdiff.py b/sdcflows/workflows/fit/tests/test_phdiff.py index a8ed63e291..8365de5f5f 100644 --- a/sdcflows/workflows/fit/tests/test_phdiff.py +++ b/sdcflows/workflows/fit/tests/test_phdiff.py @@ -34,7 +34,7 @@ def test_phdiff(tmpdir, datadir, workdir, outdir, fmap_path): wf = Workflow( name=f"phdiff_{fmap_path[0].name.replace('.nii.gz', '').replace('-', '_')}" ) - phdiff_wf = init_fmap_wf(omp_nthreads=2) + phdiff_wf = init_fmap_wf(omp_nthreads=2, debug=True) phdiff_wf.inputs.inputnode.fieldmap = fieldmaps phdiff_wf.inputs.inputnode.magnitude = [ f.replace("diff", "1").replace("phase", "magnitude") for f, _ in fieldmaps diff --git a/setup.cfg b/setup.cfg index 4a7a2bef48..326ef98316 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ setup_requires = setuptools_scm >= 3.4 toml install_requires = + gridbspline nibabel >=3.0.1 niflow-nipype1-workflows ~= 0.0.1 nipype >=1.5.1,<2.0