diff --git a/datagateway_api/common/database/filters.py b/datagateway_api/common/database/filters.py index 6bb41031..ffce968f 100644 --- a/datagateway_api/common/database/filters.py +++ b/datagateway_api/common/database/filters.py @@ -72,27 +72,30 @@ def _add_query_join(self, query): :return: Entity model of the field (usually the field relating to the endpoint the request is coming from) """ - try: - field = getattr(query.table, self.field) - except AttributeError: - raise FilterError( - f"Unknown attribute {self.field} on table {query.table.__name__}", - ) if self.related_related_field: - included_table = getattr(models, self.field) - included_included_table = getattr(models, self.related_field) + included_table = get_entity_object_from_name(self.field) + included_included_table = get_entity_object_from_name(self.related_field) query.base_query = query.base_query.join(included_table).join( included_included_table, ) - field = getattr(included_included_table, self.related_related_field) + field = self._get_field(included_included_table, self.related_related_field) elif self.related_field: included_table = get_entity_object_from_name(self.field) query.base_query = query.base_query.join(included_table) - field = getattr(included_table, self.related_field) + field = self._get_field(included_table, self.related_field) + else: + # No related fields + field = self._get_field(query.table, self.field) return field + def _get_field(self, table, field): + try: + return getattr(table, field) + except AttributeError: + raise FilterError(f"Unknown attribute {field} on table {table.__name__}") + class DatabaseWhereFilter(WhereFilter, DatabaseFilterUtilities): def __init__(self, field, value, operation):