Skip to content

Commit

Permalink
Update string equality tests for h5py 3.1 compatibility (#263)
Browse files Browse the repository at this point in the history
By converting nexus attribute values to str 's. These can now sometimes
be numpy.string_, sometimes str, depending on how the value was written.
Rather than converting everything to numpy.string_, convert to str
immediately when read, so we don't need to convert literals all over the
place. In the best case, this is a no-op.

Fixes #267 

Co-authored-by: Nicholas Devenish <ndevenish@gmail.com>
  • Loading branch information
dwpaley and ndevenish authored Dec 15, 2020
1 parent 3ba32c0 commit 687e231
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 67 deletions.
136 changes: 69 additions & 67 deletions format/nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import math
import os
from builtins import range
from typing import Union

import h5py
import numpy
Expand Down Expand Up @@ -39,6 +39,19 @@
raise


def h5str(h5_value: Union[str, numpy.string_, bytes]) -> str:
"""
Convert a value returned an h5py attribute to str.
h5py can return either a bytes-like (numpy.string_) or str object
for attribute values depending on whether the value was written as
fixed or variable length. This function collapses the two to str.
"""
if hasattr(h5_value, "decode"):
return h5_value.decode("utf-8")
return h5_value


def dataset_as_flex(dataset, selection):
if numpy.issubdtype(dataset.dtype, numpy.integer):
return dataset_as_flex_int(dataset.id.id, selection)
Expand Down Expand Up @@ -106,29 +119,25 @@ def find_entries(nx_file, entry):

def visitor(name, obj):
if "NX_class" in obj.attrs:
if numpy.string_(obj.attrs["NX_class"]) in [
numpy.string_("NXentry"),
numpy.string_("NXsubentry"),
]:
if h5str(obj.attrs["NX_class"]) in ["NXentry", "NXsubentry"]:
if "definition" in obj:
if obj["definition"][()] == numpy.string_("NXmx"):
if h5str(obj["definition"][()]) == "NXmx":
hits.append(obj)

visitor(entry, nx_file[entry])
local_visit(nx_file, visitor)
return hits


def find_class(nx_file, nx_class):
def find_class(nx_file, nx_class: str):
"""
Find a given NXclass
"""
hits = []
nx_class = numpy.string_(nx_class)

def visitor(name, obj):
if numpy.string_("NX_class") in obj.attrs:
if numpy.string_(obj.attrs["NX_class"]) == nx_class:
if "NX_class" in obj.attrs:
if h5str(obj.attrs["NX_class"]) == nx_class:
hits.append(obj)

local_visit(nx_file, visitor)
Expand Down Expand Up @@ -181,8 +190,9 @@ def visit_dependencies(nx_file, item, visitor=None):
if os.path.basename(item) == "depends_on":
depends_on = nx_file[item][()]
else:
depends_on = nx_file[item].attrs["depends_on"]
while not depends_on == numpy.string_("."):
depends_on = h5str(nx_file[item].attrs["depends_on"])

while not depends_on == ".":
if visitor:
visitor(nx_file, depends_on)
if depends_on in dependency_chain:
Expand All @@ -193,7 +203,7 @@ def visit_dependencies(nx_file, item, visitor=None):
raise RuntimeError("'%s' is missing from nx_file" % depends_on)
dependency_chain.add(depends_on)
try:
depends_on = nx_file[depends_on].attrs["depends_on"]
depends_on = h5str(nx_file[depends_on].attrs["depends_on"])
except Exception:
raise RuntimeError("'%s' contains no depends_on attribute" % depends_on)

Expand All @@ -210,20 +220,20 @@ def __init__(self, vector):
def visit(self, nx_file, depends_on):
item = nx_file[depends_on]
value = item[()]
units = item.attrs["units"]
ttype = item.attrs["transformation_type"]
units = h5str(item.attrs["units"])
ttype = h5str(item.attrs["transformation_type"])
vector = matrix.col(item.attrs["vector"])
if ttype == numpy.string_("translation"):
if ttype == "translation":
value = convert_units(value, units, "mm")
if hasattr(value, "__iter__") and len(value) == 1:
value = value[0]
self.vector = vector * value + self.vector
elif ttype == numpy.string_("rotation"):
elif ttype == "rotation":
if hasattr(value, "__iter__") and len(value):
value = value[0]
if numpy.string_(units) == numpy.string_("rad"):
if units == "rad":
deg = False
elif numpy.string_(units) == numpy.string_("deg"):
elif units == "deg":
deg = True
else:
raise RuntimeError("Invalid units: %s" % units)
Expand All @@ -233,15 +243,15 @@ def visit(self, nx_file, depends_on):

if vector is None:
value = nx_file[item][()]
units = nx_file[item].attrs["units"]
ttype = nx_file[item].attrs["transformation_type"]
units = h5str(nx_file[item].attrs["units"])
ttype = h5str(nx_file[item].attrs["transformation_type"])
vector = nx_file[item].attrs["vector"]
if "offset" in nx_file[item].attrs:
offset = nx_file[item].attrs["offset"]
offset = convert_units(offset, units, "mm")
else:
offset = vector * 0.0
if ttype == numpy.string_("translation"):
if ttype == "translation":
value = convert_units(value, units, "mm")
try:
vector = vector * value
Expand Down Expand Up @@ -269,21 +279,17 @@ def __init__(self):
def visit(self, nx_file, depends_on):
item = nx_file[depends_on]
value = item[()]
units = item.attrs["units"]
ttype = item.attrs["transformation_type"]
units = h5str(item.attrs["units"])
ttype = h5str(item.attrs["transformation_type"])
vector = [float(v) for v in item.attrs["vector"]]
if ttype == numpy.string_("translation"):
if ttype == "translation":
return
elif ttype == numpy.string_("rotation"):
elif ttype == "rotation":
if hasattr(value, "__iter__") and len(value):
value = value[0]
if units == numpy.string_("rad"):
if units == "rad":
value *= 180 / math.pi
elif units not in [
numpy.string_("deg"),
numpy.string_("degree"),
numpy.string_("degrees"),
]:
elif units not in ["deg", "degree", "degrees"]:
raise RuntimeError("Invalid units: %s" % units)

# is the axis moving? Check the values for this axis
Expand Down Expand Up @@ -320,14 +326,14 @@ def result(self):

if vector is None:
value = nx_file[item][()]
units = nx_file[item].attrs["units"]
ttype = nx_file[item].attrs["transformation_type"]
units = h5str(nx_file[item].attrs["units"])
ttype = h5str(nx_file[item].attrs["transformation_type"])
vector = nx_file[item].attrs["vector"]
if "offset" in nx_file[item].attrs:
offset = nx_file[item].attrs["offset"]
else:
offset = vector * 0.0
if ttype == numpy.string_("translation"):
if ttype == "translation":
value = convert_units(value, units, "mm")
try:
vector = vector * value
Expand Down Expand Up @@ -634,11 +640,11 @@ def get_change_of_basis(transformation):
# Change of basis to convert from NeXus to IUCr/ImageCIF convention
n2i_cob = sqr((-1, 0, 0, 0, 1, 0, 0, 0, -1))

axis_type = numpy.string_(transformation.attrs["transformation_type"])
axis_type = h5str(transformation.attrs["transformation_type"])

vector = n2i_cob * col(transformation.attrs["vector"]).normalize()
setting = transformation[0]
units = numpy.string_(transformation.attrs["units"])
units = h5str(transformation.attrs["units"])

if "offset" in transformation.attrs:
offset = n2i_cob * col(transformation.attrs["offset"])
Expand All @@ -653,14 +659,10 @@ def get_change_of_basis(transformation):
# 4x4 change of basis matrix (homogeneous coordinates)
cob = None

if axis_type == numpy.string_("rotation"):
if units == numpy.string_("rad"):
if axis_type == "rotation":
if units == "rad":
deg = False
elif units in [
numpy.string_("deg"),
numpy.string_("degree"),
numpy.string_("degrees"),
]:
elif units in ["deg", "degree", "degrees"]:
deg = True
else:
raise RuntimeError("Invalid units: %s" % units)
Expand All @@ -685,7 +687,7 @@ def get_change_of_basis(transformation):
1,
)
)
elif axis_type == numpy.string_("translation"):
elif axis_type == "translation":
setting = convert_units(setting, units, "mm")
translation = offset + (vector * setting)
cob = sqr(
Expand Down Expand Up @@ -725,9 +727,9 @@ def get_depends_on_chain_using_equipment_components(transformation):
current = transformation

while True:
parent_id = numpy.string_(current.attrs["depends_on"])
parent_id = h5str(current.attrs["depends_on"])

if parent_id == numpy.string_("."):
if parent_id == ".":
return chain
parent = current.parent[parent_id]

Expand All @@ -752,9 +754,9 @@ def get_cumulative_change_of_basis(transformation):

cob = get_change_of_basis(transformation)

parent_id = numpy.string_(transformation.attrs["depends_on"])
parent_id = h5str(transformation.attrs["depends_on"])

if parent_id == numpy.string_("."):
if parent_id == ".":
return None, cob
parent = transformation.parent[parent_id]

Expand Down Expand Up @@ -942,13 +944,13 @@ def set_frame(pg, transformation):

# Get the detector material
if "sensor_material" in nx_detector.handle:
value = numpy.string_(nx_detector.handle["sensor_material"][()])
value = h5str(nx_detector.handle["sensor_material"][()])
material = {
numpy.string_("Si"): "Si",
numpy.string_("Silicon"): "Si",
numpy.string_("Sillicon"): "Si",
numpy.string_("CdTe"): "CdTe",
numpy.string_("GaAs"): "GaAs",
"Si": "Si",
"Silicon": "Si",
"Sillicon": "Si",
"CdTe": "CdTe",
"GaAs": "GaAs",
}.get(value)
if not material:
raise RuntimeError("Unknown material: %s" % value)
Expand Down Expand Up @@ -1009,12 +1011,12 @@ def __init__(self, obj, beam, shape=None):

# Get the detector material
material = {
numpy.string_("Si"): "Si",
numpy.string_("Silicon"): "Si",
numpy.string_("Sillicon"): "Si",
numpy.string_("CdTe"): "CdTe",
numpy.string_("GaAs"): "GaAs",
}.get(nx_detector["sensor_material"][()])
"Si": "Si",
"Silicon": "Si",
"Sillicon": "Si",
"CdTe": "CdTe",
"GaAs": "GaAs",
}.get(h5str(nx_detector["sensor_material"][()]))
if not material:
raise RuntimeError(
"Unknown material: %s" % nx_detector["sensor_material"][()]
Expand Down Expand Up @@ -1112,7 +1114,7 @@ class GoniometerFactory(object):
"""

def __init__(self, obj):
if obj.handle["depends_on"][()] == ".":
if h5str(obj.handle["depends_on"][()]) == ".":
self.model = None
else:
axes, angles, axis_names, scan_axis = construct_axes(
Expand All @@ -1130,13 +1132,13 @@ def __init__(self, obj):


def find_goniometer_rotation(obj):
if obj.handle["depends_on"][()] == ".":
if h5str(obj.handle["depends_on"][()]) == ".":
return
thing = obj.handle.file[obj.handle["depends_on"][()]]
tree = get_depends_on_chain_using_equipment_components(thing)
for t in tree:
o = obj.handle.file[t.name]
if o.attrs["transformation_type"] == numpy.string_("rotation"):
if h5str(o.attrs["transformation_type"]) == "rotation":
# if this is changing, assume is scan axis
v = o[()]
if min(v) < max(v):
Expand All @@ -1145,7 +1147,7 @@ def find_goniometer_rotation(obj):


def find_scanning_axis(obj):
if obj.handle["depends_on"][()] == ".":
if h5str(obj.handle["depends_on"][()]) == ".":
return
thing = obj.handle.file[obj.handle["depends_on"][()]]
tree = get_depends_on_chain_using_equipment_components(thing)
Expand All @@ -1159,7 +1161,7 @@ def generate_scan_model(obj, detector_obj):
"""
Create a scan model from NXmx stuff.
"""
if obj.handle["depends_on"][()] == ".":
if h5str(obj.handle["depends_on"][()]) == ".":
return

# Get the image and oscillation range - need to search for rotations
Expand All @@ -1176,7 +1178,7 @@ def generate_scan_model(obj, detector_obj):
num_images = len(scan_axis)
image_range = (1, num_images)

rotn = scan_axis.attrs["transformation_type"] == numpy.string_("rotation")
rotn = h5str(scan_axis.attrs["transformation_type"]) == "rotation"

if num_images > 1 and rotn:
oscillation = (float(scan_axis[0]), float(scan_axis[1] - scan_axis[0]))
Expand Down
1 change: 1 addition & 0 deletions newsfragments/267.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix errors introduced by moving to h5py 3.1+

0 comments on commit 687e231

Please sign in to comment.