Skip to content

Commit

Permalink
Merge pull request #6 from jupyterhealth/sandbox-testing-in-ci
Browse files Browse the repository at this point in the history
Sandbox testing in CI
  • Loading branch information
minrk authored Sep 12, 2024
2 parents 73c9a98 + 8517682 commit bfc44a4
Show file tree
Hide file tree
Showing 7 changed files with 479 additions and 60 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: JupyterHealth SMART on FHIR test suite

on:
push:
branches: [ main ]
pull_request:

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.12] # extend if needed

steps:
- uses: actions/checkout@v4
- name: Use Node.js
uses: actions/setup-node@v4
with:
node-version: '18.x'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[testing]
- name: Install SMART sandbox
run: |
git clone https://github.com/smart-on-fhir/smart-launcher-v2.git
cd smart-launcher-v2
git switch -c aa0f3b1 # Fix the version we use for the sandbox
npm ci
npm run build
env:
PORT: 5555
- name: Run tests
run: |
pytest tests/
env:
SANDBOX_DIR: ${{ github.workspace }}/smart-launcher-v2


2 changes: 1 addition & 1 deletion jupyter_smart_on_fhir/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_jwks_from_key(key_file: Path, key_id: str = "1") -> str:
jwk = alg.to_jwk(key, as_dict=True)
jwk.update({"alg": "RS256", "kid": key_id})
jwks_smart = {"keys": [jwk]}
jwks_smart_str = json.dumps(jwks_smart, indent=2)
jwks_smart_str = json.dumps(jwks_smart)
return jwks_smart_str


Expand Down
130 changes: 74 additions & 56 deletions jupyter_smart_on_fhir/hub_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import time
import secrets
from functools import wraps
from flask import Flask, Response, make_response, redirect, request, session
from jupyterhub.services.auth import HubOAuth
from flask import (
Flask,
Response,
make_response,
redirect,
request,
session,
current_app,
)
import requests
from urllib.parse import urlencode
import jwt
Expand All @@ -18,43 +25,88 @@
from cryptography.fernet import Fernet, InvalidToken

prefix = os.environ.get("JUPYTERHUB_SERVICE_PREFIX", "/")
auth = HubOAuth(api_token=os.environ["JUPYTERHUB_API_TOKEN"], cache_max_age=60)
app = Flask(__name__)

# encryption key for session cookies
secret_key = base64.urlsafe_b64encode(secrets.token_bytes(32))
app.config["fernet"] = Fernet(secret_key)
# settings passed from the Hub
app.config["client_id"] = os.environ["CLIENT_ID"]
app.config["keys"] = validate_keys()

def create_app():
"""Create the Flask app with configuration"""
app = Flask(__name__)
# encryption key for session cookies
secret_key = secrets.token_bytes(32)
app.secret_key = secret_key
app.config["fernet"] = Fernet(base64.urlsafe_b64encode(secret_key))
# settings passed from the Hub
app.config["client_id"] = os.environ["CLIENT_ID"]
app.config["keys"] = validate_keys()

@app.route(prefix)
@authenticated
def fetch_data(token: str) -> Response:
"""Fetch data from a FHIR endpoint"""
headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/fhir+json",
"User-Agent": "JupyterHub",
}
url = f"{session['smart_config']['fhir_url']}/Condition" # Endpoint with data
f = requests.get(url, headers=headers)
return Response(f.text, mimetype="application/json")

@app.route(prefix + "oauth_callback")
def callback() -> Response:
"""Callback endpoint to finish OAuth flow"""
state_id = get_encrypted_cookie("state_id")
if not state_id:
return Response("No local state ID cookie found", status=400)
next_url = get_encrypted_cookie("next_url") or "/"

if error := request.args.get("error", False):
return Response(
f"Error in OAuth: {request.args.get('error_description', error)}",
status=400,
)
code = request.args.get("code")
if not code:
return Response(
"OAuth callback did not return authorization code", status=400
)
arg_state = request.args.get("state", None)
if arg_state != state_id:
return Response(
"OAuth state does not match. Try logging in again.", status=403
)
token = token_for_code(code)
set_encrypted_cookie("smart_token", token)
return make_response(redirect(next_url))

return app


def get_encrypted_cookie(key: str) -> str | None:
"""Fetch and decrypt an encrypted cookie"""
cookie = session.get(key)
if cookie:
try:
return app.config["fernet"].decrypt(cookie).decode("ascii")
return current_app.config["fernet"].decrypt(cookie).decode("ascii")
except InvalidToken:
pass # maybe warn
return None


def set_encrypted_cookie(key: str, value: str):
"""Store an encrypted cookie"""
session[key] = app.config["fernet"].encrypt(value.encode("ascii"))
session[key] = current_app.config["fernet"].encrypt(value.encode("ascii"))


def generate_jwt() -> str:
"""Generate a JWT for the SMART asymmetric client authentication"""
jwt_dict = {
"iss": app.config["client_id"],
"sub": app.config["client_id"],
"iss": current_app.config["client_id"],
"sub": current_app.config["client_id"],
"aud": session["smart_config"]["token_url"],
"jti": "jwt_id",
"exp": int(time.time() + 3600),
}
((key_id, private_key_path),) = app.config["keys"].items()
((key_id, private_key_path),) = current_app.config["keys"].items()
with open(private_key_path, "rb") as f:
private_key = f.read()
headers = {"kid": key_id}
Expand All @@ -64,7 +116,7 @@ def generate_jwt() -> str:
def token_for_code(code: str) -> str:
"""Exchange an authorization code for an access token"""
data = dict(
client_id=app.config["client_id"],
client_id=current_app.config["client_id"],
grant_type="authorization_code",
code=code,
redirect_uri=session["smart_config"]["base_url"] + "oauth_callback",
Expand All @@ -88,10 +140,14 @@ def authenticated(f):

@wraps(f)
def decorated(*args, **kwargs):
if "iss" not in request.args:
return Response(
"GET request misses 'iss' argument. Was service launched from EHR?",
status=400,
)

if token := get_encrypted_cookie("smart_token"):
return f(token, *args, **kwargs)

else:
session["smart_config"] = SMARTConfig.from_url(
request.args["iss"],
Expand All @@ -116,47 +172,9 @@ def start_oauth_flow(state_id: str, scopes: list[str] | None = None) -> Response
"state": state_id,
"redirect_uri": redirect_uri,
"launch": request.args["launch"],
"client_id": app.config["client_id"],
"client_id": current_app.config["client_id"],
"response_type": "code",
"scopes": " ".join(scopes),
}
auth_url = f"{config.auth_url}?{urlencode(headers)}"
return redirect(auth_url)


@app.route(prefix)
@authenticated
def fetch_data(token: str) -> Response:
"""Fetch data from a FHIR endpoint"""
headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/fhir+json",
"User-Agent": "JupyterHub",
}
url = f"{session['smart_config']['fhir_url']}/Condition" # Endpoint with data
f = requests.get(url, headers=headers)
return Response(f.text, mimetype="application/json")


@app.route(prefix + "oauth_callback")
def callback() -> Response:
"""Callback endpoint to finish OAuth flow"""
state_id = get_encrypted_cookie("state_id")
if not state_id:
return Response("No local state ID cookie found", status=400)
next_url = get_encrypted_cookie("next_url") or "/"

if error := request.args.get("error", False):
return Response(
f"Error in OAuth: {request.args.get('error_description', error)}",
status=400,
)
code = request.args.get("code")
if not code:
return Response("OAuth callback made without a token", status=400)
arg_state = request.args.get("state", None)
if arg_state != state_id:
return Response("OAuth state does not match. Try logging in again.", status=403)
token = token_for_code(code)
set_encrypted_cookie("smart_token", token)
return make_response(redirect(next_url))
9 changes: 6 additions & 3 deletions jupyter_smart_on_fhir/server_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SMARTExtensionApp(ExtensionApp):
).tag(config=True)

client_id = Unicode(
help="""Client ID for the SMART application""", default_value="test-id"
help="""Client ID for the SMART application""", default_value="test_id"
).tag(config=True)

def initialize_settings(self):
Expand Down Expand Up @@ -88,7 +88,10 @@ class SMARTLoginHandler(JupyterHandler):
@tornado.web.authenticated
def get(self):
state = generate_state()
self.set_secure_cookie(**state)
self.set_secure_cookie("state_id", state["state_id"])
if state["next_url"]:
self.set_secure_cookie("next_url", state["next_url"])

scopes = self.settings["scopes"]
smart_config = self.settings["smart_config"]
auth_url = smart_config.auth_url
Expand All @@ -98,7 +101,7 @@ def get(self):
code_challenge = base64.urlsafe_b64encode(code_challenge_b).rstrip(b"=")
headers = {
"aud": smart_config.fhir_url,
"state": state["id"],
"state": state["state_id"],
"launch": self.settings["launch"],
"redirect_uri": urljoin(self.request.full_url(), callback_path),
"client_id": self.settings["client_id"],
Expand Down
96 changes: 96 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest
import os
import subprocess
import requests
import time
from dataclasses import asdict, field, dataclass
import base64
import json
from urllib import parse


@pytest.fixture(scope="function") # module?
def sandbox():
port = 5555
os.environ["PORT"] = str(port)
url = f"http://localhost:{port}"
with subprocess.Popen(
["npm", "run", "start:prod"], cwd=os.environ.get("SANDBOX_DIR", ".")
) as sandbox_proc:
wait_for_server(url)
yield url
sandbox_proc.terminate()


def wait_for_server(url):
for _ in range(10):
try:
response = requests.get(url)
if response.status_code == 200:
break
except requests.ConnectionError:
pass
time.sleep(1) # Wait for 1 second before retrying
else:
raise requests.ConnectionError(f"Cannot connect to {url}")


@dataclass
class SandboxConfig:
"""Taken from smart-on-fhir/smart-launcher-v2.git:src/isomorphic/LaunchOptions.ts.
The sandbox reads its configuration from the url it is launched with.
This means we don't have to introduce a webdriver to manipulate its behaviour,
but instead we can reverse engineer the parameters to set the necessary parameters
"""

# Caveat: make sure not to change the order
# the smart sandbox uses list indices to evaluate which value is which property
# Assumes client identity validation = true
launch_type: int = 0 # provider EHR launch
patient_ids: list[str] = field(default_factory=list)
provider_ids: list[str] = field(default_factory=list)
encounter_type: str = "AUTO"
misc_skip_login: int = 0 # not compatible with provider EHR launch
misc_skip_auth: int = 0 # not compatible with provider EHR launch
misc_simulate_launch: int = 0 # don't simulate launch within EHR UI
allowed_scopes: set[str] = field(default_factory=set)
redirect_uris: list[str] = field(default_factory=list)
client_id: str = "client_id"
client_secret: str = ""
auth_error: int = 0 # simulate no error
jwks_url: str = ""
jwks: str = ""
client_type: int = 0 # 0 (public), 1 (symmetric), 2 (asymmetric)
pkce_validation: int = 1 # 0 (none), 1 (auto), 2 (always)
fhir_base_url: str = "" # arranged server side

def get_launch_code(self) -> str:
"""The sandbox settings are encoded in a base64 JSON object.
Enforcing settings/procedures needs to be done here
"""

attr_list = []
for val in asdict(self).values():
if isinstance(val, int) or isinstance(val, str):
attr_list.append(val)
elif isinstance(val, list):
attr_list.append(", ".join(val))
elif isinstance(val, set):
attr_list.append(" ".join(val))
attr_repr = json.dumps(attr_list)
return base64.b64encode(attr_repr.encode("utf-8"))

def get_url_query(
self, launch_url: str, validation: bool = True, fhir_version: str = "r4"
) -> str:
"""Provide the entire URL query that loads the sandbox with the given settings.
Requires appending with base url"""
query = parse.urlencode(
{
"launch": self.get_launch_code(),
"launch_url": launch_url,
"validation": int(validation),
"fhir_version": fhir_version,
}
)
return query
Loading

0 comments on commit bfc44a4

Please sign in to comment.