Skip to content

Commit

Permalink
Merge pull request #40 from ral-facilities/32_improve_include_filtering
Browse files Browse the repository at this point in the history
Improve include filtering
  • Loading branch information
keiranjprice101 authored Aug 16, 2019
2 parents 7b6adef + b60e2d1 commit 60c668d
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions common/models/db_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlalchemy import Index, Column, BigInteger, String, DateTime, ForeignKey, Integer, Float, FetchedValue
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm.collections import InstrumentedList

from common.exceptions import BadFilterError

Base = declarative_base()

Expand All @@ -21,36 +22,37 @@ def to_dict(self):
dictionary[column.name] = str(getattr(self, column.name))
return dictionary

def to_nested_dict(self, included_relations):
def to_nested_dict(self, includes):
"""
Given related models return a nested dictionary with the child or parent rows nested.
:param included_relations: string/list/dict - The related models to include.
:return: A nested dictionary with the included models
"""
dictionary = {}
for column in self.__table__.columns:
dictionary[column.name] = str(getattr(self, column.name))
if type(included_relations) is not dict:
for attr in dir(self):
if attr in included_relations:
relation = getattr(self, attr)
if isinstance(relation, EntityHelper):
dictionary[attr + "_ID"] = relation.to_dict()
elif isinstance(relation, InstrumentedList): # Instrumented list is when the inclusion is a child
dictionary[attr + "_ID"] = []
for entity in getattr(self, attr):
dictionary[attr + "_ID"].append(entity.to_dict())
else:
for attr in dir(self):
if attr == list(included_relations.keys())[0]:
dictionary[attr + "_ID"] = getattr(self, attr).to_nested_dict(list(included_relations.values()))

dictionary = {k: v for k, v in dictionary.items() if
"ID" in k or k != "MOD_ID" or k != "CREATE_ID" or k != "ID"}
dictionary = self.to_dict()
try:
includes = includes if type(includes) is list else [includes]
for include in includes:
if type(include) is str:
related_entity = self.get_related_entity(include)
dictionary[related_entity.__tablename__] = related_entity.to_dict()
elif type(include) is dict:
related_entity = self.get_related_entity(list(include)[0])
dictionary[related_entity.__tablename__] = related_entity.to_nested_dict(include[list(include)[0]])
except TypeError:
raise BadFilterError(f" Bad include relations provided: {includes}")
return dictionary

def get_related_entity(self, entity):
"""
Given a string for the related entity name, return the related entity
:param entity: String - The name of the entity
:return: The entity
"""
try:
return getattr(self, entity)
except AttributeError:
raise BadFilterError(f" No related entity: {entity}")

def update_from_dict(self, dictionary):
"""
Given a dictionary containing field names and variables, updates the entity from the given dictionary
Expand Down

0 comments on commit 60c668d

Please sign in to comment.