diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml new file mode 100644 index 0000000..b2412f6 --- /dev/null +++ b/.github/workflows/run-tests.yml @@ -0,0 +1,45 @@ +name: JupyterHealth SMART on FHIR test suite + +on: + push: + branches: [ main ] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.12] # extend if needed + + steps: + - uses: actions/checkout@v4 + - name: Use Node.js + uses: actions/setup-node@v4 + with: + node-version: '18.x' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[testing] + - name: Install SMART sandbox + run: | + git clone https://github.com/smart-on-fhir/smart-launcher-v2.git + cd smart-launcher-v2 + git switch -c aa0f3b1 # Fix the version we use for the sandbox + npm ci + npm run build + env: + PORT: 5555 + - name: Run tests + run: | + pytest tests/ + env: + SANDBOX_DIR: ${{ github.workspace }}/smart-launcher-v2 + + \ No newline at end of file diff --git a/jupyter_smart_on_fhir/auth.py b/jupyter_smart_on_fhir/auth.py index 82e3228..e756442 100644 --- a/jupyter_smart_on_fhir/auth.py +++ b/jupyter_smart_on_fhir/auth.py @@ -60,7 +60,7 @@ def get_jwks_from_key(key_file: Path, key_id: str = "1") -> str: jwk = alg.to_jwk(key, as_dict=True) jwk.update({"alg": "RS256", "kid": key_id}) jwks_smart = {"keys": [jwk]} - jwks_smart_str = json.dumps(jwks_smart, indent=2) + jwks_smart_str = json.dumps(jwks_smart) return jwks_smart_str diff --git a/jupyter_smart_on_fhir/hub_service.py b/jupyter_smart_on_fhir/hub_service.py index 3d08ef1..21a7196 100755 --- a/jupyter_smart_on_fhir/hub_service.py +++ b/jupyter_smart_on_fhir/hub_service.py @@ -8,8 +8,15 @@ import time import secrets from functools import wraps -from flask import Flask, Response, make_response, redirect, request, session -from jupyterhub.services.auth import HubOAuth +from flask import ( + Flask, + Response, + make_response, + redirect, + request, + session, + current_app, +) import requests from urllib.parse import urlencode import jwt @@ -18,15 +25,60 @@ from cryptography.fernet import Fernet, InvalidToken prefix = os.environ.get("JUPYTERHUB_SERVICE_PREFIX", "/") -auth = HubOAuth(api_token=os.environ["JUPYTERHUB_API_TOKEN"], cache_max_age=60) -app = Flask(__name__) -# encryption key for session cookies -secret_key = base64.urlsafe_b64encode(secrets.token_bytes(32)) -app.config["fernet"] = Fernet(secret_key) -# settings passed from the Hub -app.config["client_id"] = os.environ["CLIENT_ID"] -app.config["keys"] = validate_keys() + +def create_app(): + """Create the Flask app with configuration""" + app = Flask(__name__) + # encryption key for session cookies + secret_key = secrets.token_bytes(32) + app.secret_key = secret_key + app.config["fernet"] = Fernet(base64.urlsafe_b64encode(secret_key)) + # settings passed from the Hub + app.config["client_id"] = os.environ["CLIENT_ID"] + app.config["keys"] = validate_keys() + + @app.route(prefix) + @authenticated + def fetch_data(token: str) -> Response: + """Fetch data from a FHIR endpoint""" + headers = { + "Authorization": f"Bearer {token}", + "Accept": "application/fhir+json", + "User-Agent": "JupyterHub", + } + url = f"{session['smart_config']['fhir_url']}/Condition" # Endpoint with data + f = requests.get(url, headers=headers) + return Response(f.text, mimetype="application/json") + + @app.route(prefix + "oauth_callback") + def callback() -> Response: + """Callback endpoint to finish OAuth flow""" + state_id = get_encrypted_cookie("state_id") + if not state_id: + return Response("No local state ID cookie found", status=400) + next_url = get_encrypted_cookie("next_url") or "/" + + if error := request.args.get("error", False): + return Response( + f"Error in OAuth: {request.args.get('error_description', error)}", + status=400, + ) + code = request.args.get("code") + if not code: + return Response( + "OAuth callback did not return authorization code", status=400 + ) + arg_state = request.args.get("state", None) + if arg_state != state_id: + return Response( + "OAuth state does not match. Try logging in again.", status=403 + ) + token = token_for_code(code) + set_encrypted_cookie("smart_token", token) + return make_response(redirect(next_url)) + + return app def get_encrypted_cookie(key: str) -> str | None: @@ -34,7 +86,7 @@ def get_encrypted_cookie(key: str) -> str | None: cookie = session.get(key) if cookie: try: - return app.config["fernet"].decrypt(cookie).decode("ascii") + return current_app.config["fernet"].decrypt(cookie).decode("ascii") except InvalidToken: pass # maybe warn return None @@ -42,19 +94,19 @@ def get_encrypted_cookie(key: str) -> str | None: def set_encrypted_cookie(key: str, value: str): """Store an encrypted cookie""" - session[key] = app.config["fernet"].encrypt(value.encode("ascii")) + session[key] = current_app.config["fernet"].encrypt(value.encode("ascii")) def generate_jwt() -> str: """Generate a JWT for the SMART asymmetric client authentication""" jwt_dict = { - "iss": app.config["client_id"], - "sub": app.config["client_id"], + "iss": current_app.config["client_id"], + "sub": current_app.config["client_id"], "aud": session["smart_config"]["token_url"], "jti": "jwt_id", "exp": int(time.time() + 3600), } - ((key_id, private_key_path),) = app.config["keys"].items() + ((key_id, private_key_path),) = current_app.config["keys"].items() with open(private_key_path, "rb") as f: private_key = f.read() headers = {"kid": key_id} @@ -64,7 +116,7 @@ def generate_jwt() -> str: def token_for_code(code: str) -> str: """Exchange an authorization code for an access token""" data = dict( - client_id=app.config["client_id"], + client_id=current_app.config["client_id"], grant_type="authorization_code", code=code, redirect_uri=session["smart_config"]["base_url"] + "oauth_callback", @@ -88,10 +140,14 @@ def authenticated(f): @wraps(f) def decorated(*args, **kwargs): + if "iss" not in request.args: + return Response( + "GET request misses 'iss' argument. Was service launched from EHR?", + status=400, + ) if token := get_encrypted_cookie("smart_token"): return f(token, *args, **kwargs) - else: session["smart_config"] = SMARTConfig.from_url( request.args["iss"], @@ -116,47 +172,9 @@ def start_oauth_flow(state_id: str, scopes: list[str] | None = None) -> Response "state": state_id, "redirect_uri": redirect_uri, "launch": request.args["launch"], - "client_id": app.config["client_id"], + "client_id": current_app.config["client_id"], "response_type": "code", "scopes": " ".join(scopes), } auth_url = f"{config.auth_url}?{urlencode(headers)}" return redirect(auth_url) - - -@app.route(prefix) -@authenticated -def fetch_data(token: str) -> Response: - """Fetch data from a FHIR endpoint""" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/fhir+json", - "User-Agent": "JupyterHub", - } - url = f"{session['smart_config']['fhir_url']}/Condition" # Endpoint with data - f = requests.get(url, headers=headers) - return Response(f.text, mimetype="application/json") - - -@app.route(prefix + "oauth_callback") -def callback() -> Response: - """Callback endpoint to finish OAuth flow""" - state_id = get_encrypted_cookie("state_id") - if not state_id: - return Response("No local state ID cookie found", status=400) - next_url = get_encrypted_cookie("next_url") or "/" - - if error := request.args.get("error", False): - return Response( - f"Error in OAuth: {request.args.get('error_description', error)}", - status=400, - ) - code = request.args.get("code") - if not code: - return Response("OAuth callback made without a token", status=400) - arg_state = request.args.get("state", None) - if arg_state != state_id: - return Response("OAuth state does not match. Try logging in again.", status=403) - token = token_for_code(code) - set_encrypted_cookie("smart_token", token) - return make_response(redirect(next_url)) diff --git a/jupyter_smart_on_fhir/server_extension.py b/jupyter_smart_on_fhir/server_extension.py index 7e57a8a..7143555 100644 --- a/jupyter_smart_on_fhir/server_extension.py +++ b/jupyter_smart_on_fhir/server_extension.py @@ -31,7 +31,7 @@ class SMARTExtensionApp(ExtensionApp): ).tag(config=True) client_id = Unicode( - help="""Client ID for the SMART application""", default_value="test-id" + help="""Client ID for the SMART application""", default_value="test_id" ).tag(config=True) def initialize_settings(self): @@ -88,7 +88,10 @@ class SMARTLoginHandler(JupyterHandler): @tornado.web.authenticated def get(self): state = generate_state() - self.set_secure_cookie(**state) + self.set_secure_cookie("state_id", state["state_id"]) + if state["next_url"]: + self.set_secure_cookie("next_url", state["next_url"]) + scopes = self.settings["scopes"] smart_config = self.settings["smart_config"] auth_url = smart_config.auth_url @@ -98,7 +101,7 @@ def get(self): code_challenge = base64.urlsafe_b64encode(code_challenge_b).rstrip(b"=") headers = { "aud": smart_config.fhir_url, - "state": state["id"], + "state": state["state_id"], "launch": self.settings["launch"], "redirect_uri": urljoin(self.request.full_url(), callback_path), "client_id": self.settings["client_id"], diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c40c22b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,96 @@ +import pytest +import os +import subprocess +import requests +import time +from dataclasses import asdict, field, dataclass +import base64 +import json +from urllib import parse + + +@pytest.fixture(scope="function") # module? +def sandbox(): + port = 5555 + os.environ["PORT"] = str(port) + url = f"http://localhost:{port}" + with subprocess.Popen( + ["npm", "run", "start:prod"], cwd=os.environ.get("SANDBOX_DIR", ".") + ) as sandbox_proc: + wait_for_server(url) + yield url + sandbox_proc.terminate() + + +def wait_for_server(url): + for _ in range(10): + try: + response = requests.get(url) + if response.status_code == 200: + break + except requests.ConnectionError: + pass + time.sleep(1) # Wait for 1 second before retrying + else: + raise requests.ConnectionError(f"Cannot connect to {url}") + + +@dataclass +class SandboxConfig: + """Taken from smart-on-fhir/smart-launcher-v2.git:src/isomorphic/LaunchOptions.ts. + The sandbox reads its configuration from the url it is launched with. + This means we don't have to introduce a webdriver to manipulate its behaviour, + but instead we can reverse engineer the parameters to set the necessary parameters + """ + + # Caveat: make sure not to change the order + # the smart sandbox uses list indices to evaluate which value is which property + # Assumes client identity validation = true + launch_type: int = 0 # provider EHR launch + patient_ids: list[str] = field(default_factory=list) + provider_ids: list[str] = field(default_factory=list) + encounter_type: str = "AUTO" + misc_skip_login: int = 0 # not compatible with provider EHR launch + misc_skip_auth: int = 0 # not compatible with provider EHR launch + misc_simulate_launch: int = 0 # don't simulate launch within EHR UI + allowed_scopes: set[str] = field(default_factory=set) + redirect_uris: list[str] = field(default_factory=list) + client_id: str = "client_id" + client_secret: str = "" + auth_error: int = 0 # simulate no error + jwks_url: str = "" + jwks: str = "" + client_type: int = 0 # 0 (public), 1 (symmetric), 2 (asymmetric) + pkce_validation: int = 1 # 0 (none), 1 (auto), 2 (always) + fhir_base_url: str = "" # arranged server side + + def get_launch_code(self) -> str: + """The sandbox settings are encoded in a base64 JSON object. + Enforcing settings/procedures needs to be done here + """ + + attr_list = [] + for val in asdict(self).values(): + if isinstance(val, int) or isinstance(val, str): + attr_list.append(val) + elif isinstance(val, list): + attr_list.append(", ".join(val)) + elif isinstance(val, set): + attr_list.append(" ".join(val)) + attr_repr = json.dumps(attr_list) + return base64.b64encode(attr_repr.encode("utf-8")) + + def get_url_query( + self, launch_url: str, validation: bool = True, fhir_version: str = "r4" + ) -> str: + """Provide the entire URL query that loads the sandbox with the given settings. + Requires appending with base url""" + query = parse.urlencode( + { + "launch": self.get_launch_code(), + "launch_url": launch_url, + "validation": int(validation), + "fhir_version": fhir_version, + } + ) + return query diff --git a/tests/test_hub_service.py b/tests/test_hub_service.py index e69de29..72edeaf 100644 --- a/tests/test_hub_service.py +++ b/tests/test_hub_service.py @@ -0,0 +1,160 @@ +import pytest +from flask import session +from jupyter_smart_on_fhir.auth import get_jwks_from_key +from jupyter_smart_on_fhir.hub_service import ( + create_app, + set_encrypted_cookie, + get_encrypted_cookie, + prefix, + token_for_code, +) +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +import os +import requests +from conftest import SandboxConfig +from urllib import parse + + +@pytest.fixture(scope="module") +def keys(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("keys") + private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) + + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + key_name = "jwtRS256.key" + private_key_path = tmp_path / key_name + public_key_path = tmp_path / f"{key_name}.pub" + private_key_path.write_bytes(private_pem) + public_key_path.write_bytes(public_pem) + return {"SSH_KEY_PATH": str(private_key_path), "SSH_KEY_ID": "test_key"} + + +@pytest.fixture(scope="function") +def mock_env(monkeypatch, keys): + env = { + "JUPYTERHUB_API_TOKEN": os.getenv("JUPYTERHUB_API_TOKEN", "API_TOKEN"), + "CLIENT_ID": os.getenv("CLIENT_ID", "CLIENT_ID"), + "SCOPES": os.getenv("SCOPES", "launch profile patient/*.*"), + } | keys + for key, value in env.items(): + monkeypatch.setenv(key, value) + return env + + +@pytest.fixture(scope="function") +def test_app(mock_env): + app = create_app() + app.config["TESTING"] = True + return app + + +@pytest.fixture(scope="function") +def asymmetric_auth(keys): + jwks = get_jwks_from_key(keys["SSH_KEY_PATH"], keys["SSH_KEY_ID"]) + return SandboxConfig(client_id=os.environ["CLIENT_ID"], jwks=jwks, client_type=2) + + +@pytest.fixture(scope="function") +def client(test_app): + return test_app.test_client() + + +def test_ehr_launch(client): + response = client.get("/?token=hello") + assert response.status_code == 400 + + +@pytest.mark.parametrize( + "key,value", + [ + ("test_key", "test_value"), + ("user_id", "12345"), + ("session_token", "abcdef123456"), + ("empty_value", ""), + ("special_chars", "!@#$%^&*()_+"), + ], +) +def test_encrypted_cookie(test_app, key, value): + with test_app.test_request_context(): + session.clear() + # Set the encrypted cookie + set_encrypted_cookie(key, value) + # Verify the cookie is in the session and encrypted + assert key in session + assert session[key] != value + # Get the decrypted cookie value + decrypted_value = get_encrypted_cookie(key) + # Verify the decrypted value matches the original + assert decrypted_value == value + + +def test_get_nonexistent_cookie(test_app): + with test_app.test_request_context(): + session.clear() + value = get_encrypted_cookie("nonexistent_key") + assert value is None + + +def test_invalid_token(test_app): + with test_app.test_request_context(): + session.clear() + session["invalid_key"] = b"invalid_token" + value = get_encrypted_cookie("invalid_key") + assert value is None + + +def test_access_sandbox(sandbox): + f = requests.get(sandbox) + print(f.status_code, f.text) + assert f.status_code == 200 + + +def test_to_auth_url(sandbox, client, asymmetric_auth): + # start with oauth flow + query = {"iss": f"{sandbox}/v/r4/fhir", "launch": asymmetric_auth.get_launch_code()} + scopes = ["launch", "profile"] + os.environ["SCOPES"] = " ".join(scopes) + # set up test context + with client.application.test_request_context(): + # Launch request from sandbox with given settings + response = client.get(f"/?{parse.urlencode(query)}") + # Expecting a redirect to the login page + assert response.status_code == 302 + + auth_url = response.headers["Location"] + # Ensure auth url has correct domain and scopes + assert auth_url.startswith(sandbox) + assert "+".join(scopes) in auth_url + + # Check if auth_url passes strict client validation and provides code + f = requests.get(auth_url, allow_redirects=False) + assert f.status_code == 302 + + callback_url = f.headers["Location"] + assert "code" in callback_url + assert prefix + "oauth_callback" in callback_url + parsed_url = parse.urlparse(callback_url) + qs = parse.parse_qs(parsed_url.query) + code = qs["code"][0] + # Test if we can exchange the code for a token with asymmetric validation + # with client.session_transaction() as sess: + # sess["smart_config"] = { + # "base_url": sandbox, + # "fhir_url": f"{sandbox}/v/r4/fhir", + # "token_url": f"{sandbox}/token", + # "auth_url": f"{sandbox}/authorize", + # "scopes": scopes, + # } + # token = token_for_code(code) + # assert isinstance(token, str) + # FIXME: The session seems to be empty, but only for the token_for_code method. Confusing diff --git a/tests/test_server_extension.py b/tests/test_server_extension.py index e69de29..a86dd78 100644 --- a/tests/test_server_extension.py +++ b/tests/test_server_extension.py @@ -0,0 +1,97 @@ +import os +import subprocess +import requests +import pytest +from conftest import wait_for_server, SandboxConfig +from jupyter_smart_on_fhir.server_extension import smart_path, login_path, callback_path + +PORT = os.getenv("TEST_PORT", 18888) +ext_url = f"http://localhost:{PORT}" + + +def request_api(url, session=None, params=None, **kwargs): + query_args = {"token": "secret"} + query_args.update(params or {}) + session = session or requests.Session() + return session.get(url, params=query_args, **kwargs) + + +@pytest.fixture +def jupyterdir(tmpdir): + path = tmpdir.join("jupyter") + path.mkdir() + return str(path) + + +@pytest.fixture +def jupyter_server(tmpdir, jupyterdir): + client_id = os.environ["CLIENT_ID"] = "client_id" + env = os.environ.copy() + # avoid interacting with user configuration, state + env["JUPYTER_CONFIG_DIR"] = str(tmpdir / "dotjupyter") + env["JUPYTER_RUNTIME_DIR"] = str(tmpdir / "runjupyter") + + extension_command = ["jupyter", "server", "extension"] + command = [ + "jupyter-server", + "--ServerApp.token=secret", + "--SMARTExtensionApp.client_id={}".format(client_id), + "--port={}".format(PORT), + ] + subprocess.check_call( + extension_command + ["enable", "jupyter_smart_on_fhir.server_extension"], + env=env, + ) + + # launch the server + with subprocess.Popen(command, cwd=jupyterdir, env=env) as jupyter_proc: + wait_for_server(ext_url) + yield jupyter_proc + jupyter_proc.terminate() + + +def test_uninformed_endpoint(jupyter_server): + response = request_api(ext_url + smart_path) + assert response.status_code == 400 + + +@pytest.fixture(scope="function") +def public_client(): + return SandboxConfig( + client_id=os.environ["CLIENT_ID"], + client_type=0, + pkce_validation=2, + # setting IDs so we omit login screen in sandbox; unsure I would test that flow + patient_ids=["6bb97c2b-8762-4763-ad16-2d88db590b74"], + provider_ids=["63003abb-3924-46df-a75a-0a1f42733189"], + ) + + +def test_login_handler(jupyter_server, sandbox, public_client): + """I think this test can be splitted in three with some engineering. Perhaps useful, not sure""" + session = requests.Session() + # Try endpoint and get redirected to login + query = {"iss": f"{sandbox}/v/r4/fhir", "launch": public_client.get_launch_code()} + response = request_api( + ext_url + smart_path, params=query, allow_redirects=False, session=session + ) + assert response.status_code == 302 + assert response.headers["Location"] == login_path + + # Login with headers and get redirected to auth url + response = request_api(ext_url + login_path, session=session, allow_redirects=False) + assert response.status_code == 302 + auth_url = response.headers["Location"] + assert auth_url.startswith(sandbox) + + # Internally, get redirected to provider-auth + response = request_api(auth_url, session=session, allow_redirects=False) + assert response.status_code == 302 + callback_url = response.headers["Location"] + assert callback_url.startswith(ext_url + callback_path) + assert "code=" in callback_url + response = request_api(callback_url, session=session) + assert response.status_code == 200 + assert response.url.startswith(ext_url + smart_path) + + # TODO: Should I test token existence? And how?