Skip to content

Commit

Permalink
Merge branch 'master' into feature/query-params-search-api-#259
Browse files Browse the repository at this point in the history
  • Loading branch information
Viktor Bozhinov committed Dec 9, 2021
2 parents b0a4c47 + dd23cce commit 7ae6e77
Show file tree
Hide file tree
Showing 47 changed files with 486 additions and 285 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

<!--next-version-placeholder-->

## v3.1.0 (2021-12-06)
### Feature
* Implement session/client handling for search API #258 ([`46a1539`](https://github.com/ral-facilities/datagateway-api/commit/46a1539398f63e9c8a6539d703a164dd7c8749e7))

## v3.0.1 (2021-11-24)
### Fix
* Allow blank extensions and slash extension to be valid ([`70ddb7a`](https://github.com/ral-facilities/datagateway-api/commit/70ddb7a4fd89ba10b06cd71c3ab2a98648cfb773))
Expand Down
4 changes: 1 addition & 3 deletions datagateway_api/config.json.example
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
"search_api": {
"extension": "/search-api",
"icat_url": "https://localhost:8181",
"icat_check_cert": false,
"client_pool_init_size": 2,
"client_pool_max_size": 5
"icat_check_cert": false
},
"flask_reloader": false,
"log_level": "WARN",
Expand Down
28 changes: 15 additions & 13 deletions datagateway_api/src/api_start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from flask_restful import Api
from flask_swagger_ui import get_swaggerui_blueprint

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config

# Only attempt to create a DataGateway API backend if the datagateway_api object
# is present in the config. This ensures that the API does not error on startup
# due to an AttributeError exception being thrown if the object is missing.
if config.datagateway_api is not None:
if Config.config.datagateway_api is not None:
from datagateway_api.src.datagateway_api.backends import create_backend
from datagateway_api.src.datagateway_api.database.helpers import db # noqa: I202
from datagateway_api.src.datagateway_api.icat.icat_client_pool import create_client_pool
Expand Down Expand Up @@ -73,15 +73,17 @@ def create_app_infrastructure(flask_app):
flask_app.url_map.strict_slashes = False
api = CustomErrorHandledApi(flask_app)

if config.datagateway_api is not None:
if Config.config.datagateway_api is not None:
try:
backend_type = flask_app.config["TEST_BACKEND"]
config.datagateway_api.set_backend_type(backend_type)
Config.config.datagateway_api.set_backend_type(backend_type)
except KeyError:
backend_type = config.datagateway_api.backend
backend_type = Config.config.datagateway_api.backend

if backend_type == "db":
flask_app.config["SQLALCHEMY_DATABASE_URI"] = config.datagateway_api.db_url
flask_app.config[
"SQLALCHEMY_DATABASE_URI"
] = Config.config.datagateway_api.db_url
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db.init_app(flask_app)

Expand All @@ -92,12 +94,12 @@ def create_app_infrastructure(flask_app):

def create_api_endpoints(flask_app, api, spec):
# DataGateway API endpoints
if config.datagateway_api is not None:
if Config.config.datagateway_api is not None:
try:
backend_type = flask_app.config["TEST_BACKEND"]
config.datagateway_api.set_backend_type(backend_type)
Config.config.datagateway_api.set_backend_type(backend_type)
except KeyError:
backend_type = config.datagateway_api.backend
backend_type = Config.config.datagateway_api.backend

backend = create_backend(backend_type)

Expand All @@ -106,7 +108,7 @@ def create_api_endpoints(flask_app, api, spec):
# Create client pool
icat_client_pool = create_client_pool()

datagateway_api_extension = config.datagateway_api.extension
datagateway_api_extension = Config.config.datagateway_api.extension
for entity_name in endpoints:
get_endpoint_resource = get_endpoint(
entity_name,
Expand Down Expand Up @@ -220,8 +222,8 @@ def create_api_endpoints(flask_app, api, spec):
spec.path(resource=ping_resource, api=api)

# Search API endpoints
if config.search_api is not None:
search_api_extension = config.search_api.extension
if Config.config.search_api is not None:
search_api_extension = Config.config.search_api.extension
search_api_entity_endpoints = ["datasets", "documents", "instruments"]

for entity_name in search_api_entity_endpoints:
Expand Down Expand Up @@ -271,7 +273,7 @@ def create_api_endpoints(flask_app, api, spec):
def openapi_config(spec):
# Reorder paths (e.g. get, patch, post) so openapi.yaml only changes when there's a
# change to the Swagger docs, rather than changing on each startup
if config.generate_swagger:
if Config.config.generate_swagger:
log.debug("Reordering OpenAPI docs to alphabetical order")
for entity_data in spec._paths.values():
for endpoint_name in sorted(entity_data.keys()):
Expand Down
7 changes: 4 additions & 3 deletions datagateway_api/src/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ class SearchAPI(BaseModel):
validation of the SearchAPI config data using Python type annotations.
"""

client_pool_init_size: StrictInt
client_pool_max_size: StrictInt
extension: StrictStr
icat_check_cert: StrictBool
icat_url: StrictStr
Expand Down Expand Up @@ -216,4 +214,7 @@ def validate_api_extensions(cls, value, values): # noqa: B902, N805
return value


config = APIConfig.load()
class Config:
"""Class containing config as a class variable so it can mocked during testing"""

config = APIConfig.load()
8 changes: 4 additions & 4 deletions datagateway_api/src/common/logger_setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging.config
from pathlib import Path

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config

LOG_FILE_NAME = Path(config.log_location)
LOG_FILE_NAME = Path(Config.config.log_location)
logger_config = {
"version": 1,
"formatters": {
Expand All @@ -14,15 +14,15 @@
},
"handlers": {
"default": {
"level": config.log_level,
"level": Config.config.log_level,
"formatter": "default",
"class": "logging.handlers.RotatingFileHandler",
"filename": LOG_FILE_NAME,
"maxBytes": 5000000,
"backupCount": 10,
},
},
"root": {"level": config.log_level, "handlers": ["default"]},
"root": {"level": Config.config.log_level, "handlers": ["default"]},
}


Expand Down
18 changes: 13 additions & 5 deletions datagateway_api/src/datagateway_api/icat/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import FilterError
from datagateway_api.src.common.filters import (
DistinctFieldFilter,
Expand Down Expand Up @@ -210,13 +210,21 @@ def apply_filter(self, query):


class PythonICATSkipFilter(SkipFilter):
def __init__(self, skip_value):
def __init__(self, skip_value, filter_use="datagateway_api"):
super().__init__(skip_value)
self.filter_use = filter_use

def apply_filter(self, query):
icat_properties = get_icat_properties(
config.datagateway_api.icat_url, config.datagateway_api.icat_check_cert,
)
if self.filter_use == "datagateway_api":
icat_properties = get_icat_properties(
Config.config.datagateway_api.icat_url,
Config.config.datagateway_api.icat_check_cert,
)
else:
icat_properties = get_icat_properties(
Config.config.search_api.icat_url,
Config.config.search_api.icat_check_cert,
)
icat_set_limit(query, self.skip_value, icat_properties["maxEntities"])


Expand Down
21 changes: 13 additions & 8 deletions datagateway_api/src/datagateway_api/icat/icat_client_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@
from icat.client import Client
from object_pool import ObjectPool

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config

log = logging.getLogger()


class ICATClient(Client):
"""Wrapper class to allow an object pool of client objects to be created"""

def __init__(self):
super().__init__(
config.datagateway_api.icat_url,
checkCert=config.datagateway_api.icat_check_cert,
)
def __init__(self, client_use="datagateway_api"):
if client_use == "datagateway_api":
icat_url = Config.config.datagateway_api.icat_url
icat_check_cert = Config.config.datagateway_api.icat_check_cert
else:
# Search API use cases
icat_url = Config.config.search_api.icat_url
icat_check_cert = Config.config.search_api.icat_check_cert

super().__init__(icat_url, checkCert=icat_check_cert)
# When clients are cleaned up, sessions won't be logged out
self.autoLogout = False

Expand All @@ -36,8 +41,8 @@ def create_client_pool():

return ObjectPool(
ICATClient,
min_init=config.datagateway_api.client_pool_init_size,
max_capacity=config.datagateway_api.client_pool_max_size,
min_init=Config.config.datagateway_api.client_pool_init_size,
max_capacity=Config.config.datagateway_api.client_pool_max_size,
max_reusable=0,
expires=0,
)
4 changes: 2 additions & 2 deletions datagateway_api/src/datagateway_api/icat/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from cachetools.lru import LRUCache

from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config

log = logging.getLogger()

Expand All @@ -19,7 +19,7 @@ class ExtendedLRUCache(LRUCache):
"""

def __init__(self):
super().__init__(maxsize=config.datagateway_api.client_cache_size)
super().__init__(maxsize=Config.config.datagateway_api.client_cache_size)

def popitem(self):
key, client = super().popitem()
Expand Down
4 changes: 2 additions & 2 deletions datagateway_api/src/datagateway_api/query_filter_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from datagateway_api.src.common.base_query_filter_factory import QueryFilterFactory
from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import (
ApiError,
FilterError,
Expand Down Expand Up @@ -33,7 +33,7 @@ def get_query_filter(request_filter, entity_name=None):
:raises FilterError: If the filter name is not recognised
"""

backend_type = config.datagateway_api.backend
backend_type = Config.config.datagateway_api.backend
if backend_type == "db":
from datagateway_api.src.datagateway_api.database.filters import (
DatabaseDistinctFieldFilter as DistinctFieldFilter,
Expand Down
10 changes: 5 additions & 5 deletions datagateway_api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
create_openapi_endpoint,
openapi_config,
)
from datagateway_api.src.common.config import config
from datagateway_api.src.common.config import Config
from datagateway_api.src.common.logger_setup import setup_logger

setup_logger()
Expand All @@ -23,8 +23,8 @@

if __name__ == "__main__":
app.run(
host=config.host,
port=config.port,
debug=config.debug_mode,
use_reloader=config.flask_reloader,
host=Config.config.host,
port=Config.config.port,
debug=Config.config.debug_mode,
use_reloader=Config.config.flask_reloader,
)
53 changes: 12 additions & 41 deletions datagateway_api/src/resources/search_api_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,22 @@
import logging

from flask_restful import Resource

from datagateway_api.src.common.helpers import get_filters_from_query_string

log = logging.getLogger()
from datagateway_api.src.search_api.helpers import (
get_count,
get_files,
get_files_count,
get_search,
get_with_id,
)


# TODO - Might need kwargs on get_search_endpoint(), get_single_endpoint(),
# get_number_count_endpoint(), get_files_endpoint(), get_number_count_files_endpoint()
# for client handling?
def get_search_endpoint(name):
"""
TODO - Add docstring
"""

class Endpoint(Resource):
def get(self):
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
"""
TODO - Need to return similar to
return (
backend.get_with_filters(
get_session_id_from_auth_header(),
entity_type,
get_filters_from_query_string(),
**kwargs,
),
200,
)
"""
pass
return get_search(name), 200

# TODO - Add `get.__doc__`

Expand All @@ -46,10 +31,7 @@ def get_single_endpoint(name):

class EndpointWithID(Resource):
def get(self, pid):
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
# TODO - Add return
pass
return get_with_id(name, pid), 200

# TODO - Add `get.__doc__`

Expand All @@ -64,11 +46,7 @@ def get_number_count_endpoint(name):

class CountEndpoint(Resource):
def get(self):
# Only WHERE included on count endpoints
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
# TODO - Add return
pass
return get_count(name), 200

# TODO - Add `get.__doc__`

Expand All @@ -83,10 +61,7 @@ def get_files_endpoint(name):

class FilesEndpoint(Resource):
def get(self, pid):
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
# TODO - Add return
pass
return get_files(name), 200

# TODO - Add `get.__doc__`

Expand All @@ -101,11 +76,7 @@ def get_number_count_files_endpoint(name):

class CountFilesEndpoint(Resource):
def get(self, pid):
# Only WHERE included on count endpoints
filters = get_filters_from_query_string("search_api", name)
log.debug("Filters: %s", filters)
# TODO - Add return
pass
return get_files_count(name, pid)

# TODO - Add `get.__doc__`

Expand Down
2 changes: 1 addition & 1 deletion datagateway_api/src/search_api/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __repr__(self):

class SearchAPISkipFilter(PythonICATSkipFilter):
def __init__(self, skip_value):
super().__init__(skip_value)
super().__init__(skip_value, filter_use="search_api")

def apply_filter(self, query):
return super().apply_filter(query)
Expand Down
Loading

0 comments on commit 7ae6e77

Please sign in to comment.