Skip to content

Commit

Permalink
Merge pull request #124 from ral-facilities/feature/improve-openpi-sp…
Browse files Browse the repository at this point in the history
…ec-gen-#123

Improve OpenAPI spec generation
  • Loading branch information
louise-davies authored May 29, 2020
2 parents 1d84c39 + 03d768f commit 130a736
Show file tree
Hide file tree
Showing 14 changed files with 12,411 additions and 8,153 deletions.
15 changes: 9 additions & 6 deletions common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def requires_session_id(method):
:returns a 403, "Forbidden" if a valid session_id is not provided with the request
"""


@wraps(method)
def wrapper_requires_session(*args, **kwargs):
log.info(" Authenticating consumer")
Expand All @@ -45,7 +44,6 @@ def wrapper_requires_session(*args, **kwargs):
except AuthenticationError:
return "Forbidden", 403


return wrapper_requires_session


Expand All @@ -54,6 +52,7 @@ def queries_records(method):
Decorator for endpoint resources that search for a record in a table
:param method: The method for the endpoint
:return: Will return a 404, "No such record" if a MissingRecordError is caught
:return: Will return a 400, "Error message" if other expected errors are caught
"""

@wraps(method)
Expand Down Expand Up @@ -93,11 +92,14 @@ def get_session_id_from_auth_header():
parser = reqparse.RequestParser()
parser.add_argument("Authorization", location="headers")
args = parser.parse_args()
auth_header = args["Authorization"].split(" ") if args["Authorization"] is not None else ""
auth_header = args["Authorization"].split(
" ") if args["Authorization"] is not None else ""
if auth_header == "":
raise MissingCredentialsError(f"No credentials provided in auth header")
raise MissingCredentialsError(
f"No credentials provided in auth header")
if len(auth_header) != 2 or auth_header[0] != "Bearer":
raise AuthenticationError(f" Could not authenticate consumer with auth header {auth_header}")
raise AuthenticationError(
f" Could not authenticate consumer with auth header {auth_header}")
return auth_header[1]


Expand Down Expand Up @@ -125,5 +127,6 @@ def get_filters_from_query_string():
filters = []
for arg in request.args:
for value in request.args.getlist(arg):
filters.append(QueryFilterFactory.get_query_filter({arg: json.loads(value)}))
filters.append(QueryFilterFactory.get_query_filter(
{arg: json.loads(value)}))
return filters
3 changes: 1 addition & 2 deletions dev-requirements.in
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
Faker == 2.0.2
pyyaml == 5.1.2
Faker == 2.0.2
3 changes: 1 addition & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# This file is autogenerated by pip-compile
# To update, run:
#
# pip-compile '.\dev-requirements.in'
# pip-compile dev-requirements.in
#
faker==2.0.2
python-dateutil==2.8.0 # via faker
pyyaml==5.1.2
six==1.12.0 # via faker, python-dateutil
text-unidecode==1.3 # via faker
3 changes: 3 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ flask_restful == 0.3.7
sqlalchemy == 1.3.8
pymysql == 0.9.3
flask-cors == 3.0.8
apispec == 3.3.0
flask-swagger-ui == 3.25.0
pyyaml == 5.1.2
7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
# This file is autogenerated by pip-compile
# To update, run:
#
# pip-compile '.\requirements.in'
# pip-compile requirements.in
#
aniso8601==8.0.0 # via flask-restful
apispec==3.3.0
click==7.0 # via flask
flask-cors==3.0.8
flask==1.1.1 # via flask-cors, flask-restful
flask-swagger-ui==3.25.0
flask==1.1.1 # via flask-cors, flask-restful, flask-swagger-ui
flask_restful==0.3.7
itsdangerous==1.1.0 # via flask
jinja2==2.10.1 # via flask
markupsafe==1.1.1 # via jinja2
pymysql==0.9.3
pytz==2019.2 # via flask-restful
pyyaml==5.1.2
six==1.12.0 # via flask-cors, flask-restful
sqlalchemy==1.3.8
werkzeug==0.16.0 # via flask
81 changes: 71 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Flask
from flask_cors import CORS
from flask_restful import Api
from flask_swagger_ui import get_swaggerui_blueprint

from common.config import config
from common.logger_setup import setup_logger
Expand All @@ -11,35 +12,95 @@
from src.resources.table_endpoints.table_endpoints import UsersInvestigations, UsersInvestigationsCount, \
InstrumentsFacilityCycles, InstrumentsFacilityCyclesCount, InstrumentsFacilityCyclesInvestigations, \
InstrumentsFacilityCyclesInvestigationsCount
from src.swagger.swagger_generator import swagger_gen

swagger_gen.write_swagger_spec()
from apispec import APISpec
from pathlib import Path
import json
from src.swagger.apispec_flask_restful import RestfulPlugin
from src.swagger.initialise_spec import initialise_spec


spec = APISpec(title="DataGateway API", version="1.0", openapi_version="3.0.3",
plugins=[RestfulPlugin()], security=[{"session_id": []}])

app = Flask(__name__)
cors = CORS(app)
app.url_map.strict_slashes = False
api = Api(app)

swaggerui_blueprint = get_swaggerui_blueprint(
"",
"/openapi.json",
config={
'app_name': "DataGateway API OpenAPI Spec"
},
)

app.register_blueprint(swaggerui_blueprint, url_prefix="/")

setup_logger()

initialise_spec(spec)

for entity_name in endpoints:
api.add_resource(get_endpoint(entity_name, endpoints[entity_name]), f"/{entity_name.lower()}")
api.add_resource(get_id_endpoint(entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/<int:id>")
api.add_resource(get_count_endpoint(entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/count")
api.add_resource(get_find_one_endpoint(entity_name, endpoints[entity_name]), f"/{entity_name.lower()}/findone")
get_endpoint_resource = get_endpoint(entity_name, endpoints[entity_name])
api.add_resource(get_endpoint_resource, f"/{entity_name.lower()}")
spec.path(resource=get_endpoint_resource, api=api)

get_id_endpoint_resource = get_id_endpoint(
entity_name, endpoints[entity_name])
api.add_resource(get_id_endpoint_resource,
f"/{entity_name.lower()}/<int:id>")
spec.path(resource=get_id_endpoint_resource, api=api)

get_count_endpoint_resource = get_count_endpoint(
entity_name, endpoints[entity_name])
api.add_resource(get_count_endpoint_resource,
f"/{entity_name.lower()}/count")
spec.path(resource=get_count_endpoint_resource, api=api)

get_find_one_endpoint_resource = get_find_one_endpoint(
entity_name, endpoints[entity_name])
api.add_resource(get_find_one_endpoint_resource,
f"/{entity_name.lower()}/findone")
spec.path(resource=get_find_one_endpoint_resource, api=api)


# Session endpoint
api.add_resource(Sessions, "/sessions")
spec.path(resource=Sessions, api=api)

# Table specific endpoints
api.add_resource(UsersInvestigations, "/users/<int:id>/investigations")
api.add_resource(UsersInvestigationsCount, "/users/<int:id>/investigations/count")
api.add_resource(InstrumentsFacilityCycles, "/instruments/<int:id>/facilitycycles")
api.add_resource(InstrumentsFacilityCyclesCount, "/instruments/<int:id>/facilitycycles/count")
spec.path(resource=UsersInvestigations, api=api)
api.add_resource(UsersInvestigationsCount,
"/users/<int:id>/investigations/count")
spec.path(resource=UsersInvestigationsCount, api=api)
api.add_resource(InstrumentsFacilityCycles,
"/instruments/<int:id>/facilitycycles")
spec.path(resource=InstrumentsFacilityCycles, api=api)
api.add_resource(InstrumentsFacilityCyclesCount,
"/instruments/<int:id>/facilitycycles/count")
spec.path(resource=InstrumentsFacilityCyclesCount, api=api)
api.add_resource(InstrumentsFacilityCyclesInvestigations,
"/instruments/<int:instrument_id>/facilitycycles/<int:cycle_id>/investigations")
spec.path(resource=InstrumentsFacilityCyclesInvestigations, api=api)
api.add_resource(InstrumentsFacilityCyclesInvestigationsCount,
"/instruments/<int:instrument_id>/facilitycycles/<int:cycle_id>/investigations/count")
spec.path(resource=InstrumentsFacilityCyclesInvestigationsCount, api=api)

openapi_spec_path = Path(__file__).parent / "swagger/openapi.yaml"
with open(openapi_spec_path, "w") as f:
f.write(spec.to_yaml())


@app.route("/openapi.json")
def specs():
resp = app.make_response(json.dumps(spec.to_dict(), indent=2))
resp.mimetype = "application/json"
return resp


if __name__ == "__main__":
app.run(host=config.get_host(), port=config.get_port(), debug=config.is_debug_mode())
app.run(host=config.get_host(), port=config.get_port(),
debug=config.is_debug_mode())
Loading

0 comments on commit 130a736

Please sign in to comment.