Skip to content

Commit

Permalink
Merge pull request #399 from ral-facilities/398-adapt-search-query-fi…
Browse files Browse the repository at this point in the history
…lter-and-scoring-to-v5

Query filter and scoring
  • Loading branch information
VKTB authored Feb 24, 2023
2 parents dd73f90 + a4f833e commit a09748c
Show file tree
Hide file tree
Showing 24 changed files with 479 additions and 97 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,24 @@ adding conditions via dictionaries. This is needed where queries are joined with
or `OR`. This collates all the work from `NestedWhereFilters` so all requires types of
conditions can be supported.

### Search Scoring

Search scoring allows for the results returned by the Search API to be scored in terms of
relevancy. The config option `enabled` from the `search_scoring` object in `config.yaml`
can be used to enable or disable the search scoring. When enabled, it handles the `query`
filter provided in the requests sent by the [Federated Photon and Neutron Search Service](https://github.com/panosc-eu/panosc-federated-search-service),
otherwise, it returns an error to indicate that the `query` filter is not supported.
For this functionality to work, an instance of the [PaNOSC Search Scoring Service](https://github.com/panosc-eu/panosc-search-scoring/)
is needed which has been configured and populated as per the instructions in its
repository and can return scores. The full URL to its `/score` endpoint will need to be
provided to the config option `api_url` from the `search_scoring` object in `config.yaml`
so that the Search API know where to send its result from ICAT along with the value from
the `query` filter for scoring.

The [European Photon and Neutron Open Data Search Portal](https://data.panosc.eu/)
requires all Search APIs that want to be integrated with the portal to support search
scoring.

## Generating the OpenAPI Specification

When the config option `generate_swagger` is set to true in `config.yaml`, a YAML
Expand Down
6 changes: 6 additions & 0 deletions datagateway_api/config.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ search_api:
mechanism: "anon"
username: ""
password: ""
search_scoring:
enabled: false
api_url: "http://localhost:9000/score"
api_request_timeout: 5
group: "documents" #corresponds to the defined group in the scoring app. https://github.com/panosc-eu/panosc-search-scoring/blob/master/docs/md/PaNOSC_Federated_Search_Results_Scoring_API.md#model
limit: 1000
flask_reloader: false
log_level: "DEBUG"
log_location: "/home/runner/work/datagateway/datagateway/datagateway-api/datagateway_api/logs.log"
Expand Down
14 changes: 7 additions & 7 deletions datagateway_api/src/api_start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def create_api_endpoints(flask_app, api, specs):
)
search_api_extension = Config.config.search_api.extension
search_api_entity_endpoints = {
"datasets": "Dataset",
"documents": "Document",
"instruments": "Instrument",
"Datasets": "Dataset",
"Documents": "Document",
"Instruments": "Instrument",
}

for endpoint_name, entity_name in search_api_entity_endpoints.items():
Expand Down Expand Up @@ -320,8 +320,8 @@ def create_api_endpoints(flask_app, api, specs):
get_files_endpoint_resource = get_files_endpoint("File")
api.add_resource(
get_files_endpoint_resource,
f"{search_api_extension}/datasets/<string:pid>/files",
endpoint="search_api_get_dataset_files",
f"{search_api_extension}/Datasets/<string:pid>/files",
endpoint="search_api_get_Dataset_files",
)
search_api_spec.path(resource=get_files_endpoint_resource, api=api)

Expand All @@ -330,8 +330,8 @@ def create_api_endpoints(flask_app, api, specs):
)
api.add_resource(
get_number_count_files_endpoint_resource,
f"{search_api_extension}/datasets/<string:pid>/files/count",
endpoint="search_api_count_dataset_files",
f"{search_api_extension}/Datasets/<string:pid>/files/count",
endpoint="search_api_count_Dataset_files",
)
search_api_spec.path(resource=get_number_count_files_endpoint_resource, api=api)

Expand Down
9 changes: 9 additions & 0 deletions datagateway_api/src/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ class Config:
validate_assignment = True


class SearchScoring(BaseModel):
enabled: StrictBool
api_url: StrictStr
api_request_timeout: StrictInt
group: StrictStr
limit: StrictInt


class SearchAPI(BaseModel):
"""
Configuration model class that implements pydantic's BaseModel class to allow for
Expand All @@ -133,6 +141,7 @@ class SearchAPI(BaseModel):
mechanism: StrictStr
username: StrictStr
password: StrictStr
search_scoring: SearchScoring

_validate_extension = validator("extension", allow_reuse=True)(validate_extension)

Expand Down
6 changes: 6 additions & 0 deletions datagateway_api/src/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,9 @@ class SearchAPIError(ApiError):
def __init__(self, msg="Search API error", *args, **kwargs):
super().__init__(msg, *args, **kwargs)
self.status_code = 500


class ScoringAPIError(ApiError):
def __init__(self, msg="Scoring API error", *args, **kwargs):
super().__init__(msg, *args, **kwargs)
self.status_code = 500
28 changes: 21 additions & 7 deletions datagateway_api/src/resources/search_api_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from flask_restful import Resource

from datagateway_api.src.common.helpers import get_filters_from_query_string
from datagateway_api.src.search_api.filters import SearchAPIScoringFilter
from datagateway_api.src.search_api.helpers import (
get_count,
get_files,
Expand All @@ -11,6 +12,7 @@
get_with_pid,
search_api_error_handling,
)
from datagateway_api.src.search_api.search_scoring import SearchScoring

log = logging.getLogger()

Expand All @@ -19,7 +21,7 @@ def get_search_endpoint(entity_name):
"""
Given an entity name, generate a flask_restful `Resource` class. In
`create_api_endpoints()`, these generated classes are registered with the API e.g.
`api.add_resource(get_search_endpoint("Dataset"), "/datasets")`
`api.add_resource(get_search_endpoint("Dataset"), "/Datasets")`
:param entity_name: Name of the entity
:type entity_name: :class:`str`
Expand All @@ -30,8 +32,20 @@ class Endpoint(Resource):
@search_api_error_handling
def get(self):
filters = get_filters_from_query_string("search_api", entity_name)
log.debug("Filters: %s", filters)
return get_search(entity_name, filters), 200
results = get_search(entity_name, filters)
scoring_filter = next(
(
filter_
for filter_ in filters
if isinstance(filter_, SearchAPIScoringFilter)
),
None,
)
if scoring_filter:
scores = SearchScoring.get_score(scoring_filter.value)
results = SearchScoring.add_scores_to_results(results, scores)

return results, 200

get.__doc__ = f"""
---
Expand Down Expand Up @@ -66,7 +80,7 @@ def get_single_endpoint(entity_name):
"""
Given an entity name, generate a flask_restful `Resource` class. In
`create_api_endpoints()`, these generated classes are registered with the API e.g.
`api.add_resource(get_single_endpoint("Dataset"), "/datasets/<string:pid>")`
`api.add_resource(get_single_endpoint("Dataset"), "/Datasets/<string:pid>")`
:param entity_name: Name of the entity
:type entity_name: :class:`str`
Expand Down Expand Up @@ -116,7 +130,7 @@ def get_number_count_endpoint(entity_name):
"""
Given an entity name, generate a flask_restful `Resource` class. In
`create_api_endpoints()`, these generated classes are registered with the API e.g.
`api.add_resource(get_number_count_endpoint("Dataset"), "/datasets/count")`
`api.add_resource(get_number_count_endpoint("Dataset"), "/Datasets/count")`
:param entity_name: Name of the entity
:type entity_name: :class:`str`
Expand Down Expand Up @@ -161,7 +175,7 @@ def get_files_endpoint(entity_name):
"""
Given an entity name, generate a flask_restful `Resource` class. In
`create_api_endpoints()`, these generated classes are registered with the API e.g.
`api.add_resource(get_files_endpoint("Dataset"), "/datasets/<string:pid>/files")`
`api.add_resource(get_files_endpoint("Dataset"), "/Datasets/<string:pid>/files")`
:param entity_name: Name of the entity
:type entity_name: :class:`str`
Expand Down Expand Up @@ -217,7 +231,7 @@ def get_number_count_files_endpoint(entity_name):
Given an entity name, generate a flask_restful `Resource` class. In
`create_api_endpoints()`, these generated classes are registered with the API e.g.
`api.add_resource(get_number_count_files_endpoint("Dataset"),
"/datasets<string:pid>/files/count")`
"/Datasets<string:pid>/files/count")`
:param entity_name: Name of the entity
:type entity_name: :class:`str`
Expand Down
13 changes: 13 additions & 0 deletions datagateway_api/src/search_api/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ def apply_filter(self, query):
return super().apply_filter(query.icat_query.query)


class SearchAPIScoringFilter(SearchAPIWhereFilter):
def __init__(self, query_value):
if not isinstance(query_value, str):
raise ValueError("The value of the query filter must be a string")
# We are only supporting scoring on the Document entity/ endpoint so hard
# coding the corresponding field (summary) here that is used when searching for
# documents that match the query_value pattern.
super().__init__(field="summary", value=query_value, operation="ilike")

def apply_filter(self, query):
return super().apply_filter(query)


class SearchAPIIncludeFilter(PythonICATIncludeFilter):
def __init__(self, included_filters, panosc_entity_name):
self.included_filters = included_filters
Expand Down
3 changes: 2 additions & 1 deletion datagateway_api/src/search_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datagateway_api.src.common.exceptions import (
BadRequestError,
MissingRecordError,
ScoringAPIError,
SearchAPIError,
)
from datagateway_api.src.common.filter_order_handler import FilterOrderHandler
Expand Down Expand Up @@ -39,7 +40,7 @@ def search_api_error_handling(method):
def wrapper_error_handling(*args, **kwargs):
try:
return method(*args, **kwargs)
except ValidationError as e:
except (ValidationError, ScoringAPIError) as e:
log.exception(msg=e.args)
assign_status_code(e, 500)
raise SearchAPIError(create_error_message(e))
Expand Down
1 change: 0 additions & 1 deletion datagateway_api/src/search_api/panosc_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def get_icat_mapping(self, panosc_entity_name, field_name):

try:
icat_mapping = self.mappings[panosc_entity_name][field_name]
log.debug("ICAT mapping/translation found: %s", icat_mapping)
except KeyError as e:
raise FilterError(f"Bad PaNOSC to ICAT mapping: {e.args}")

Expand Down
13 changes: 13 additions & 0 deletions datagateway_api/src/search_api/query_filter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@


from datagateway_api.src.common.base_query_filter_factory import QueryFilterFactory
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import FilterError, SearchAPIError
from datagateway_api.src.search_api.filters import (
SearchAPIIncludeFilter,
SearchAPILimitFilter,
SearchAPIScoringFilter,
SearchAPISkipFilter,
SearchAPIWhereFilter,
)
Expand Down Expand Up @@ -62,6 +64,17 @@ def get_query_filter(request_filter, entity_name=None, related_entity_name=None)
elif filter_name == "skip":
log.info("skip JSON object found")
query_filters.append(SearchAPISkipFilter(filter_input))
elif (
filter_name == "query"
and entity_name == "Document"
and Config.config.search_api.search_scoring.enabled
):
# We are only supporting scoring on the Document entity/ endpoint
# so the query filter is not accepted on other entities/ endpoints.
# Scoring must be enabled in order for the query filter to be
# accepted.
log.info("query JSON object found")
query_filters.append(SearchAPIScoringFilter(filter_input))
else:
raise FilterError(
"No valid filter name given within filter query param:"
Expand Down
61 changes: 61 additions & 0 deletions datagateway_api/src/search_api/search_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import requests
from requests import RequestException

from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import ScoringAPIError


class SearchScoring:
@staticmethod
def get_score(query):
"""
Gets the score for all the items in the scoring API according to the query
value provided.
:param query: The term to use in the relevancy scoring
:type query: :class:`str`
:return: Returns the scores
:raises ScoringAPIError: If an error occurs while interacting with the Search
Scoring API
"""
try:
data = {
"query": query,
"group": Config.config.search_api.search_scoring.group,
"limit": Config.config.search_api.search_scoring.limit,
# With itemIds, scoring server returns a 400 error. No idea why.
# "itemIds": list(map(lambda entity: (entity["pid"]), entities)), #
}
response = requests.post(
Config.config.search_api.search_scoring.api_url,
json=data,
timeout=Config.config.search_api.search_scoring.api_request_timeout,
)
response.raise_for_status()
return response.json()["scores"]
except RequestException:
raise ScoringAPIError("An error occurred while trying to score the results")

@staticmethod
def add_scores_to_results(results, scores):
"""
Add the scores to all the results returned from the metadata catalogue. It only
adds the score if it finds a match, otherwise the score is set to -1
(arbitrarily chosen).
:param results: List of results that have been retrieved from the metadata
catalogue
:type results: :class:`list`
:param scores: List of items retrieved from the scoring application
:type scores: :class:`list`
:return: Returns the results with scores
"""
for result in results:
result["score"] = next(
(
score["score"]
for score in scores
if score["itemId"] == result["pid"]
),
-1,
)

return results
Loading

0 comments on commit a09748c

Please sign in to comment.