Skip to content

Commit

Permalink
extract relate filter from rest module (fixes #35)
Browse files Browse the repository at this point in the history
  • Loading branch information
sheppard committed Jan 10, 2016
1 parent e81766c commit 4b32294
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 87 deletions.
22 changes: 22 additions & 0 deletions patterns/relate/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from rest_framework.filters import BaseFilterBackend

from wq.db.rest.models import get_ct, get_by_identifier
from .models import get_related_parents


class RelatedFilterBackend(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
ctype = get_ct(view.model)
filter = {}
for key, val in list(view.kwargs.items()) + list(request.GET.items()):
if not key.startswith('related_'):
continue
if isinstance(val, list):
val = val[0]
for pct in get_related_parents(ctype):
if key == 'related_' + pct.identifier:
pclass = pct.model_class()
parent = get_by_identifier(pclass.objects, val)
objs = view.model.objects.filter_by_related(parent)
filter['pk__in'] = objs.values_list('pk', flat=True)
return queryset.filter(**filter)
18 changes: 18 additions & 0 deletions patterns/relate/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,21 @@ def __str__(self):
class Meta:
proxy = INSTALLED
abstract = not INSTALLED


def get_related_types(cls, **kwargs):
from wq.db.rest.models import ContentType as RestContentType
ctypes = set()
for rtype in cls.objects.filter(**kwargs):
# This is a DjangoContentType, swap for our custom version
ctype = RestContentType.objects.get_for_id(rtype.right.pk)
ctypes.add(ctype)
return ctypes


def get_related_children(contenttype):
return get_related_types(RelationshipType, from_type=contenttype)


def get_related_parents(contenttype):
return get_related_types(InverseRelationshipType, to_type=contenttype)
53 changes: 53 additions & 0 deletions patterns/relate/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from rest_framework.routers import Route
from wq.db.rest.views import ModelViewSet
from wq.db.rest.models import get_ct
from .models import get_related_parents, get_related_children
from .filters import RelatedFilterBackend


class RelatedModelViewSet(ModelViewSet):
filter_backends = ModelViewSet.filter_backends + [RelatedFilterBackend]

def list(self, request, *args, **kwargs):
response = super(RelatedModelViewSet, self).list(
request, *args, **kwargs
)
ct = get_ct(self.model)
for pct in get_related_parents(ct):
self.get_parent(pct, 'related_%s' % pct.identifier, response)
return response

@classmethod
def extra_routes(cls):
routes = []
ct = get_ct(cls.model)
for pct in get_related_parents(ct):
if not pct.is_registered():
continue
if pct.urlbase == '':
purlbase = ''
else:
purlbase = pct.urlbase + '/'

routes.append(Route(
(
'^' + purlbase + r'(?P<related_' + pct.identifier
+ '>[^\/\?]+)/{prefix}{trailing_slash}$'
),
mapping={'get': 'list'},
name="{basename}-for-related-%s" % pct.identifier,
initkwargs={'suffix': 'List'},
))

for cct in get_related_children(ct):
if not cct.is_registered():
continue
cbase = cct.urlbase
routes.append(Route(
url='^%s-by-{prefix}' % cbase,
mapping={'get': 'list'},
name="%s-by-%s" % (cct.identifier, ct.identifier),
initkwargs={'target': cbase, 'suffix': 'List'},
))

return routes
24 changes: 5 additions & 19 deletions rest/filters.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
from rest_framework.filters import BaseFilterBackend
RESERVED_PARAMETERS = ('_', 'page', 'limit', 'format', 'slug', 'mode')

from .models import get_ct, get_by_identifier
from django.utils.six import string_types
from django.db.models.fields import FieldDoesNotExist


class FilterBackend(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
kwargs = {}
for key, val in list(view.kwargs.items()) + list(request.GET.items()):
if key in RESERVED_PARAMETERS or key in view.ignore_kwargs:
if key in getattr(view, 'ignore_kwargs', []):
continue
kwargs[key] = val if isinstance(val, string_types) else val[0]
if isinstance(val, list):
kwargs[key] = val[0]
else:
kwargs[key] = val
model = getattr(view, 'model', None) or queryset.model
ctype = get_ct(model)

for key, val in list(kwargs.items()):
if key.startswith('related_'):
continue
field_name = key.split('__')[0]
try:
field = model._meta.get_field_by_name(field_name)[0]
Expand All @@ -41,14 +37,4 @@ def filter_queryset(self, request, queryset, view):
else:
kwargs[key] = pcls.objects.get(pk=val)

for key, val in list(kwargs.items()):
if key.startswith('related_') and ctype.is_related:
for pct in ctype.get_all_parents():
if key == 'related_' + pct.identifier:
pclass = pct.model_class()
parent = get_by_identifier(pclass.objects, kwargs[key])
del kwargs[key]
objs = model.objects.filter_by_related(parent)
kwargs['pk__in'] = objs.values_list('pk', flat=True)

return queryset.filter(**kwargs)
35 changes: 0 additions & 35 deletions rest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
ContentType as DjangoContentType,
ContentTypeManager as DjangoContentTypeManager
)
from wq.db.patterns.models import RelationshipType
from django.utils.encoding import force_text
from django.utils.six import string_types

Expand Down Expand Up @@ -54,24 +53,6 @@ def get_foreign_keys(self):
def get_parents(self):
return set(self.get_foreign_keys().keys())

# Get foreign keys and RelationshipType parents for this content type
def get_all_parents(self):
parents = self.get_parents()
if self.is_related:
parents.update(self.get_relationshiptype_parents())
return parents

def get_relationshiptype_parents(self):
parents = set()
if not self.is_related:
return parents
for rtype in RelationshipType.objects.filter(to_type=self):
ctype = rtype.from_type
# This is a DjangoContentType, swap for our custom version
ctype = ContentType.objects.get_for_id(ctype.pk)
parents.add(ctype)
return parents

def get_children(self, include_rels=False):
cls = self.model_class()
if cls is None:
Expand All @@ -88,27 +69,11 @@ def get_model(rel):
else:
return set(child[0] for child in children)

def get_all_children(self):
children = self.get_children()
if not self.is_related:
return children
for rtype in RelationshipType.objects.filter(from_type=self):
ctype = rtype.to_type
# This is a DjangoContentType, swap for our custom version
ctype = ContentType.objects.get_for_id(ctype.pk)
children.add(ctype)
return children

def get_config(self, user=None):
from . import router # avoid circular import
cls = self.model_class()
return router.get_model_config(cls, user)

@property
def is_related(self):
config = self.get_config()
return config.get('related', False)

def is_registered(self):
from . import router # avoid circular import
cls = self.model_class()
Expand Down
38 changes: 8 additions & 30 deletions rest/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@ def get_routes(self, viewset):

# /[parentmodel_url]/[foreignkey_value]/[model_url]
ct = get_ct(model)
parent_routes = []
for pct, fields in ct.get_foreign_keys().items():
if not pct.is_registered():
continue
Expand All @@ -461,41 +460,17 @@ def get_routes(self, viewset):
purlbase = ''
else:
purlbase = pct.urlbase + '/'
parent_routes.append((
fields[0],
(
routes.append(Route(
url=(
'^' + purlbase + r'(?P<' + fields[0]
+ '>[^\/\?]+)/{prefix}{trailing_slash}$'
)
))

# Similar but for RelatedModel parent-child relationships
# (FIXME: see #35)
for pct in ct.get_relationshiptype_parents():
if not pct.is_registered():
continue
if pct.urlbase == '':
purlbase = ''
else:
purlbase = pct.urlbase + '/'

parent_routes.append((
'related-' + pct.identifier,
(
'^' + purlbase + r'(?P<related_' + pct.identifier
+ '>[^\/\?]+)/{prefix}{trailing_slash}$'
)
))

for pname, purl in parent_routes:
routes.append(Route(
url=purl,
),
mapping={'get': 'list'},
name="{basename}-for-%s" % pname,
name="{basename}-for-%s" % fields[0],
initkwargs={'suffix': 'List'},
))

for cct in ct.get_all_children():
for cct in ct.get_children():
if not cct.is_registered():
continue
cbase = cct.urlbase
Expand All @@ -506,6 +481,9 @@ def get_routes(self, viewset):
initkwargs={'target': cbase, 'suffix': 'List'},
))

if hasattr(viewset, 'extra_routes'):
routes += viewset.extra_routes()

return routes

@property
Expand Down
2 changes: 0 additions & 2 deletions rest/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,6 @@ def list(self, request, *args, **kwargs):
for pct, fields in ct.get_foreign_keys().items():
if len(fields) == 1:
self.get_parent(pct, fields[0], response)
for pct in ct.get_relationshiptype_parents():
self.get_parent(pct, 'related_%s' % pct.identifier, response)
return response

def create(self, request, *args, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions tests/patterns_app/rest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from wq.db import rest
from wq.db.patterns import rest as patterns
from wq.db.patterns.identify.views import IdentifiedModelViewSet
from wq.db.patterns.relate.views import RelatedModelViewSet
from .models import (
AnnotatedModel, IdentifiedModel, MarkedModel, LocatedModel,
RelatedModel, AnotherRelatedModel,
Expand Down Expand Up @@ -32,10 +33,12 @@
rest.router.register_model(
RelatedModel,
serializer=patterns.RelatedModelSerializer,
viewset=RelatedModelViewSet,
)
rest.router.register_model(
AnotherRelatedModel,
serializer=patterns.RelatedModelSerializer,
viewset=RelatedModelViewSet,
)
rest.router.register_model(
IdentifiedAnnotatedModel,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_relate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.contrib.auth.models import User
from tests.patterns_app.models import RelatedModel, AnotherRelatedModel
from wq.db.patterns.models import RelationshipType, Relationship
from wq.db.patterns.models import get_related_parents, get_related_children


def create_reltype():
Expand Down Expand Up @@ -37,7 +38,7 @@ def setUp(self):
from_content_type=self.parent_ct,
from_object_id=self.parent.pk,

to_content_type=self.parent_ct,
to_content_type=self.child_ct,
to_object_id=self.child.pk,
)

Expand Down Expand Up @@ -79,6 +80,14 @@ def test_relate_create(self):
self.assertEqual(str(invrel), "Child2 Sibling Of Child1")
self.assertEqual(str(invrel.reltype), "Sibling Of")

def test_relate_parents(self):
parents = get_related_parents(self.child_ct)
self.assertEqual(set([self.parent_ct]), parents)

def test_relate_children(self):
children = get_related_children(self.parent_ct)
self.assertEqual(set([self.child_ct]), children)


class RelateRestTestCase(RelateBaseTestCase):
def setUp(self):
Expand Down Expand Up @@ -225,6 +234,8 @@ def test_relate_put(self):
self.assertEqual(invrel['item_id'], parent2.pk)

def test_relate_filter_by_parent(self):
AnotherRelatedModel.objects.create(name="Child2")
AnotherRelatedModel.objects.create(name="Child3")
response = self.client.get(
'/relatedmodels/%s/anotherrelatedmodels.json' % self.parent.pk
)
Expand Down

0 comments on commit 4b32294

Please sign in to comment.