Skip to content

Commit

Permalink
Save output with the same format as brainreg (#66)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Tyson <code@adamltyson.com>
  • Loading branch information
IgorTatarnikov and adamltyson authored Jan 10, 2025
1 parent 123b25b commit 408a3e6
Show file tree
Hide file tree
Showing 29 changed files with 1,122 additions and 117 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: tests
on:
push:
branches:
- '*'
- 'main'
tags:
- '*'
pull_request:
Expand Down
235 changes: 195 additions & 40 deletions brainglobe_registration/elastix/register.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,22 @@
from typing import List, Tuple
from pathlib import Path
from typing import List, Optional, Tuple

import itk
import numpy as np
import numpy.typing as npt
from brainglobe_atlasapi import BrainGlobeAtlas


def get_atlas_by_name(atlas_name: str) -> BrainGlobeAtlas:
"""
Get a BrainGlobeAtlas object by its name.
Parameters
----------
atlas_name : str
The name of the atlas.
Returns
-------
BrainGlobeAtlas
The BrainGlobeAtlas object.
"""
atlas = BrainGlobeAtlas(atlas_name)

return atlas
from brainglobe_registration.utils.utils import (
convert_atlas_labels,
restore_atlas_labels,
)


def run_registration(
atlas_image,
moving_image,
annotation_image,
atlas_image: npt.NDArray,
moving_image: npt.NDArray,
parameter_lists: List[Tuple[str, dict]],
) -> Tuple[npt.NDArray, itk.ParameterObject, npt.NDArray]:
output_directory: Optional[Path] = None,
) -> Tuple[npt.NDArray, itk.ParameterObject]:
"""
Run the registration process on the given images.
Expand All @@ -40,19 +26,17 @@ def run_registration(
The atlas image.
moving_image : npt.NDArray
The moving image.
annotation_image : npt.NDArray
The annotation image.
parameter_lists : List[tuple[str, dict]], optional
The list of parameter lists, by default None
parameter_lists : List[tuple[str, dict]]
The list of registration parameters, one for each transform.
output_directory : Optional[Path], optional
The output directory for the registration results, by default None
Returns
-------
npt.NDArray
The result image.
itk.ParameterObject
The result transform parameters.
npt.NDArray
The transformed annotation image.
"""
# convert to ITK, view only
atlas_image = itk.GetImageViewFromArray(atlas_image).astype(itk.F)
Expand All @@ -66,33 +50,204 @@ def run_registration(
parameter_object = setup_parameter_object(parameter_lists=parameter_lists)

elastix_object.SetParameterObject(parameter_object)

# update filter object
elastix_object.UpdateLargestPossibleRegion()

# get results
result_image = elastix_object.GetOutput()
result_transform_parameters = elastix_object.GetTransformParameterObject()
temp_interp_order = result_transform_parameters.GetParameter(

if output_directory:
file_names = [
f"{output_directory}/TransformParameters.{i}.txt"
for i in range(len(parameter_lists))
]

itk.ParameterObject.WriteParameterFile(
result_transform_parameters, file_names
)

return (
np.asarray(result_image),
result_transform_parameters,
)


def transform_annotation_image(
annotation_image: npt.NDArray[np.uint32],
transform_parameters: itk.ParameterObject,
) -> npt.NDArray[np.uint32]:
"""
Transform the annotation image using the given transform parameters.
Sets the FinalBSplineInterpolationOrder to 0 to avoid interpolation.
Resets the FinalBSplineInterpolationOrder to its original value after
transforming the annotation image.
Parameters
----------
annotation_image : npt.NDArray
The annotation image.
transform_parameters : itk.ParameterObject
The transform parameters.
Returns
-------
npt.NDArray
The transformed annotation image.
"""
adjusted_annotation_image, mapping = convert_atlas_labels(annotation_image)

annotation_image = itk.GetImageFromArray(adjusted_annotation_image).astype(
itk.F
)
temp_interp_order = transform_parameters.GetParameter(
0, "FinalBSplineInterpolationOrder"
)
result_transform_parameters.SetParameter(
"FinalBSplineInterpolationOrder", "0"
transform_parameters.SetParameter("FinalBSplineInterpolationOrder", "0")

transformix_object = itk.TransformixFilter.New(annotation_image)
transformix_object.SetTransformParameterObject(transform_parameters)
transformix_object.UpdateLargestPossibleRegion()

transformed_annotation = transformix_object.GetOutput()

transform_parameters.SetParameter(
"FinalBSplineInterpolationOrder", temp_interp_order
)
transformed_annotation_array = np.asarray(transformed_annotation).astype(
np.uint32
)

annotation_image_transformix = itk.transformix_filter(
annotation_image.astype(np.float32, copy=False),
result_transform_parameters,
transformed_annotation_array = restore_atlas_labels(
transformed_annotation_array, mapping
)

return transformed_annotation_array


def transform_image(
image: npt.NDArray,
transform_parameters: itk.ParameterObject,
) -> npt.NDArray:
"""
Transform the image using the given transform parameters.
Parameters
----------
image: npt.NDArray
The image to transform.
transform_parameters: itk.ParameterObject
The transform parameters.
Returns
-------
npt.NDArray
The transformed image.
"""
image = itk.GetImageViewFromArray(image).astype(itk.F)

transformix_object = itk.TransformixFilter.New(image)
transformix_object.SetTransformParameterObject(transform_parameters)
transformix_object.UpdateLargestPossibleRegion()

transformed_image = transformix_object.GetOutput()

return np.asarray(transformed_image)


def calculate_deformation_field(
moving_image: npt.NDArray,
transform_parameters: itk.ParameterObject,
debug: bool = False,
) -> npt.NDArray:
"""
Calculate the deformation field for the moving image using the given
transform parameters.
Parameters
----------
moving_image : npt.NDArray
The moving image.
transform_parameters : itk.ParameterObject
The transform parameters.
debug : bool, optional
Whether to save extra files for debugging, by default False
Returns
-------
npt.NDArray
The deformation field.
"""
transformix_object = itk.TransformixFilter.New(
itk.GetImageViewFromArray(moving_image).astype(itk.F),
transform_parameters,
)
transformix_object.SetComputeDeformationField(True)

transformix_object.UpdateLargestPossibleRegion()

# Change from ITK to numpy axes ordering
deformation_field = itk.GetArrayFromImage(
transformix_object.GetOutputDeformationField()
)[..., ::-1]

if not debug:
# Cleanup files generated by elastix
(Path.cwd() / "DeformationField.tiff").unlink(missing_ok=True)

return deformation_field


def invert_transformation(
fixed_image: npt.NDArray,
parameter_list: List[Tuple[str, dict]],
transform_parameters: itk.ParameterObject,
output_directory: Optional[Path] = None,
) -> itk.ParameterObject:

fixed_image = itk.GetImageFromArray(fixed_image).astype(itk.F)

elastix_object = itk.ElastixRegistrationMethod.New(
fixed_image, fixed_image
)

parameter_object_inverse = setup_parameter_object(parameter_list)

elastix_object.SetInitialTransformParameterObject(transform_parameters)

elastix_object.SetParameterObject(parameter_object_inverse)

elastix_object.UpdateLargestPossibleRegion()

num_initial_transforms = transform_parameters.GetNumberOfParameterMaps()

result_image = elastix_object.GetOutput()
out_parameters = elastix_object.GetTransformParameterObject()
result_transform_parameters = itk.ParameterObject.New()

for i in range(
num_initial_transforms, out_parameters.GetNumberOfParameterMaps()
):
result_transform_parameters.AddParameterMap(
out_parameters.GetParameterMap(i)
)

result_transform_parameters.SetParameter(
"FinalBSplineInterpolationOrder", temp_interp_order
0, "InitialTransformParameterFileName", "NoInitialTransform"
)

if output_directory:
file_names = [
f"{output_directory}/InverseTransformParameters.{i}.txt"
for i in range(len(parameter_list))
]

itk.ParameterObject.WriteParameterFiles(
result_transform_parameters, file_names
)

return (
np.asarray(result_image),
result_transform_parameters,
np.asarray(annotation_image_transformix),
)


Expand Down
4 changes: 2 additions & 2 deletions brainglobe_registration/parameters/ara_tools/affine.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

//ImageTypes
(FixedInternalImagePixelType "float")
(FixedImageDimension 3)
(FixedImageDimension 2)
(MovingInternalImagePixelType "float")
(MovingImageDimension 3)
(MovingImageDimension 2)

//Components
(Registration "MultiResolutionRegistration")
Expand Down
4 changes: 2 additions & 2 deletions brainglobe_registration/parameters/ara_tools/bspline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

//ImageTypes
(FixedInternalImagePixelType "float")
(FixedImageDimension 3)
(FixedImageDimension 2)
(MovingInternalImagePixelType "float")
(MovingImageDimension 3)
(MovingImageDimension 2)

//Components
(Registration "MultiResolutionRegistration")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
(Registration "MultiResolutionRegistration")
(ResampleInterpolator "FinalBSplineInterpolator")
(Resampler "DefaultResampler")
(ResultImageFormat "nii")
(ResultImageFormat "tiff")
(Transform "AffineTransform")
(WriteIterationInfo "false")
(WriteResultImage "true")
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
(Registration "MultiMetricMultiResolutionRegistration")
(ResampleInterpolator "FinalBSplineInterpolator")
(Resampler "DefaultResampler")
(ResultImageFormat "nii")
(ResultImageFormat "tiff")
(Transform "BSplineTransform")
(WriteIterationInfo "false")
(WriteResultImage "true")
24 changes: 0 additions & 24 deletions brainglobe_registration/parameters/elastix_default/rigid.txt

This file was deleted.

Loading

0 comments on commit 408a3e6

Please sign in to comment.