Skip to content

Commit

Permalink
Store and use max bit depth for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-austin committed Jan 29, 2025
1 parent a566fa6 commit bbf53cb
Showing 1 changed file with 56 additions and 3 deletions.
59 changes: 56 additions & 3 deletions operationsgateway_api/src/records/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ async def _apply_function(
variable_transformer.evaluate(function["expression"])
variables = variable_transformer.variables
channels_to_fetch = set()
bit_depths = []

log.debug("Attempting to extract %s from %s", variables, record["channels"])
for variable in variables:
Expand All @@ -526,6 +527,7 @@ async def _apply_function(
record["_id"],
variable,
record["channels"][variable],
bit_depths=bit_depths,
)
else:
channels_to_fetch.add(variable)
Expand All @@ -536,6 +538,7 @@ async def _apply_function(
variable_data,
channels_to_fetch,
skip_functions=variable_transformer.skip_functions,
bit_depths=bit_depths,
)
if missing_channels:
# Remove any missing channels so we don't skip future functions which
Expand All @@ -556,6 +559,7 @@ async def _apply_function(
result=result,
return_thumbnails=return_thumbnails,
truncate=truncate,
bit_depths=bit_depths,
)
variable_data[function["name"]] = result

Expand All @@ -565,6 +569,7 @@ async def _fetch_channels(
variable_data: dict,
channels_to_fetch: "set[str]",
skip_functions: "set[str]",
bit_depths: "list[int]",
) -> "set[str]":
"""Fetches `channels_to_fetch`, returning known channels missing for this
record and raising an exception if the channel is not known at all."""
Expand Down Expand Up @@ -594,13 +599,15 @@ async def _fetch_channels(
record["_id"],
name,
channel_value,
bit_depths=bit_depths,
)

@staticmethod
async def _extract_variable(
record_id: str,
name: str,
channel_value: dict,
bit_depths: "list[int]",
) -> "np.ndarray | WaveformVariable | float":
"""
Extracts and returns the relevant data from `channel_value`, handling
Expand All @@ -620,6 +627,10 @@ async def _extract_variable(
channel_name=name,
channel_value=channel_value,
)
if raw_bit_depth is not None:
# Modify in place to store for each channel, getting around static func
bit_depths.append(raw_bit_depth)

image_bytes = await Image.get_image(
record_id=record_id,
channel_name=name,
Expand Down Expand Up @@ -674,6 +685,33 @@ def _bit_shift_to_raw(
# Bit depths > 8 would have been stored as 16 bit, so shift back to raw
return img_array / 2 ** (16 - raw_bit_depth)

@staticmethod
def _bit_shift_to_storage(
img_array: np.ndarray,
raw_bit_depth: "int | None",
) -> "tuple[np.ndarray, int]":
"""Shift the bits of a calculated image from numerically accurate positions to
most significant bits for display/storage.
Args:
img_array (np.ndarray): Calculated image as a np.ndarray.
raw_bit_depth (int | None): Original specified bit depth of the raw data.
Returns:
tuple[np.ndarray, int]:
Input image with the bits shifted to storage/display positions,
and the value of this storage bit depth.
"""
if raw_bit_depth in (8, 16):
# If bit depth exactly matches a storage depth, no shift is needed
return img_array, raw_bit_depth
elif raw_bit_depth < 8:
# Bit depths < 8 would have been stored as 8 bit, so shift up to storage
return img_array.astype(np.uint8) * 2 ** (8 - raw_bit_depth), 8
else:
# Bit depths > 8 would have been stored as 16 bit, so shift up to storage
return img_array.astype(np.uint16) * 2 ** (16 - raw_bit_depth), 16

@staticmethod
def _parse_function_results(
record: dict,
Expand All @@ -684,6 +722,7 @@ def _parse_function_results(
colourmap_name: str,
function_name: str,
result: "np.ndarray | WaveformVariable | np.float64",
bit_depths: "list[int]",
return_thumbnails: bool = True,
truncate: bool = False,
) -> None:
Expand All @@ -703,6 +742,7 @@ def _parse_function_results(
colourmap_name=colourmap_name,
return_thumbnails=return_thumbnails,
truncate=truncate,
bit_depths=bit_depths,
)

elif isinstance(result, WaveformVariable):
Expand Down Expand Up @@ -737,13 +777,26 @@ def _parse_image_result(
colourmap_name: str,
return_thumbnails: bool,
truncate: bool,
bit_depths: "list[int]",
) -> dict:
"""Parses a numpy ndarray and returns image bytes, either for a thumbnail or
full image.
"""
# We do not track the bit depth of inputs, so keep maximum depth to
# avoid losing information
storage_bit_depth = 16
if len(bit_depths) == 0:
# We have no information about input bit depths, so set to max supported
# This will not lose any information, but may make the image very dark
overall_bit_depth = 16
else:
# Otherwise, take the highest depth encountered. There may be more than one
# if the function depends on multiple channels with different depths,
# in which case, we should try and store all information which means
# choosing the highest bit depth needed
overall_bit_depth = max(bit_depths)

result, storage_bit_depth = Record._bit_shift_to_storage(
img_array=result,
raw_bit_depth=overall_bit_depth,
)
if return_thumbnails:
metadata = {
"channel_dtype": "image",
Expand Down

0 comments on commit bbf53cb

Please sign in to comment.