Skip to content

Commit

Permalink
feat: add support for scoring results #398
Browse files Browse the repository at this point in the history
  • Loading branch information
VKTB committed Feb 10, 2023
1 parent c328a0d commit dd58a8b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 35 deletions.
35 changes: 16 additions & 19 deletions datagateway_api/src/resources/search_api_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@

from flask_restful import Resource

from datagateway_api.src.common.config import Config
from datagateway_api.src.common.helpers import get_filters_from_query_string
from datagateway_api.src.search_api.filters import SearchAPIScoringFilter
from datagateway_api.src.search_api.helpers import (
add_scores_to_entities,
add_scores_to_results,
get_count,
get_files,
get_files_count,
get_score,
get_search,
get_search_api_query_filter_list,
get_with_pid,
is_query_filter,
search_api_error_handling,
)

Expand All @@ -35,21 +33,20 @@ class Endpoint(Resource):
@search_api_error_handling
def get(self):
filters = get_filters_from_query_string("search_api", entity_name)
# in case there is no query filter then we processed as usual
if not is_query_filter(filters):
return get_search(entity_name, filters), 200
else:
query = get_search_api_query_filter_list(filters)[0].value
entities = get_search(
entity_name,
filters,
"LOWER(o.summary) like '%" + query.lower() + "%'",
)

if Config.config.search_api.scoring_enabled:
scores = get_score(entities, query)
entities = add_scores_to_entities(entities, scores)
return entities, 200
results = get_search(entity_name, filters)
scoring_filter = next(
(
filter_
for filter_ in filters
if isinstance(filter_, SearchAPIScoringFilter)
),
None,
)
if scoring_filter:
scores = get_score(results, scoring_filter.value)
results = add_scores_to_results(results, scores)

return results, 200

get.__doc__ = f"""
---
Expand Down
16 changes: 0 additions & 16 deletions datagateway_api/src/search_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
SearchAPIIncludeFilter,
SearchAPIWhereFilter,
)
from datagateway_api.src.search_api.filters import SearchAPIScoringFilter
import datagateway_api.src.search_api.models as models
from datagateway_api.src.search_api.query import SearchAPIQuery
from datagateway_api.src.search_api.session_handler import (
Expand Down Expand Up @@ -283,18 +282,3 @@ def get_files_count(entity_name, filters, pid):

filters.append(SearchAPIWhereFilter("dataset.pid", pid, "eq"))
return get_count(entity_name, filters)


def get_search_api_query_filter_list(filters):
"""
Returns the list of SearchAPIQueryFilter that are in the filters array
"""
return list(filter(lambda x: isinstance(x, SearchAPIScoringFilter), filters))


@client_manager
def is_query_filter(filters):
"""
Checks if there is a SearchAPIQueryFilter in the list of filters
"""
return len(get_search_api_query_filter_list(filters)) == 1

0 comments on commit dd58a8b

Please sign in to comment.