Skip to content

Commit

Permalink
Merge branch 'feature/query-params-search-api-#259' into filter-input…
Browse files Browse the repository at this point in the history
…-conversion
  • Loading branch information
MRichards99 committed Dec 8, 2021
2 parents dd23cce + b0a4c47 commit 75f8a91
Show file tree
Hide file tree
Showing 15 changed files with 2,497 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ venv/
.idea/
*.pyc
logs.log*
config.json
config.json*
.vscode/
.nox/
.python-version
Expand Down
17 changes: 17 additions & 0 deletions datagateway_api/src/common/base_query_filter_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from abc import abstractstaticmethod


class QueryFilterFactory(object):
@abstractstaticmethod
def get_query_filter(request_filter, entity_name=None): # noqa: B902, N805
"""
Given a filter, return a matching Query filter object
:param request_filter: The filter to create the QueryFilter for
:type request_filter: :class:`dict`
:param entity_name: Entity name of the endpoint, optional (only used for search
API, not DataGateway API)
:type entity_name: :class:`str`
:return: The QueryFilter object created
"""
pass
29 changes: 25 additions & 4 deletions datagateway_api/src/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
MissingCredentialsError,
)
from datagateway_api.src.datagateway_api.database import models
from datagateway_api.src.datagateway_api.query_filter_factory import QueryFilterFactory
from datagateway_api.src.resources.entities.entity_endpoint_dict import endpoints

log = logging.getLogger()
Expand Down Expand Up @@ -88,20 +87,42 @@ def is_valid_json(string):
return True


def get_filters_from_query_string():
def get_filters_from_query_string(api_type, entity_name=None):
"""
Gets a list of filters from the query_strings arg,value pairs, and returns a list of
QueryFilter Objects
:param api_type: Type of API this function is being used for i.e. DataGateway API or
Search API
:type api_type: :class:`str`
:param entity_name: Entity name of the endpoint, optional (only used for search
API, not DataGateway API)
:type entity_name: :class:`str`
:raises ApiError: If `api_type` isn't a valid value
:return: The list of filters
"""
if api_type == "search_api":
from datagateway_api.src.search_api.query_filter_factory import (
SearchAPIQueryFilterFactory as QueryFilterFactory,
)
elif api_type == "datagateway_api":
from datagateway_api.src.datagateway_api.query_filter_factory import (
DataGatewayAPIQueryFilterFactory as QueryFilterFactory,
)
else:
raise ApiError(
"Incorrect api_type passed into `get_filter_from_query_string(): "
f"{api_type}",
)
log.info(" Getting filters from query string")
try:
filters = []
for arg in request.args:
for value in request.args.getlist(arg):
filters.append(
QueryFilterFactory.get_query_filter({arg: json.loads(value)}),
filters.extend(
QueryFilterFactory.get_query_filter(
{arg: json.loads(value)}, entity_name,
),
)
return filters
except Exception as e:
Expand Down
22 changes: 14 additions & 8 deletions datagateway_api/src/datagateway_api/query_filter_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from datagateway_api.src.common.base_query_filter_factory import QueryFilterFactory
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import (
ApiError,
Expand All @@ -9,9 +10,9 @@
log = logging.getLogger()


class QueryFilterFactory(object):
class DataGatewayAPIQueryFilterFactory(QueryFilterFactory):
@staticmethod
def get_query_filter(request_filter):
def get_query_filter(request_filter, entity_name=None):
"""
Given a filter, return a matching Query filter object
Expand All @@ -22,6 +23,11 @@ def get_query_filter(request_filter):
:param request_filter: The filter to create the QueryFilter for
:type request_filter: :class:`dict`
:param entity_name: Not utilised in DataGateway API implementation of this
static function, used in the search API. It is part of the method signature
as the same function call (called in `get_filters_from_query_string()`) is
used for both implementations
:type entity_name: :class:`str`
:return: The QueryFilter object created
:raises ApiError: If the backend type contains an invalid value
:raises FilterError: If the filter name is not recognised
Expand Down Expand Up @@ -57,18 +63,18 @@ def get_query_filter(request_filter):
field = list(request_filter[filter_name].keys())[0]
operation = list(request_filter[filter_name][field].keys())[0]
value = request_filter[filter_name][field][operation]
return WhereFilter(field, value, operation)
return [WhereFilter(field, value, operation)]
elif filter_name == "order":
field = request_filter["order"].split(" ")[0]
direction = request_filter["order"].split(" ")[1]
return OrderFilter(field, direction)
return [OrderFilter(field, direction)]
elif filter_name == "skip":
return SkipFilter(request_filter["skip"])
return [SkipFilter(request_filter["skip"])]
elif filter_name == "limit":
return LimitFilter(request_filter["limit"])
return [LimitFilter(request_filter["limit"])]
elif filter_name == "include":
return IncludeFilter(request_filter["include"])
return [IncludeFilter(request_filter["include"])]
elif filter_name == "distinct":
return DistinctFieldFilter(request_filter["distinct"])
return [DistinctFieldFilter(request_filter["distinct"])]
else:
raise FilterError(f" Bad filter: {request_filter}")
6 changes: 3 additions & 3 deletions datagateway_api/src/resources/entities/entity_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get(self):
backend.get_with_filters(
get_session_id_from_auth_header(),
entity_type,
get_filters_from_query_string(),
get_filters_from_query_string("datagateway_api"),
**kwargs,
),
200,
Expand Down Expand Up @@ -321,7 +321,7 @@ def get_count_endpoint(name, entity_type, backend, **kwargs):

class CountEndpoint(Resource):
def get(self):
filters = get_filters_from_query_string()
filters = get_filters_from_query_string("datagateway_api")
return (
backend.count_with_filters(
get_session_id_from_auth_header(), entity_type, filters, **kwargs,
Expand Down Expand Up @@ -380,7 +380,7 @@ def get_find_one_endpoint(name, entity_type, backend, **kwargs):

class FindOneEndpoint(Resource):
def get(self):
filters = get_filters_from_query_string()
filters = get_filters_from_query_string("datagateway_api")
return (
backend.get_one_with_filters(
get_session_id_from_auth_header(), entity_type, filters, **kwargs,
Expand Down
17 changes: 17 additions & 0 deletions datagateway_api/src/resources/search_api_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from flask_restful import Resource

from datagateway_api.src.search_api.helpers import (
Expand All @@ -7,6 +9,9 @@
get_search,
get_with_id,
)
from datagateway_api.src.common.helpers import get_filters_from_query_string

log = logging.getLogger()


def get_search_endpoint(name):
Expand All @@ -16,6 +21,8 @@ def get_search_endpoint(name):

class Endpoint(Resource):
def get(self):
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
return get_search(name), 200

# TODO - Add `get.__doc__`
Expand All @@ -31,6 +38,8 @@ def get_single_endpoint(name):

class EndpointWithID(Resource):
def get(self, pid):
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
return get_with_id(name, pid), 200

# TODO - Add `get.__doc__`
Expand All @@ -46,6 +55,9 @@ def get_number_count_endpoint(name):

class CountEndpoint(Resource):
def get(self):
# Only WHERE included on count endpoints
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
return get_count(name), 200

# TODO - Add `get.__doc__`
Expand All @@ -61,6 +73,8 @@ def get_files_endpoint(name):

class FilesEndpoint(Resource):
def get(self, pid):
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
return get_files(name), 200

# TODO - Add `get.__doc__`
Expand All @@ -76,6 +90,9 @@ def get_number_count_files_endpoint(name):

class CountFilesEndpoint(Resource):
def get(self, pid):
# Only WHERE included on count endpoints
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
return get_files_count(name, pid)

# TODO - Add `get.__doc__`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get(self, id_):
backend.get_facility_cycles_for_instrument_with_filters(
get_session_id_from_auth_header(),
id_,
get_filters_from_query_string(),
get_filters_from_query_string("datagateway_api"),
**kwargs,
),
200,
Expand Down Expand Up @@ -126,7 +126,7 @@ def get(self, id_):
backend.get_facility_cycles_for_instrument_count_with_filters(
get_session_id_from_auth_header(),
id_,
get_filters_from_query_string(),
get_filters_from_query_string("datagateway_api"),
**kwargs,
),
200,
Expand Down Expand Up @@ -202,7 +202,7 @@ def get(self, instrument_id, cycle_id):
get_session_id_from_auth_header(),
instrument_id,
cycle_id,
get_filters_from_query_string(),
get_filters_from_query_string("datagateway_api"),
**kwargs,
),
200,
Expand Down Expand Up @@ -272,7 +272,7 @@ def get(self, instrument_id, cycle_id):
get_session_id_from_auth_header(),
instrument_id,
cycle_id,
get_filters_from_query_string(),
get_filters_from_query_string("datagateway_api"),
**kwargs,
),
200,
Expand Down
22 changes: 22 additions & 0 deletions datagateway_api/src/search_api/filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from icat.client import Client
from icat.query import Query

from datagateway_api.src.datagateway_api.icat.filters import (
PythonICATIncludeFilter,
PythonICATLimitFilter,
Expand All @@ -16,6 +19,25 @@ def __init__(self, field, value, operation):
def apply_filter(self, query):
return super().apply_filter(query)

def __str__(self):
# TODO - replace with `SessionHandler.client` when that work is merged
client = Client("https://localhost.localdomain:8181", checkCert=False)
client.login("simple", {"username": "root", "password": "pw"})

# TODO - can't just hardcode investigation entity. Might need `icat_entity_name`
# to be passed into init
query = Query(client, "Investigation")
query.addConditions(self.create_filter())
str_conds = query.get_conditions_as_str()

return str_conds[0]

def __repr__(self):
return (
f"Field: '{self.field}', Value: '{self.value}', Operation:"
f" '{self.operation}'"
)


class SearchAPISkipFilter(PythonICATSkipFilter):
def __init__(self, skip_value):
Expand Down
53 changes: 53 additions & 0 deletions datagateway_api/src/search_api/nested_where_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
class NestedWhereFilters:
def __init__(self, lhs, rhs, joining_operator):
"""
Class to represent nested conditions that use different boolean operators e.g.
`(A OR B) AND (C OR D)`. This works by joining the two conditions with a boolean
operator
:param lhs: Left hand side of the condition - either a string condition, WHERE
filter or instance of this class
:type lhs: Any class that has `__str__()` implemented, but use cases will be for
:class:`str` or :class:`SearchAPIWhereFilter` or :class:`NestedWhereFilters`
:param rhs: Right hand side of the condition - either a string condition, WHERE
filter or instance of this class
:type rhs: Any class that has `__str__()` implemented, but use cases will be for
:class:`str` or :class:`SearchAPIWhereFilter` or :class:`NestedWhereFilters`
:param joining_operator: Boolean operator used to join the conditions of `lhs`
`rhs` (e.g. `AND` or `OR`)
:type joining_operator: :class:`str`
"""

# Ensure each side is in a list for consistency for string conversion
if not isinstance(lhs, list):
lhs = [lhs]
if not isinstance(rhs, list):
rhs = [rhs]

self.lhs = lhs
self.rhs = rhs
self.joining_operator = joining_operator

def __str__(self):
"""
Join the condition on the left with the one on the right with the boolean
operator
"""
boolean_algebra_list = [self.lhs, self.rhs]
try:
boolean_algebra_list.remove([None])
except ValueError:
# If neither side contains `None`, we should continue as normal
pass

# If either side contains a list of WHERE filter objects, flatten the conditions
conditions = [str(m) for n in (i for i in boolean_algebra_list) for m in n]
operator = f" {self.joining_operator} "

return f"({operator.join(conditions)})"

def __repr__(self):
return (
f"LHS: {repr(self.lhs)}, RHS: {repr(self.rhs)}, Operator:"
f" {repr(self.joining_operator)}"
)
Loading

0 comments on commit 75f8a91

Please sign in to comment.