From 4b3229482bf493a0d148ebbc334597ac07943b86 Mon Sep 17 00:00:00 2001 From: "S. Andrew Sheppard" Date: Sat, 9 Jan 2016 19:46:53 -0600 Subject: [PATCH] extract relate filter from rest module (fixes #35) --- patterns/relate/filters.py | 22 ++++++++++++++++ patterns/relate/models.py | 18 +++++++++++++ patterns/relate/views.py | 53 ++++++++++++++++++++++++++++++++++++++ rest/filters.py | 24 ++++------------- rest/models.py | 35 ------------------------- rest/routers.py | 38 ++++++--------------------- rest/views.py | 2 -- tests/patterns_app/rest.py | 3 +++ tests/test_relate.py | 13 +++++++++- 9 files changed, 121 insertions(+), 87 deletions(-) create mode 100644 patterns/relate/filters.py create mode 100644 patterns/relate/views.py diff --git a/patterns/relate/filters.py b/patterns/relate/filters.py new file mode 100644 index 00000000..d35b3093 --- /dev/null +++ b/patterns/relate/filters.py @@ -0,0 +1,22 @@ +from rest_framework.filters import BaseFilterBackend + +from wq.db.rest.models import get_ct, get_by_identifier +from .models import get_related_parents + + +class RelatedFilterBackend(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + ctype = get_ct(view.model) + filter = {} + for key, val in list(view.kwargs.items()) + list(request.GET.items()): + if not key.startswith('related_'): + continue + if isinstance(val, list): + val = val[0] + for pct in get_related_parents(ctype): + if key == 'related_' + pct.identifier: + pclass = pct.model_class() + parent = get_by_identifier(pclass.objects, val) + objs = view.model.objects.filter_by_related(parent) + filter['pk__in'] = objs.values_list('pk', flat=True) + return queryset.filter(**filter) diff --git a/patterns/relate/models.py b/patterns/relate/models.py index 0a127912..45ee1c7c 100644 --- a/patterns/relate/models.py +++ b/patterns/relate/models.py @@ -280,3 +280,21 @@ def __str__(self): class Meta: proxy = INSTALLED abstract = not INSTALLED + + +def get_related_types(cls, **kwargs): + from wq.db.rest.models import ContentType as RestContentType + ctypes = set() + for rtype in cls.objects.filter(**kwargs): + # This is a DjangoContentType, swap for our custom version + ctype = RestContentType.objects.get_for_id(rtype.right.pk) + ctypes.add(ctype) + return ctypes + + +def get_related_children(contenttype): + return get_related_types(RelationshipType, from_type=contenttype) + + +def get_related_parents(contenttype): + return get_related_types(InverseRelationshipType, to_type=contenttype) diff --git a/patterns/relate/views.py b/patterns/relate/views.py new file mode 100644 index 00000000..7ea8b077 --- /dev/null +++ b/patterns/relate/views.py @@ -0,0 +1,53 @@ +from rest_framework.routers import Route +from wq.db.rest.views import ModelViewSet +from wq.db.rest.models import get_ct +from .models import get_related_parents, get_related_children +from .filters import RelatedFilterBackend + + +class RelatedModelViewSet(ModelViewSet): + filter_backends = ModelViewSet.filter_backends + [RelatedFilterBackend] + + def list(self, request, *args, **kwargs): + response = super(RelatedModelViewSet, self).list( + request, *args, **kwargs + ) + ct = get_ct(self.model) + for pct in get_related_parents(ct): + self.get_parent(pct, 'related_%s' % pct.identifier, response) + return response + + @classmethod + def extra_routes(cls): + routes = [] + ct = get_ct(cls.model) + for pct in get_related_parents(ct): + if not pct.is_registered(): + continue + if pct.urlbase == '': + purlbase = '' + else: + purlbase = pct.urlbase + '/' + + routes.append(Route( + ( + '^' + purlbase + r'(?P[^\/\?]+)/{prefix}{trailing_slash}$' + ), + mapping={'get': 'list'}, + name="{basename}-for-related-%s" % pct.identifier, + initkwargs={'suffix': 'List'}, + )) + + for cct in get_related_children(ct): + if not cct.is_registered(): + continue + cbase = cct.urlbase + routes.append(Route( + url='^%s-by-{prefix}' % cbase, + mapping={'get': 'list'}, + name="%s-by-%s" % (cct.identifier, ct.identifier), + initkwargs={'target': cbase, 'suffix': 'List'}, + )) + + return routes diff --git a/rest/filters.py b/rest/filters.py index c8b66c44..f1de2dca 100644 --- a/rest/filters.py +++ b/rest/filters.py @@ -1,8 +1,4 @@ from rest_framework.filters import BaseFilterBackend -RESERVED_PARAMETERS = ('_', 'page', 'limit', 'format', 'slug', 'mode') - -from .models import get_ct, get_by_identifier -from django.utils.six import string_types from django.db.models.fields import FieldDoesNotExist @@ -10,15 +6,15 @@ class FilterBackend(BaseFilterBackend): def filter_queryset(self, request, queryset, view): kwargs = {} for key, val in list(view.kwargs.items()) + list(request.GET.items()): - if key in RESERVED_PARAMETERS or key in view.ignore_kwargs: + if key in getattr(view, 'ignore_kwargs', []): continue - kwargs[key] = val if isinstance(val, string_types) else val[0] + if isinstance(val, list): + kwargs[key] = val[0] + else: + kwargs[key] = val model = getattr(view, 'model', None) or queryset.model - ctype = get_ct(model) for key, val in list(kwargs.items()): - if key.startswith('related_'): - continue field_name = key.split('__')[0] try: field = model._meta.get_field_by_name(field_name)[0] @@ -41,14 +37,4 @@ def filter_queryset(self, request, queryset, view): else: kwargs[key] = pcls.objects.get(pk=val) - for key, val in list(kwargs.items()): - if key.startswith('related_') and ctype.is_related: - for pct in ctype.get_all_parents(): - if key == 'related_' + pct.identifier: - pclass = pct.model_class() - parent = get_by_identifier(pclass.objects, kwargs[key]) - del kwargs[key] - objs = model.objects.filter_by_related(parent) - kwargs['pk__in'] = objs.values_list('pk', flat=True) - return queryset.filter(**kwargs) diff --git a/rest/models.py b/rest/models.py index 1bf80147..e8ac1bfb 100644 --- a/rest/models.py +++ b/rest/models.py @@ -2,7 +2,6 @@ ContentType as DjangoContentType, ContentTypeManager as DjangoContentTypeManager ) -from wq.db.patterns.models import RelationshipType from django.utils.encoding import force_text from django.utils.six import string_types @@ -54,24 +53,6 @@ def get_foreign_keys(self): def get_parents(self): return set(self.get_foreign_keys().keys()) - # Get foreign keys and RelationshipType parents for this content type - def get_all_parents(self): - parents = self.get_parents() - if self.is_related: - parents.update(self.get_relationshiptype_parents()) - return parents - - def get_relationshiptype_parents(self): - parents = set() - if not self.is_related: - return parents - for rtype in RelationshipType.objects.filter(to_type=self): - ctype = rtype.from_type - # This is a DjangoContentType, swap for our custom version - ctype = ContentType.objects.get_for_id(ctype.pk) - parents.add(ctype) - return parents - def get_children(self, include_rels=False): cls = self.model_class() if cls is None: @@ -88,27 +69,11 @@ def get_model(rel): else: return set(child[0] for child in children) - def get_all_children(self): - children = self.get_children() - if not self.is_related: - return children - for rtype in RelationshipType.objects.filter(from_type=self): - ctype = rtype.to_type - # This is a DjangoContentType, swap for our custom version - ctype = ContentType.objects.get_for_id(ctype.pk) - children.add(ctype) - return children - def get_config(self, user=None): from . import router # avoid circular import cls = self.model_class() return router.get_model_config(cls, user) - @property - def is_related(self): - config = self.get_config() - return config.get('related', False) - def is_registered(self): from . import router # avoid circular import cls = self.model_class() diff --git a/rest/routers.py b/rest/routers.py index add6a06c..ea4755cd 100644 --- a/rest/routers.py +++ b/rest/routers.py @@ -449,7 +449,6 @@ def get_routes(self, viewset): # /[parentmodel_url]/[foreignkey_value]/[model_url] ct = get_ct(model) - parent_routes = [] for pct, fields in ct.get_foreign_keys().items(): if not pct.is_registered(): continue @@ -461,41 +460,17 @@ def get_routes(self, viewset): purlbase = '' else: purlbase = pct.urlbase + '/' - parent_routes.append(( - fields[0], - ( + routes.append(Route( + url=( '^' + purlbase + r'(?P<' + fields[0] + '>[^\/\?]+)/{prefix}{trailing_slash}$' - ) - )) - - # Similar but for RelatedModel parent-child relationships - # (FIXME: see #35) - for pct in ct.get_relationshiptype_parents(): - if not pct.is_registered(): - continue - if pct.urlbase == '': - purlbase = '' - else: - purlbase = pct.urlbase + '/' - - parent_routes.append(( - 'related-' + pct.identifier, - ( - '^' + purlbase + r'(?P[^\/\?]+)/{prefix}{trailing_slash}$' - ) - )) - - for pname, purl in parent_routes: - routes.append(Route( - url=purl, + ), mapping={'get': 'list'}, - name="{basename}-for-%s" % pname, + name="{basename}-for-%s" % fields[0], initkwargs={'suffix': 'List'}, )) - for cct in ct.get_all_children(): + for cct in ct.get_children(): if not cct.is_registered(): continue cbase = cct.urlbase @@ -506,6 +481,9 @@ def get_routes(self, viewset): initkwargs={'target': cbase, 'suffix': 'List'}, )) + if hasattr(viewset, 'extra_routes'): + routes += viewset.extra_routes() + return routes @property diff --git a/rest/views.py b/rest/views.py index 73c24487..d157d3f8 100644 --- a/rest/views.py +++ b/rest/views.py @@ -186,8 +186,6 @@ def list(self, request, *args, **kwargs): for pct, fields in ct.get_foreign_keys().items(): if len(fields) == 1: self.get_parent(pct, fields[0], response) - for pct in ct.get_relationshiptype_parents(): - self.get_parent(pct, 'related_%s' % pct.identifier, response) return response def create(self, request, *args, **kwargs): diff --git a/tests/patterns_app/rest.py b/tests/patterns_app/rest.py index de992ade..fdeeca81 100644 --- a/tests/patterns_app/rest.py +++ b/tests/patterns_app/rest.py @@ -1,6 +1,7 @@ from wq.db import rest from wq.db.patterns import rest as patterns from wq.db.patterns.identify.views import IdentifiedModelViewSet +from wq.db.patterns.relate.views import RelatedModelViewSet from .models import ( AnnotatedModel, IdentifiedModel, MarkedModel, LocatedModel, RelatedModel, AnotherRelatedModel, @@ -32,10 +33,12 @@ rest.router.register_model( RelatedModel, serializer=patterns.RelatedModelSerializer, + viewset=RelatedModelViewSet, ) rest.router.register_model( AnotherRelatedModel, serializer=patterns.RelatedModelSerializer, + viewset=RelatedModelViewSet, ) rest.router.register_model( IdentifiedAnnotatedModel, diff --git a/tests/test_relate.py b/tests/test_relate.py index a698efe7..84fd4106 100644 --- a/tests/test_relate.py +++ b/tests/test_relate.py @@ -4,6 +4,7 @@ from django.contrib.auth.models import User from tests.patterns_app.models import RelatedModel, AnotherRelatedModel from wq.db.patterns.models import RelationshipType, Relationship +from wq.db.patterns.models import get_related_parents, get_related_children def create_reltype(): @@ -37,7 +38,7 @@ def setUp(self): from_content_type=self.parent_ct, from_object_id=self.parent.pk, - to_content_type=self.parent_ct, + to_content_type=self.child_ct, to_object_id=self.child.pk, ) @@ -79,6 +80,14 @@ def test_relate_create(self): self.assertEqual(str(invrel), "Child2 Sibling Of Child1") self.assertEqual(str(invrel.reltype), "Sibling Of") + def test_relate_parents(self): + parents = get_related_parents(self.child_ct) + self.assertEqual(set([self.parent_ct]), parents) + + def test_relate_children(self): + children = get_related_children(self.parent_ct) + self.assertEqual(set([self.child_ct]), children) + class RelateRestTestCase(RelateBaseTestCase): def setUp(self): @@ -225,6 +234,8 @@ def test_relate_put(self): self.assertEqual(invrel['item_id'], parent2.pk) def test_relate_filter_by_parent(self): + AnotherRelatedModel.objects.create(name="Child2") + AnotherRelatedModel.objects.create(name="Child3") response = self.client.get( '/relatedmodels/%s/anotherrelatedmodels.json' % self.parent.pk )