From 71765a590d996b9935299ebfb3de305660542ca1 Mon Sep 17 00:00:00 2001 From: Viktor Bozhinov Date: Fri, 28 Jan 2022 13:41:40 +0000 Subject: [PATCH] refactor: refactor `from_icat` abstract method #265 --- datagateway_api/src/search_api/models.py | 44 +++++++++++------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/datagateway_api/src/search_api/models.py b/datagateway_api/src/search_api/models.py index 81f8d568..f67e7f04 100644 --- a/datagateway_api/src/search_api/models.py +++ b/datagateway_api/src/search_api/models.py @@ -31,17 +31,17 @@ class PaNOSCAttribute(ABC, BaseModel): @classmethod @abstractmethod def from_icat(cls, icat_data, required_related_fields): # noqa: B902, N805 - model_fields = cls.__fields__ + entity_fields = cls.__fields__ - model_data = {} - for field in model_fields: + entity_data = {} + for entity_field in entity_fields: # Some fields have aliases so we must use them when creating a model # instance. If a field does not have an alias then the `alias` property # holds the name of the field - field_alias = cls.__fields__[field].alias + entity_field_alias = cls.__fields__[entity_field].alias - panosc_entity_name, icat_field_name = mappings.get_icat_mapping( - cls.__name__, field_alias, + entity_name, icat_field_name = mappings.get_icat_mapping( + cls.__name__, entity_field_alias, ) if not isinstance(icat_field_name, list): @@ -69,54 +69,52 @@ def from_icat(cls, icat_data, required_related_fields): # noqa: B902, N805 if not field_value: continue - if panosc_entity_name != cls.__name__: + if entity_name != cls.__name__: # If we are here, it means that the field references another model so # we have to get hold of its class definition and call its `from_icat` # method to create an instance of itself with the ICAT data provided. # Doing this allows for recursion. - data = field_value - if not isinstance(data, list): - data = [data] + data = ( + [field_value] if not isinstance(field_value, list) else field_value + ) required_related_fields_for_next_entity = [] for required_related_field in required_related_fields: required_related_field = required_related_field.split(".") if ( len(required_related_field) > 1 - and field_alias in required_related_field + and entity_field_alias in required_related_field ): required_related_fields_for_next_entity.extend( required_related_field[1:], ) - # Get the class of the referenced model - panosc_model_attr = getattr(sys.modules[__name__], panosc_entity_name) + # Get the class of the referenced entity + entity_attr = getattr(sys.modules[__name__], entity_name) field_value = [ - panosc_model_attr.from_icat( - d, required_related_fields_for_next_entity, - ) + entity_attr.from_icat(d, required_related_fields_for_next_entity) for d in data ] - field_outer_type = cls.__fields__[field].outer_type_ + entity_field_outer_type = cls.__fields__[entity_field].outer_type_ if ( - not hasattr(field_outer_type, "_name") - or field_outer_type._name != "List" + not hasattr(entity_field_outer_type, "_name") + or entity_field_outer_type._name != "List" ) and isinstance(field_value, list): # If the field does not hold list of values but `field_value` # is a list, then just get its first element field_value = field_value[0] - model_data[field_alias] = field_value + entity_data[entity_field_alias] = field_value for required_related_field in required_related_fields: required_related_field = required_related_field.split(".")[0] if ( - required_related_field in model_fields + required_related_field in entity_fields and required_related_field in cls._related_fields_with_min_cardinality_one - and required_related_field not in model_data + and required_related_field not in entity_data ): # If we are here, it means that a related entity, which has a minimum # cardinality of one, has been specified to be included as part of the @@ -128,7 +126,7 @@ def from_icat(cls, icat_data, required_related_fields): # noqa: B902, N805 ) raise ValidationError(errors=[error_wrapper], model=cls) - return cls(**model_data) + return cls(**entity_data) class Affiliation(PaNOSCAttribute):