From dd58a8b4bef706b343144ab8fe14a8b4737b3781 Mon Sep 17 00:00:00 2001 From: Viktor Bozhinov Date: Fri, 10 Feb 2023 16:03:55 +0000 Subject: [PATCH] feat: add support for scoring results #398 --- .../src/resources/search_api_endpoints.py | 35 +++++++++---------- datagateway_api/src/search_api/helpers.py | 16 --------- 2 files changed, 16 insertions(+), 35 deletions(-) diff --git a/datagateway_api/src/resources/search_api_endpoints.py b/datagateway_api/src/resources/search_api_endpoints.py index 1901133b..afdf82f3 100644 --- a/datagateway_api/src/resources/search_api_endpoints.py +++ b/datagateway_api/src/resources/search_api_endpoints.py @@ -2,18 +2,16 @@ from flask_restful import Resource -from datagateway_api.src.common.config import Config from datagateway_api.src.common.helpers import get_filters_from_query_string +from datagateway_api.src.search_api.filters import SearchAPIScoringFilter from datagateway_api.src.search_api.helpers import ( - add_scores_to_entities, + add_scores_to_results, get_count, get_files, get_files_count, get_score, get_search, - get_search_api_query_filter_list, get_with_pid, - is_query_filter, search_api_error_handling, ) @@ -35,21 +33,20 @@ class Endpoint(Resource): @search_api_error_handling def get(self): filters = get_filters_from_query_string("search_api", entity_name) - # in case there is no query filter then we processed as usual - if not is_query_filter(filters): - return get_search(entity_name, filters), 200 - else: - query = get_search_api_query_filter_list(filters)[0].value - entities = get_search( - entity_name, - filters, - "LOWER(o.summary) like '%" + query.lower() + "%'", - ) - - if Config.config.search_api.scoring_enabled: - scores = get_score(entities, query) - entities = add_scores_to_entities(entities, scores) - return entities, 200 + results = get_search(entity_name, filters) + scoring_filter = next( + ( + filter_ + for filter_ in filters + if isinstance(filter_, SearchAPIScoringFilter) + ), + None, + ) + if scoring_filter: + scores = get_score(results, scoring_filter.value) + results = add_scores_to_results(results, scores) + + return results, 200 get.__doc__ = f""" --- diff --git a/datagateway_api/src/search_api/helpers.py b/datagateway_api/src/search_api/helpers.py index 161dc215..2eba5250 100644 --- a/datagateway_api/src/search_api/helpers.py +++ b/datagateway_api/src/search_api/helpers.py @@ -18,7 +18,6 @@ SearchAPIIncludeFilter, SearchAPIWhereFilter, ) -from datagateway_api.src.search_api.filters import SearchAPIScoringFilter import datagateway_api.src.search_api.models as models from datagateway_api.src.search_api.query import SearchAPIQuery from datagateway_api.src.search_api.session_handler import ( @@ -283,18 +282,3 @@ def get_files_count(entity_name, filters, pid): filters.append(SearchAPIWhereFilter("dataset.pid", pid, "eq")) return get_count(entity_name, filters) - - -def get_search_api_query_filter_list(filters): - """ - Returns the list of SearchAPIQueryFilter that are in the filters array - """ - return list(filter(lambda x: isinstance(x, SearchAPIScoringFilter), filters)) - - -@client_manager -def is_query_filter(filters): - """ - Checks if there is a SearchAPIQueryFilter in the list of filters - """ - return len(get_search_api_query_filter_list(filters)) == 1