Skip to content
This repository has been archived by the owner on Jan 22, 2025. It is now read-only.

Cleanup of authentication tests #285

Merged
merged 2 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions backend/tests/endpoints/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ def data_map(course: Course) -> dict[str, Any]:
}

@fixture
def auth_test(request: FixtureRequest, client: FlaskClient, data_map: dict[str, Any]) -> tuple:
def auth_test(
request: FixtureRequest, client: FlaskClient, data_map: dict[str, Any]
) -> tuple[str, Any, str, bool]:
"""Add concrete test data to auth"""
# endpoint, method, token, allowed
endpoint, method, token, *other = request.param
endpoint, method, token, allowed = request.param

for k, v in data_map.items():
endpoint = endpoint.replace(k, str(v))
csrf = get_csrf_from_login(client, token)
return endpoint, getattr(client, method), csrf, *other
csrf = get_csrf_from_login(client, token) if token else None

return endpoint, getattr(client, method), csrf, allowed



Expand Down
13 changes: 5 additions & 8 deletions backend/tests/endpoints/course/courses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,13 @@ class TestCourseEndpoint(TestEndpoint):
### AUTHENTICATION ###
# Where is login required
authentication_tests = \
authentication_tests("/courses", ["get", "post"], ["login"], ["0123456789", ""]) + \
authentication_tests("/courses/@course_id", ["get", "patch", "delete"],
["login"], ["0123456789", ""]) + \
authentication_tests("/courses/@course_id/students", ["get", "post", "delete"],
["login"], ["0123456789", ""]) + \
authentication_tests("/courses/@course_id/admins", ["get", "post", "delete"],
["login"], ["0123456789", ""])
authentication_tests("/courses", ["get", "post"]) + \
authentication_tests("/courses/@course_id", ["get", "patch", "delete"]) + \
authentication_tests("/courses/@course_id/students", ["get", "post", "delete"]) + \
authentication_tests("/courses/@course_id/admins", ["get", "post", "delete"])

@mark.parametrize("auth_test", authentication_tests, indirect=True)
def test_authentication(self, auth_test: tuple[str, Any]):
def test_authentication(self, auth_test: tuple[str, Any, str, bool]):
"""Test the authentication"""
super().authentication(auth_test)

Expand Down
21 changes: 12 additions & 9 deletions backend/tests/endpoints/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
from typing import Any
from pytest import param

def authentication_tests(endpoint: str, methods: list[str],
allowed_tokens: list[str], disallowed_tokens: list[str]) -> list[Any]:
def authentication_tests(endpoint: str, methods: list[str]) -> list[Any]:
"""Transform the format to single authentication tests"""
tests = []

for token in (allowed_tokens + disallowed_tokens):
allowed: bool = token in allowed_tokens
for method in methods:
for method in methods:
for token in [None, "0123456789", "login"]:
allowed = token == "login"
tests.append(param(
(endpoint, method, token, allowed),
id = f"{endpoint} {method.upper()} ({token} {'allowed' if allowed else 'disallowed'})"
(endpoint, method, token, allowed),
id = f"{endpoint} {method.upper()} " \
f"({token} {'allowed' if allowed else 'disallowed'})"
))

return tests
Expand Down Expand Up @@ -84,12 +84,15 @@ def query_parameter_tests(
class TestEndpoint:
"""Base class for endpoint tests"""

def authentication(self, auth_test: tuple[str, Any]):
def authentication(self, auth_test: tuple[str, Any, str, bool]):
"""Test if the authentication for the given endpoint works"""

endpoint, method, csrf, allowed = auth_test

response = method(endpoint, headers = {"X-CSRF-TOKEN":csrf})
if csrf:
response = method(endpoint, headers = {"X-CSRF-TOKEN":csrf})
else:
response = method(endpoint)
assert allowed == (response.status_code != 401)

def authorization(self, auth_test: tuple[str, Any, str, bool]):
Expand Down