diff --git a/common/icat/helpers.py b/common/icat/helpers.py index 3b47d216..f0ca33be 100644 --- a/common/icat/helpers.py +++ b/common/icat/helpers.py @@ -177,8 +177,12 @@ def update_attributes(old_entity, new_entity): f" {old_entity.BeanName} entity" ) + return old_entity + + +def push_data_updates_to_icat(entity): try: - old_entity.update() + entity.update() except (ICATValidationError, ICATInternalError) as e: raise PythonICATError(e) @@ -263,7 +267,8 @@ def update_entity_by_id(client, table_name, id_, new_data): # There will only ever be one record associated with a single ID - if a record with # the specified ID cannot be found, it'll be picked up by the MissingRecordError in # get_entity_by_id() - update_attributes(entity_id_data, new_data) + updated_icat_entity = update_attributes(entity_id_data, new_data) + push_data_updates_to_icat(updated_icat_entity) # The record is re-obtained from Python ICAT (rather than using entity_id_data) to # show to the user whether the change has actually been applied @@ -388,18 +393,31 @@ def update_entities(client, table_name, data_to_update): if not isinstance(data_to_update, list): data_to_update = [data_to_update] - for entity in data_to_update: + updated_icat_data = [] + + for entity_request in data_to_update: try: - updated_result = update_entity_by_id( - client, table_name, entity["id"], entity + entity_data = get_entity_by_id( + client, + table_name, + entity_request["id"], + False, + return_related_entities=True, ) - updated_data.append(updated_result) + + updated_entity_data = update_attributes(entity_data, entity_request) + updated_icat_data.append(updated_entity_data) except KeyError: raise BadRequestError( "The new data in the request body must contain the ID (using the key:" " 'id') of the entity you wish to update" ) + # This separates the local data updates from pushing these updates to icatdb + for entity in updated_icat_data: + push_data_updates_to_icat(entity) + updated_data.append(get_entity_by_id(client, table_name, entity.id, True)) + return updated_data diff --git a/common/icat/query.py b/common/icat/query.py index dc35c32a..4dd708a9 100644 --- a/common/icat/query.py +++ b/common/icat/query.py @@ -333,4 +333,4 @@ def flatten_query_included_fields(self, includes): ICAT query """ - return [m for n in (field.split(".") for field in includes) for m in n] \ No newline at end of file + return [m for n in (field.split(".") for field in includes) for m in n]