Skip to content

Commit

Permalink
GH-4: Using python-inject to inject the UserRepository
Browse files Browse the repository at this point in the history
  • Loading branch information
Sparrow0hawk committed Oct 13, 2023
1 parent 9dac19f commit 902a47a
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 35 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies = [
"flask~=2.3.0",
"govuk-frontend-jinja~=2.7.0",
"gunicorn~=21.2.0",
"inject~=5.0.0",
"python-dotenv~=1.0.0",
"requests~=2.31.0"
]
Expand Down
23 changes: 14 additions & 9 deletions schemes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from typing import Any, Mapping

import inject
from authlib.integrations.flask_client import OAuth
from authlib.oauth2.rfc7523 import PrivateKeyJWT
from flask import Flask, Response, request, url_for
from inject import Binder
from jinja2 import ChoiceLoader, FileSystemLoader, PackageLoader, PrefixLoader

from schemes import api, auth, home, start
Expand All @@ -19,10 +21,13 @@ def create_app(test_config: Mapping[str, Any] | None = None) -> Flask:
app.config.from_prefixed_env()
app.config.from_mapping(test_config)

inject.configure(_bindings)

_configure_basic_auth(app)
_configure_govuk_frontend(app)
_configure_oidc(app)
_configure_users(app)
if not app.testing:
_configure_users()

app.register_blueprint(start.bp)
app.register_blueprint(auth.bp, url_prefix="/auth")
Expand All @@ -33,6 +38,10 @@ def create_app(test_config: Mapping[str, Any] | None = None) -> Flask:
return app


def _bindings(binder: Binder) -> None:
binder.bind(UserRepository, UserRepository())


def _configure_basic_auth(app: Flask) -> None:
username = app.config.get("BASIC_AUTH_USERNAME")
password = app.config.get("BASIC_AUTH_PASSWORD")
Expand Down Expand Up @@ -74,11 +83,7 @@ def _configure_oidc(app: Flask) -> None:
)


def _configure_users(app: Flask) -> None:
users = UserRepository()

if not app.testing:
users.add(User("alex.coleman@activetravelengland.gov.uk"))
users.add(User("mark.hobson@activetravelengland.gov.uk"))

app.extensions["users"] = users
def _configure_users() -> None:
users = inject.instance(UserRepository)
users.add(User("alex.coleman@activetravelengland.gov.uk"))
users.add(User("mark.hobson@activetravelengland.gov.uk"))
11 changes: 6 additions & 5 deletions schemes/api.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from flask import Blueprint, Response, current_app, request
import inject
from flask import Blueprint, Response, request

from schemes.users import User, UserRepository

bp = Blueprint("api", __name__)


@bp.route("/users", methods=["POST"])
def add_user() -> Response:
@inject.autoparams()
def add_user(users: UserRepository) -> Response:
user = User(request.get_json()["email"])
users: UserRepository = current_app.extensions["users"]
users.add(user)
return Response(status=201)


@bp.route("/users", methods=["DELETE"])
def clear_users() -> Response:
users: UserRepository = current_app.extensions["users"]
@inject.autoparams()
def clear_users(users: UserRepository) -> Response:
users.clear()
return Response(status=204)
9 changes: 5 additions & 4 deletions schemes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, ParamSpec, TypeVar
from urllib.parse import urlencode, urlparse

import inject
from authlib.integrations.flask_client import OAuth
from authlib.oidc.core import UserInfo
from flask import (
Expand All @@ -21,12 +22,13 @@


@bp.route("")
def callback() -> BaseResponse:
@inject.autoparams()
def callback(users: UserRepository) -> BaseResponse:
oauth = _get_oauth()
token = oauth.govuk.authorize_access_token()
user = oauth.govuk.userinfo(token=token)

if not _is_authorized(user):
if not _is_authorized(users, user):
return redirect(url_for("auth.unauthorized"))

session["user"] = user
Expand Down Expand Up @@ -69,8 +71,7 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> T | Response:
return decorated_function


def _is_authorized(user: UserInfo) -> bool:
users: UserRepository = current_app.extensions["users"]
def _is_authorized(users: UserRepository, user: UserInfo) -> bool:
return users.get(user["email"]) is not None


Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from typing import Any, Generator

import inject
import pytest
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
Expand Down Expand Up @@ -54,6 +55,7 @@ def app_fixture(oidc_server: LiveServer) -> Generator[Flask, Any, Any]:
)
)
yield app
inject.clear()
oidc_client.clear_clients()


Expand Down
8 changes: 5 additions & 3 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Mapping
from typing import Any, Generator, Mapping

import inject
import pytest
from flask import Flask
from flask.testing import FlaskClient
Expand All @@ -22,8 +23,9 @@ def config_fixture() -> Mapping[str, Any]:


@pytest.fixture(name="app")
def app_fixture(config: Mapping[str, Any]) -> Flask:
return create_app(config)
def app_fixture(config: Mapping[str, Any]) -> Generator[Flask, Any, Any]:
yield create_app(config)
inject.clear()


@pytest.fixture(name="client")
Expand Down
19 changes: 12 additions & 7 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from typing import Any, Mapping

import inject
import pytest
from flask import current_app
from flask.testing import FlaskClient

from schemes.users import User
from schemes.users import User, UserRepository


def test_add_user(client: FlaskClient) -> None:
@pytest.fixture(name="users")
def users_fixture() -> UserRepository:
return inject.instance(UserRepository)


def test_add_user(users: UserRepository, client: FlaskClient) -> None:
response = client.post("/api/users", json={"email": "boardman@example.com"})

assert response.status_code == 201
assert current_app.extensions["users"].get("boardman@example.com")
assert users.get("boardman@example.com")


def test_clear_users(client: FlaskClient) -> None:
current_app.extensions["users"].add(User("boardman@example.com"))
def test_clear_users(users: UserRepository, client: FlaskClient) -> None:
users.add(User("boardman@example.com"))

response = client.delete("/api/users")

assert response.status_code == 204
assert not current_app.extensions["users"].get_all()
assert not users.get_all()


class TestProduction:
Expand Down
20 changes: 13 additions & 7 deletions tests/integration/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any, Mapping
from unittest.mock import Mock

import inject
import pytest
from authlib.integrations.flask_client import OAuth
from authlib.oidc.core import UserInfo
from flask import current_app, session
from flask.testing import FlaskClient

from schemes.users import User
from schemes.users import User, UserRepository
from tests.integration.pages import UnauthorizedPage


Expand All @@ -16,8 +17,13 @@ def config_fixture(config: Mapping[str, Any]) -> Mapping[str, Any]:
return config | {"GOVUK_END_SESSION_ENDPOINT": "https://example.com/logout"}


def test_callback_logs_in(client: FlaskClient) -> None:
current_app.extensions["users"].add(User("boardman@example.com"))
@pytest.fixture(name="users")
def users_fixture() -> UserRepository:
return inject.instance(UserRepository)


def test_callback_logs_in(users: UserRepository, client: FlaskClient) -> None:
users.add(User("boardman@example.com"))
_given_oidc_returns_token_response({"id_token": "jwt"})
_given_oidc_returns_user_info(UserInfo({"email": "boardman@example.com"}))

Expand All @@ -27,8 +33,8 @@ def test_callback_logs_in(client: FlaskClient) -> None:
assert session["user"] == UserInfo({"email": "boardman@example.com"}) and session["id_token"] == "jwt"


def test_callback_redirects_to_home(client: FlaskClient) -> None:
current_app.extensions["users"].add(User("boardman@example.com"))
def test_callback_redirects_to_home(users: UserRepository, client: FlaskClient) -> None:
users.add(User("boardman@example.com"))
_given_oidc_returns_token_response({"id_token": "jwt"})
_given_oidc_returns_user_info(UserInfo({"email": "boardman@example.com"}))

Expand All @@ -37,8 +43,8 @@ def test_callback_redirects_to_home(client: FlaskClient) -> None:
assert response.status_code == 302 and response.location == "/home"


def test_callback_when_unauthorized_redirects_to_unauthorized(client: FlaskClient) -> None:
current_app.extensions["users"].add(User("boardman@example.com"))
def test_callback_when_unauthorized_redirects_to_unauthorized(users: UserRepository, client: FlaskClient) -> None:
users.add(User("boardman@example.com"))
_given_oidc_returns_token_response({"id_token": "jwt"})
_given_oidc_returns_user_info(UserInfo({"email": "obree@example.com"}))

Expand Down

0 comments on commit 902a47a

Please sign in to comment.