diff --git a/common/filters.py b/common/filters.py index 8ab93a0d..346f76f1 100644 --- a/common/filters.py +++ b/common/filters.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import logging +from common.exceptions import BadRequestError + log = logging.getLogger() @@ -27,6 +29,13 @@ def __init__(self, field, value, operation): self.value = value self.operation = operation + if self.operation == "in": + if not isinstance(self.value, list): + raise BadRequestError( + "When using the 'in' operation for a WHERE filter, the values must" + " be in a list format e.g. [1, 2, 3]" + ) + def _extract_filter_fields(self, field): fields = field.split(".") include_depth = len(fields) diff --git a/common/icat/filters.py b/common/icat/filters.py index b80dab45..5d6d99f4 100644 --- a/common/icat/filters.py +++ b/common/icat/filters.py @@ -35,7 +35,11 @@ def apply_filter(self, query): elif self.operation == "gte": where_filter = self.create_condition(self.field, ">=", self.value) elif self.operation == "in": - where_filter = self.create_condition(self.field, "in", tuple(self.value)) + # Convert self.value into a string with brackets equivalent to tuple format. + # Cannot convert straight to tuple as single element tuples contain a + # trailing comma which Python ICAT/JPQL doesn't accept + self.value = str(self.value).replace("[", "(").replace("]", ")") + where_filter = self.create_condition(self.field, "in", self.value) else: raise FilterError(f"Bad operation given to where filter: {self.operation}") @@ -65,9 +69,10 @@ def create_condition(attribute_name, operator, value): """ conditions = {} - # Removing quote marks when doing conditions with IN expressions + # Removing quote marks when doing conditions with IN expressions or when a + # distinct filter is used in a request jpql_value = ( - f"{value}" if isinstance(value, tuple) or operator == "!=" else f"'{value}'" + f"{value}" if operator == "in" or operator == "!=" else f"'{value}'" ) conditions[attribute_name] = f"{operator} {jpql_value}" log.debug("Conditions in ICAT where filter, %s", conditions)