Skip to content

Commit

Permalink
fix: problems in orientation + lightweight debug execution
Browse files Browse the repository at this point in the history
* Address the problems of the HCP dataset (which is not "plumb" in AFNI's terms).
* Avoid extrapolation in debug execution to shave off some memory
  • Loading branch information
oesteban committed Nov 18, 2020
1 parent 76f6557 commit f66b148
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 62 deletions.
145 changes: 88 additions & 57 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
"mode", "median", "mean", "no", usedefault=True,
desc="strategy to recenter the distribution of the input fieldmap"
)
extrapolate = traits.Bool(True, usedefault=True,
desc="generate a field, extrapolated outside the brain mask")


class _BSplineApproxOutputSpec(TraitedSpec):
out_field = File(exists=True)
out_coeff = OutputMultiObject(File(exists=True))
out_error = File(exists=True)
out_extrapolated = File()


class BSplineApprox(SimpleInterface):
Expand All @@ -73,8 +76,9 @@ def _run_interface(self, runtime):

# Load in the fieldmap
fmapnii = nb.load(self.inputs.in_data)
data = fmapnii.get_fdata()
nsamples = data.size
data = fmapnii.get_fdata(dtype="float32")
oriented_nii = canonical_orientation(fmapnii)
oriented_nii.to_filename("data.nii.gz")
mask = nb.load(self.inputs.in_mask).get_fdata() > 0
bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing]

Expand All @@ -87,58 +91,35 @@ def _run_interface(self, runtime):
elif self.inputs.recenter == "mean":
data -= np.mean(data[mask])

# Calculate B-Splines grid(s)
bs_levels = []
for sp in bs_spacing:
bs_levels.append(bspline_grid(fmapnii, control_zooms_mm=sp))

# Calculate spatial location of voxels, and normalize per B-Spline grid
fmap_points = grid_coords(fmapnii)
sample_points = []
for sp in bs_spacing:
sample_points.append((fmap_points / sp).astype("float32"))
mask_indices = np.argwhere(mask)
fmap_points = (
oriented_nii.affine.astype("float32") @ (
np.vstack((mask_indices.T, np.ones((1, mask_indices.shape[0],), dtype=int)))
)
)[:3].T

# Calculate the spatial location of control points
bs_levels = []
w_l = []
ncoeff = []
for sp, level, points in zip(bs_spacing, bs_levels, sample_points):
for sp in bs_spacing:
level = bspline_grid(oriented_nii, control_zooms_mm=sp)
bs_levels.append(level)
ncoeff.append(level.dataobj.size)
_w = np.ones((ncoeff[-1], nsamples), dtype="float32")

_gc = grid_coords(level, control_zooms_mm=sp)

for i in range(3):
d = np.abs((_gc[:, np.newaxis, i] - points[np.newaxis, :, i])[_w > 1e-6])
_w[_w > 1e-6] *= np.piecewise(
d,
[d >= 2.0, d < 1.0, (d >= 1.0) & (d < 2)],
[0.,
lambda d: (4. - 6. * d ** 2 + 3. * d ** 3) / 6.,
lambda d: (2. - d) ** 3 / 6.]
)

_w[_w < 1e-6] = 0.0
w_l.append(_w)

# Calculate the cubic spline weights per dimension and tensor-product
weights = np.vstack(w_l)
dist_support = weights > 0.0
w_l.append(bspline_weights(fmap_points, level))

# Compose the interpolation matrix
interp_mat = np.zeros((np.sum(ncoeff), nsamples))
interp_mat[dist_support] = weights[dist_support]
regressors = np.vstack(w_l)

# Fit the model
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
model.fit(
interp_mat[..., mask.reshape(-1)].T, # Regress only within brainmask
data[mask],
)
model.fit(regressors.T, data[mask])

fit_data = (
(np.array(model.coef_) @ interp_mat) # Interpolation
interp_data = np.zeros_like(data)
interp_data[mask] = (
(np.array(model.coef_) @ regressors) # Interpolation
.astype("float32")
.reshape(data.shape)
)

# Store outputs
Expand All @@ -151,7 +132,7 @@ def _run_interface(self, runtime):
)
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
nb.Nifti1Image(fit_data, fmapnii.affine, hdr).to_filename(out_name)
nb.Nifti1Image(interp_data, fmapnii.affine, hdr).to_filename(out_name)
self._results["out_field"] = out_name

index = 0
Expand All @@ -170,11 +151,46 @@ def _run_interface(self, runtime):

# Write out fitting-error map
self._results["out_error"] = out_name.replace("_field.", "_error.")
nb.Nifti1Image(data - fit_data * mask, fmapnii.affine, fmapnii.header).to_filename(
nb.Nifti1Image(data * mask - interp_data, fmapnii.affine, fmapnii.header).to_filename(
self._results["out_error"])

if not self.inputs.extrapolate:
return runtime

bg_indices = np.argwhere(~mask)
bg_points = (
oriented_nii.affine.astype("float32") @ (
np.vstack((bg_indices.T, np.ones((1, bg_indices.shape[0],), dtype=int)))
)
)[:3].T

extrapolators = np.vstack(
[bspline_weights(bg_points, level) for level in bs_levels]
)
interp_data[~mask] = (
(np.array(model.coef_) @ extrapolators) # Extrapolation
.astype("float32")
)
self._results["out_extrapolated"] = out_name.replace("_field.", "_extra.")
nb.Nifti1Image(interp_data, fmapnii.affine, hdr).to_filename(
self._results["out_extrapolated"]
)
return runtime


def canonical_orientation(img):
"""Generate an alternative image aligned with the array axes."""
if isinstance(img, (str, Path)):
img = nb.load(img)

shape = np.array(img.shape[:3])
affine = np.diag(np.hstack((img.header.get_zooms()[:3], 1)))
affine[:3, 3] -= affine[:3, :3] @ (0.5 * (shape - 1))
nii = nb.Nifti1Image(img.dataobj, affine)
nii.header.set_xyzt_units(*img.header.get_xyzt_units())
return nii


def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
"""Calculate a Nifti1Image object encoding the location of control points."""
if isinstance(img, (str, Path)):
Expand All @@ -187,7 +203,8 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
dir_cos = img.affine[:3, :3] / im_zooms

# Initialize the affine of the B-Spline grid
bs_affine = np.diag(np.hstack((np.array(control_zooms_mm) @ dir_cos, 1)))
bs_affine = np.eye(4)
bs_affine[:3, :3] = np.array(control_zooms_mm) * dir_cos
bs_zooms = nb.affines.voxel_sizes(bs_affine)

# Calculate the shape of the B-Spline grid
Expand All @@ -202,17 +219,31 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
return nb.Nifti1Image(np.zeros(bs_shape, dtype="float32"), bs_affine)


def grid_coords(img, control_zooms_mm=None, dtype="float32"):
"""Create a linear space of physical coordinates."""
if isinstance(img, (str, Path)):
img = nb.load(img)

grid = np.array(
np.meshgrid(*[range(s) for s in img.shape[:3]]), dtype=dtype
).reshape(3, -1)
coords = (img.affine @ np.vstack((grid, np.ones(grid.shape[-1])))).T[..., :3]

if control_zooms_mm is not None:
coords /= np.array(control_zooms_mm)
def bspline_weights(points, level):
"""Calculate the tensor-product cubic B-Spline weights for a list of 3D points."""
ctl_spacings = [float(sp) for sp in level.header.get_zooms()[:3]]
ncoeff = level.dataobj.size
ctl_points = (
level.affine.astype("float32") @ (
np.vstack((
np.argwhere(np.asanyarray(level.dataobj) == 0.0).astype("float32").T,
np.ones((1, ncoeff), dtype="float32")
))
)
)[:3].T

weights = np.ones((ncoeff, points.shape[0]), dtype="float32")
for i in range(3):
d = np.abs(
(ctl_points[:, np.newaxis, i] - points[np.newaxis, :, i])[weights > 1e-6]
) / ctl_spacings[i]
weights[weights > 1e-6] *= np.piecewise(
d,
[d >= 2.0, d < 1.0, (d >= 1.0) & (d < 2)],
[0.,
lambda d: (4. - 6. * d ** 2 + 3. * d ** 3) / 6.,
lambda d: (2. - d) ** 3 / 6.]
)

return coords.astype(dtype)
weights[weights < 1e-6] = 0.0
return weights
14 changes: 9 additions & 5 deletions sdcflows/models/fieldmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def init_fmap_wf(omp_nthreads, debug=False, mode="phasediff", name="fmap_wf"):
"""
from ..interfaces.bspline import (
BSplineApprox, DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM
BSplineApprox, DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM, DEFAULT_ZOOMS_MM
)

workflow = Workflow(name=name)
Expand All @@ -99,10 +99,13 @@ def init_fmap_wf(omp_nthreads, debug=False, mode="phasediff", name="fmap_wf"):
)

magnitude_wf = init_magnitude_wf(omp_nthreads=omp_nthreads)
bs_filter = pe.Node(BSplineApprox(
bs_spacing=[DEFAULT_LF_ZOOMS_MM] if debug else [DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM],
), n_procs=omp_nthreads, name="bs_filter")
bs_filter = pe.Node(BSplineApprox(), n_procs=omp_nthreads, name="bs_filter")
bs_filter.interface._always_run = debug
bs_filter.inputs.bs_spacing = (
[DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM] if not debug
else [DEFAULT_ZOOMS_MM]
)
bs_filter.inputs.extrapolate = not debug

# fmt: off
workflow.connect([
Expand All @@ -112,7 +115,8 @@ def init_fmap_wf(omp_nthreads, debug=False, mode="phasediff", name="fmap_wf"):
("outputnode.fmap_mask", "fmap_mask"),
("outputnode.fmap_ref", "fmap_ref"),
]),
(bs_filter, outputnode, [("out_field", "fmap")]),
(bs_filter, outputnode, [
("out_extrapolated" if not debug else "out_field", "fmap")]),
])
# fmt: on

Expand Down

0 comments on commit f66b148

Please sign in to comment.