diff --git a/operationsgateway_api/src/records/record.py b/operationsgateway_api/src/records/record.py index ef3499a9..f219779c 100644 --- a/operationsgateway_api/src/records/record.py +++ b/operationsgateway_api/src/records/record.py @@ -518,6 +518,7 @@ async def _apply_function( variable_transformer.evaluate(function["expression"]) variables = variable_transformer.variables channels_to_fetch = set() + bit_depths = [] log.debug("Attempting to extract %s from %s", variables, record["channels"]) for variable in variables: @@ -526,6 +527,7 @@ async def _apply_function( record["_id"], variable, record["channels"][variable], + bit_depths=bit_depths, ) else: channels_to_fetch.add(variable) @@ -536,6 +538,7 @@ async def _apply_function( variable_data, channels_to_fetch, skip_functions=variable_transformer.skip_functions, + bit_depths=bit_depths, ) if missing_channels: # Remove any missing channels so we don't skip future functions which @@ -556,6 +559,7 @@ async def _apply_function( result=result, return_thumbnails=return_thumbnails, truncate=truncate, + bit_depths=bit_depths, ) variable_data[function["name"]] = result @@ -565,6 +569,7 @@ async def _fetch_channels( variable_data: dict, channels_to_fetch: "set[str]", skip_functions: "set[str]", + bit_depths: "list[int]", ) -> "set[str]": """Fetches `channels_to_fetch`, returning known channels missing for this record and raising an exception if the channel is not known at all.""" @@ -594,6 +599,7 @@ async def _fetch_channels( record["_id"], name, channel_value, + bit_depths=bit_depths, ) @staticmethod @@ -601,6 +607,7 @@ async def _extract_variable( record_id: str, name: str, channel_value: dict, + bit_depths: "list[int]", ) -> "np.ndarray | WaveformVariable | float": """ Extracts and returns the relevant data from `channel_value`, handling @@ -620,6 +627,10 @@ async def _extract_variable( channel_name=name, channel_value=channel_value, ) + if raw_bit_depth is not None: + # Modify in place to store for each channel, getting around static func + bit_depths.append(raw_bit_depth) + image_bytes = await Image.get_image( record_id=record_id, channel_name=name, @@ -674,6 +685,33 @@ def _bit_shift_to_raw( # Bit depths > 8 would have been stored as 16 bit, so shift back to raw return img_array / 2 ** (16 - raw_bit_depth) + @staticmethod + def _bit_shift_to_storage( + img_array: np.ndarray, + raw_bit_depth: "int | None", + ) -> "tuple[np.ndarray, int]": + """Shift the bits of a calculated image from numerically accurate positions to + most significant bits for display/storage. + + Args: + img_array (np.ndarray): Calculated image as a np.ndarray. + raw_bit_depth (int | None): Original specified bit depth of the raw data. + + Returns: + tuple[np.ndarray, int]: + Input image with the bits shifted to storage/display positions, + and the value of this storage bit depth. + """ + if raw_bit_depth in (8, 16): + # If bit depth exactly matches a storage depth, no shift is needed + return img_array, raw_bit_depth + elif raw_bit_depth < 8: + # Bit depths < 8 would have been stored as 8 bit, so shift up to storage + return img_array.astype(np.uint8) * 2 ** (8 - raw_bit_depth), 8 + else: + # Bit depths > 8 would have been stored as 16 bit, so shift up to storage + return img_array.astype(np.uint16) * 2 ** (16 - raw_bit_depth), 16 + @staticmethod def _parse_function_results( record: dict, @@ -684,6 +722,7 @@ def _parse_function_results( colourmap_name: str, function_name: str, result: "np.ndarray | WaveformVariable | np.float64", + bit_depths: "list[int]", return_thumbnails: bool = True, truncate: bool = False, ) -> None: @@ -703,6 +742,7 @@ def _parse_function_results( colourmap_name=colourmap_name, return_thumbnails=return_thumbnails, truncate=truncate, + bit_depths=bit_depths, ) elif isinstance(result, WaveformVariable): @@ -737,13 +777,26 @@ def _parse_image_result( colourmap_name: str, return_thumbnails: bool, truncate: bool, + bit_depths: "list[int]", ) -> dict: """Parses a numpy ndarray and returns image bytes, either for a thumbnail or full image. """ - # We do not track the bit depth of inputs, so keep maximum depth to - # avoid losing information - storage_bit_depth = 16 + if len(bit_depths) == 0: + # We have no information about input bit depths, so set to max supported + # This will not lose any information, but may make the image very dark + overall_bit_depth = 16 + else: + # Otherwise, take the highest depth encountered. There may be more than one + # if the function depends on multiple channels with different depths, + # in which case, we should try and store all information which means + # choosing the highest bit depth needed + overall_bit_depth = max(bit_depths) + + result, storage_bit_depth = Record._bit_shift_to_storage( + img_array=result, + raw_bit_depth=overall_bit_depth, + ) if return_thumbnails: metadata = { "channel_dtype": "image",