Skip to content

Commit

Permalink
refactor: move search scoring logic to a separate module #398
Browse files Browse the repository at this point in the history
  • Loading branch information
VKTB committed Feb 14, 2023
1 parent 0472d96 commit 70cd754
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 67 deletions.
7 changes: 3 additions & 4 deletions datagateway_api/src/resources/search_api_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
from datagateway_api.src.common.helpers import get_filters_from_query_string
from datagateway_api.src.search_api.filters import SearchAPIScoringFilter
from datagateway_api.src.search_api.helpers import (
add_scores_to_results,
get_count,
get_files,
get_files_count,
get_score,
get_search,
get_with_pid,
search_api_error_handling,
)
from datagateway_api.src.search_api.search_scoring import SearchScoring

log = logging.getLogger()

Expand Down Expand Up @@ -43,8 +42,8 @@ def get(self):
None,
)
if scoring_filter:
scores = get_score(results, scoring_filter.value)
results = add_scores_to_results(results, scores)
scores = SearchScoring.get_score(scoring_filter.value)
results = SearchScoring.add_scores_to_results(results, scores)

return results, 200

Expand Down
50 changes: 0 additions & 50 deletions datagateway_api/src/search_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import logging

from pydantic import ValidationError
import requests
from requests import RequestException

from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import (
BadRequestError,
MissingRecordError,
Expand Down Expand Up @@ -78,53 +75,6 @@ def create_error_message(e):
return wrapper_error_handling


def get_score(entities, query):
"""
Gets the score on the given entities based in the query parameter
that is the term to be found
:param entities: List of entities that have been retrieved from one ICAT query.
:type entities: :class:`list`
:param query: String with the term to be searched by
:type query: :class:`str`
"""
try:
data = {
"query": query,
"group": Config.config.search_api.search_scoring.group,
"limit": Config.config.search_api.search_scoring.limit,
# With itemIds, scoring server returns a 400 error. No idea why.
# "itemIds": list(map(lambda entity: (entity["pid"]), entities)), #
}
response = requests.post(
Config.config.search_api.search_scoring.api_url,
json=data,
timeout=Config.config.search_api.search_scoring.api_request_timeout,
)
response.raise_for_status()
return response.json()["scores"]
except RequestException:
raise ScoringAPIError("An error occurred while trying to score the results")


def add_scores_to_results(results, scores):
"""
For each entity this function adds the score if it is found by matching
the score.item.itemsId with the pid of the entity
Otherwise the score is filled with -1 (arbitrarily chosen)
:param results: List of entities that have been retrieved from one ICAT query.
:type results: :class:`list`
:param scores: List of items retrieved from the scoring application
:type scores: :class:`list`
"""
for result in results:
result["score"] = next(
(score["score"] for score in scores if score["itemId"] == result["pid"]),
-1,
)

return results


@client_manager
def get_search(entity_name, filters, str_conditions=None):
"""
Expand Down
61 changes: 61 additions & 0 deletions datagateway_api/src/search_api/search_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import requests
from requests import RequestException

from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import ScoringAPIError


class SearchScoring:
@staticmethod
def get_score(query):
"""
Gets the score for all the items in the scoring API according to the query
value provided.
:param query: The term to use in the relevancy scoring
:type query: :class:`str`
:return: Returns the scores
:raises ScoringAPIError: If an error occurs while interacting with the Search
Scoring API
"""
try:
data = {
"query": query,
"group": Config.config.search_api.search_scoring.group,
"limit": Config.config.search_api.search_scoring.limit,
# With itemIds, scoring server returns a 400 error. No idea why.
# "itemIds": list(map(lambda entity: (entity["pid"]), entities)), #
}
response = requests.post(
Config.config.search_api.search_scoring.api_url,
json=data,
timeout=Config.config.search_api.search_scoring.api_request_timeout,
)
response.raise_for_status()
return response.json()["scores"]
except RequestException:
raise ScoringAPIError("An error occurred while trying to score the results")

@staticmethod
def add_scores_to_results(results, scores):
"""
Add the scores to all the results returned from the metadata catalogue. It only
adds the score if it finds a match, otherwise the score is set to -1
(arbitrarily chosen).
:param results: List of results that have been retrieved from the metadata
catalogue
:type results: :class:`list`
:param scores: List of items retrieved from the scoring application
:type scores: :class:`list`
:return: Returns the results with scores
"""
for result in results:
result["score"] = next(
(
score["score"]
for score in scores
if score["itemId"] == result["pid"]
),
-1,
)

return results
23 changes: 10 additions & 13 deletions test/unit/search_api/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import ScoringAPIError
from datagateway_api.src.search_api.helpers import add_scores_to_results, get_score
from datagateway_api.src.search_api.search_scoring import SearchScoring

SEARCH_SCORING_API_SCORES_DATA = {
"request": {
Expand Down Expand Up @@ -83,13 +83,13 @@ def test_get_score(self, post_mock):
scoring_query_filter_value = "My test query"
post_request_data = {
"query": scoring_query_filter_value,
"group": "Documents",
"limit": 1000,
"group": Config.config.search_api.search_scoring.group,
"limit": Config.config.search_api.search_scoring.limit,
}
post_mock.return_value.status_code = 200
post_mock.return_value.json.return_value = SEARCH_SCORING_API_SCORES_DATA

scores = get_score(SEARCH_API_DOCUMENT_RESULTS, scoring_query_filter_value)
scores = SearchScoring.get_score(scoring_query_filter_value)

post_mock.assert_called_once_with(
Config.config.search_api.search_scoring.api_url,
Expand All @@ -102,21 +102,18 @@ def test_get_score(self, post_mock):
def test_get_score_raises_scoring_api_error(self, post_mock):
post_mock.side_effect = RequestException
with pytest.raises(ScoringAPIError):
get_score(SEARCH_API_DOCUMENT_RESULTS, "My test query")
SearchScoring.get_score("My test query")

def test_add_score_to_results(self):
expected_search_api_document_results_with_scores = []
expected_results_with_scores = []
# Add scores to document results
scores = [0.7071067811865475, -1, 0.53843041]
for i, result in enumerate(SEARCH_API_DOCUMENT_RESULTS):
expected_search_api_document_results_with_scores.append(result.copy())
expected_search_api_document_results_with_scores[i]["score"] = scores[i]
expected_results_with_scores.append(result.copy())
expected_results_with_scores[i]["score"] = scores[i]

actual_search_api_document_results_with_scores = add_scores_to_results(
actual_results_with_scores = SearchScoring.add_scores_to_results(
SEARCH_API_DOCUMENT_RESULTS, SEARCH_SCORING_API_SCORES_DATA["scores"],
)

assert (
actual_search_api_document_results_with_scores
== expected_search_api_document_results_with_scores
)
assert actual_results_with_scores == expected_results_with_scores

0 comments on commit 70cd754

Please sign in to comment.