Skip to content

Commit 21fac6b

Browse files
✨ Use DRF field info to extract filter param model field
1 parent 693aeec commit 21fac6b

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

vng_api_common/inspectors/query.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from django.core.exceptions import FieldDoesNotExist
21
from django.db import models
32
from django.utils.encoding import force_text
43
from django.utils.translation import ugettext as _
@@ -10,6 +9,7 @@
109

1110
from ..filters import URLModelChoiceFilter
1211
from ..utils import underscore_to_camel
12+
from .utils import get_target_field
1313

1414

1515
class FilterInspector(CoreAPICompatInspector):
@@ -29,13 +29,7 @@ def get_filter_parameters(self, filter_backend):
2929

3030
for parameter in fields:
3131
filter_field = filter_class.base_filters[parameter.name]
32-
33-
try:
34-
model_field = queryset.model._meta.get_field(
35-
parameter.name.split("__")[0]
36-
)
37-
except FieldDoesNotExist:
38-
model_field = None
32+
model_field = get_target_field(queryset.model, parameter.name)
3933

4034
help_text = filter_field.extra.get(
4135
"help_text", model_field.help_text if model_field else ""

vng_api_common/inspectors/utils.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Optional, Type
2+
3+
from django.db import models
4+
5+
from rest_framework.utils.model_meta import get_field_info
6+
7+
8+
def get_target_field(model: Type[models.Model], field: str) -> Optional[models.Field]:
9+
"""
10+
Retrieve the end-target that ``field`` points to.
11+
12+
:param field: A string containing a lookup, potentially spanning relations. E.g.:
13+
foo__bar__lte.
14+
:return: A Django model field instance or `None`
15+
"""
16+
17+
start, *remaining = field.split("__")
18+
field_info = get_field_info(model)
19+
20+
# simple, non relational field?
21+
if start in field_info.fields:
22+
return field_info.fields[start]
23+
24+
# simple relational field?
25+
if start in field_info.forward_relations:
26+
relation_info = field_info.forward_relations[start]
27+
if not remaining:
28+
return relation_info.model_field
29+
else:
30+
return get_target_field(relation_info.related_model, "__".join(remaining))
31+
32+
# check the reverse relations - note that the model name is used instead of model_name_set
33+
# in the queries -> we can't just test for containment in field_info.reverse_relations
34+
for relation_info in field_info.reverse_relations.values():
35+
# not sure about this - what if there are more relations with different related names?
36+
if relation_info.related_model._meta.model_name != start:
37+
continue
38+
return get_target_field(relation_info.related_model, "__".join(remaining))
39+
40+
return None

0 commit comments

Comments
 (0)