Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 Loaders and util for H5 and NIfTI transforms (propagation of contributions by @sgiavasis) #213

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
129 changes: 107 additions & 22 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""

from pathlib import Path
import numpy as np
import h5py
Expand Down Expand Up @@ -146,13 +147,13 @@ def from_arrays(cls, coordinates, triangles):
darrays = [
nb.gifti.GiftiDataArray(
coordinates.astype(np.float32),
intent=nb.nifti1.intent_codes['NIFTI_INTENT_POINTSET'],
datatype=nb.nifti1.data_type_codes['NIFTI_TYPE_FLOAT32'],
intent=nb.nifti1.intent_codes["NIFTI_INTENT_POINTSET"],
datatype=nb.nifti1.data_type_codes["NIFTI_TYPE_FLOAT32"],
),
nb.gifti.GiftiDataArray(
triangles.astype(np.int32),
intent=nb.nifti1.intent_codes['NIFTI_INTENT_TRIANGLE'],
datatype=nb.nifti1.data_type_codes['NIFTI_TYPE_INT32'],
intent=nb.nifti1.intent_codes["NIFTI_INTENT_TRIANGLE"],
datatype=nb.nifti1.data_type_codes["NIFTI_TYPE_INT32"],
),
]
gii = nb.gifti.GiftiImage(darrays=darrays)
Expand Down Expand Up @@ -251,14 +252,57 @@ class TransformBase:
__slots__ = (
"_reference",
"_ndim",
"_affine",
"_shape",
"_header",
"_grid",
"_mapping",
Comment on lines +255 to +259
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking like TransformBase is becoming an affine transform. I will try to get it back to a generic, abstract base class.

"_hdf5_dct",
"_x5_dct",
)

def __init__(self, reference=None):
x5_struct = {
"TransformGroup/0": {
"Type": None,
"Transform": None,
"Metadata": None,
"Inverse": None,
},
"TransformGroup/0/Domain": {"Grid": None, "Size": None, "Mapping": None},
"TransformGroup/1": {},
"TransformChain": {},
}
Comment on lines +264 to +274
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to make this struct some sort of config option. Is there such a concept in nibabel, @effigies?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, but what kind of config option are you thinking? I'm very wary of putting mutable global state in a library, as that just asks for nasty interactions with multiple libraries (and threads). We could think about making a context manager to update state, along with ContextVars to manage state, if needed.

But possibly you're talking about something else? This looks like a place for a schema, but I haven't studied this code yet.


def __init__(
self,
x5=None,
hdf5=None,
nifti=None,
shape=None,
affine=None,
header=None,
reference=None,
):
"""Instantiate a transform."""

self._reference = None
if reference:
self.reference = reference

if nifti is not None:
self._x5_dct = self.init_x5_structure(nifti)
elif hdf5:
self.update_x5_structure(hdf5)
elif x5:
self.update_x5_structure(x5)
self._shape = shape
self._affine = affine
self._header = header

# TO-DO
self._grid = None
self._mapping = None

Comment on lines +292 to +305
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To linear transforms base

def __call__(self, x, inverse=False):
"""Apply y = f(x)."""
return self.map(x, inverse=inverse)
Expand Down Expand Up @@ -295,6 +339,12 @@ def ndim(self):
"""Access the dimensions of the reference space."""
raise TypeError("TransformBase has no dimensions")

def init_x5_structure(self, xfm_data=None):
self.x5_struct["TransformGroup/0/Transform"] = xfm_data

def update_x5_structure(self, hdf5_struct=None):
self.x5_struct.update(hdf5_struct)

def map(self, x, inverse=False):
r"""
Apply :math:`y = f(x)`.
Expand All @@ -316,33 +366,68 @@ def map(self, x, inverse=False):
"""
return x

def to_filename(self, filename, fmt="X5"):
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
with h5py.File(filename, "w") as out_file:
out_file.attrs["Format"] = "X5"
out_file.attrs["Version"] = np.uint16(1)
root = out_file.create_group("/0")
self._to_hdf5(root)

return filename

def _to_hdf5(self, x5_root):
"""Serialize this object into the x5 file format."""
raise NotImplementedError

def apply(self, *args, **kwargs):
"""Apply the transform to a dataset.

Deprecated. Please use ``nitransforms.resampling.apply`` instead.
"""
message = (
"The `apply` method is deprecated. Please use `nitransforms.resampling.apply` instead."
)
message = "The `apply` method is deprecated. Please use `nitransforms.resampling.apply` instead."
warnings.warn(message, DeprecationWarning, stacklevel=2)
from .resampling import apply

return apply(self, *args, **kwargs)

def _to_hdf5(self, x5_root):
"""Serialize this object into the x5 file format."""
transform_group = x5_root.create_group("TransformGroup")

"""Group '0' containing Affine transform"""
transform_0 = transform_group.create_group("0")

transform_0.attrs["Type"] = "Affine"
transform_0.create_dataset("Transform", data=self._matrix)
transform_0.create_dataset("Inverse", data=np.linalg.inv(self._matrix))

metadata = {"key": "value"}
transform_0.attrs["Metadata"] = str(metadata)

"""sub-group 'Domain' contained within group '0' """
domain_group = transform_0.create_group("Domain")
domain_group.attrs["Grid"] = self.grid
domain_group.create_dataset("Size", data=_as_homogeneous(self._reference.shape))
domain_group.create_dataset("Mapping", data=self.map)

raise NotImplementedError

def read_x5(self, x5_root):
variables = {}
with h5py.File(x5_root, "r") as f:
f.visititems(
lambda filename, x5_root: self._from_hdf5(filename, x5_root, variables)
)

_transform = variables["TransformGroup/0/Transform"]
_inverse = variables["TransformGroup/0/Inverse"]
_size = variables["TransformGroup/0/Domain/Size"]
_map = variables["TransformGroup/0/Domain/Mapping"]

return _transform, _inverse, _size, _map

def _from_hdf5(self, name, x5_root, storage):
if isinstance(x5_root, h5py.Dataset):
storage[name] = {
"type": "dataset",
"attrs": dict(x5_root.attrs),
"shape": x5_root.shape,
"data": x5_root[()], # Read the data
}
elif isinstance(x5_root, h5py.Group):
storage[name] = {
"type": "group",
"attrs": dict(x5_root.attrs),
"members": {},
}

Comment on lines +402 to +430
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go into specific loaders :)


def _as_homogeneous(xyz, dtype="float32", dim=3):
"""
Expand Down
67 changes: 60 additions & 7 deletions nitransforms/cli.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skipping this file for now

Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import os
from textwrap import dedent

from nitransforms.base import TransformBase
from nitransforms.io.base import xfm_loader
from nitransforms.linear import load as linload
from nitransforms.nonlinear import load as nlinload
from nitransforms.resampling import apply

from .linear import load as linload
from .nonlinear import load as nlinload
from .resampling import apply

import pprint

def cli_apply(pargs):
"""
Expand All @@ -32,8 +34,8 @@ def cli_apply(pargs):

xfm = (
nlinload(pargs.transform, fmt=fmt)
if pargs.nonlinear else
linload(pargs.transform, fmt=fmt)
if pargs.nonlinear
else linload(pargs.transform, fmt=fmt)
)

# ensure a reference is set
Expand All @@ -47,8 +49,43 @@ def cli_apply(pargs):
cval=pargs.cval,
prefilter=pargs.prefilter,
)
moved.to_filename(pargs.out or f"nt_{os.path.basename(pargs.moving)}")
# moved.to_filename(pargs.out or f"nt_{os.path.basename(pargs.moving)}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# moved.to_filename(pargs.out or f"nt_{os.path.basename(pargs.moving)}")
moved.to_filename(pargs.out or f"nt_{os.path.basename(pargs.moving)}")



def cli_xfm_util(pargs):
""" """

xfm_data = xfm_loader(pargs.transform)
xfm_x5 = TransformBase(**xfm_data)

if pargs.info:
pprint.pprint(xfm_x5.x5_struct)
print(f"Shape:\n{xfm_x5._shape}")
print(f"Affine:\n{xfm_x5._affine}")

if pargs.x5:
filename = f"{os.path.basename(pargs.transform).split('.')[0]}.x5"
xfm_x5.to_filename(filename)
print(f"Writing out {filename}")


def cli_xfm_util(pargs):
"""
"""

xfm_data = xfm_loader(pargs.transform)
xfm_x5 = TransformBase(**xfm_data)

if pargs.info:
pprint.pprint(xfm_x5.x5_struct)
print(f"Shape:\n{xfm_x5._shape}")
print(f"Affine:\n{xfm_x5._affine}")

if pargs.x5:
filename = f"{os.path.basename(pargs.transform).split('.')[0]}.x5"
xfm_x5.to_filename(filename)
print(f"Writing out {filename}")


def get_parser():
desc = dedent(
Expand All @@ -58,6 +95,7 @@ def get_parser():
Commands:

apply Apply a transformation to an image
xfm_util Assorted transform utilities

For command specific information, use 'nt <command> -h'.
"""
Expand Down Expand Up @@ -122,6 +160,17 @@ def _add_subparser(name, description):
help="Determines if the image's data array is prefiltered with a spline filter before "
"interpolation (default: True)",
)

xfm_util = _add_subparser("xfm_util", cli_xfm_util.__doc__)
xfm_util.set_defaults(func=cli_xfm_util)
xfm_util.add_argument("transform", help="The transform file")
xfm_util.add_argument(
"--info", action="store_true", help="Get information about the transform"
)
xfm_util.add_argument(
"--x5", action="store_true", help="Convert transform to .x5 file format."
)

return parser, subparsers


Expand All @@ -135,3 +184,7 @@ def main(pargs=None):
subparser = subparsers.choices[pargs.command]
subparser.print_help()
raise (e)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion nitransforms/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Read and write transforms."""
from nitransforms.io import afni, fsl, itk, lta
from nitransforms.io import afni, fsl, itk, lta, x5
from nitransforms.io.base import TransformIOError, TransformFileError

__all__ = [
Expand All @@ -22,6 +22,7 @@
"fs": (lta, "FSLinearTransform"),
"fsl": (fsl, "FSLLinearTransform"),
"afni": (afni, "AFNILinearTransform"),
"x5": (x5, "X5Transform"),
}


Expand Down
Loading
Loading