diff --git a/examples/demonstration/router_api.py b/examples/demonstration/router_api.py index 984382e..6f18ddb 100644 --- a/examples/demonstration/router_api.py +++ b/examples/demonstration/router_api.py @@ -28,5 +28,6 @@ def sim_auth(request: Request): max_age=request.auth.expires, expires=request.auth.expires, httponly=request.auth.http, + samesite=request.auth.samesite, ) return response diff --git a/src/fastapi_oauth2/config.py b/src/fastapi_oauth2/config.py index 6bdcac9..954247c 100644 --- a/src/fastapi_oauth2/config.py +++ b/src/fastapi_oauth2/config.py @@ -10,6 +10,7 @@ class OAuth2Config: enable_ssr: bool allow_http: bool + samesite: str jwt_secret: str jwt_expires: int jwt_algorithm: str @@ -20,6 +21,7 @@ def __init__( *, enable_ssr: bool = True, allow_http: bool = False, + samesite: str = "lax", jwt_secret: str = "", jwt_expires: Union[int, str] = 900, jwt_algorithm: str = "HS256", @@ -29,6 +31,7 @@ def __init__( os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" self.enable_ssr = enable_ssr self.allow_http = allow_http + self.samesite = samesite self.jwt_secret = jwt_secret self.jwt_expires = int(jwt_expires) self.jwt_algorithm = jwt_algorithm diff --git a/src/fastapi_oauth2/core.py b/src/fastapi_oauth2/core.py index 1dbfaa3..c226e55 100644 --- a/src/fastapi_oauth2/core.py +++ b/src/fastapi_oauth2/core.py @@ -145,6 +145,7 @@ async def token_redirect(self, request: Request, **kwargs) -> RedirectResponse: expires=request.auth.expires, secure=not request.auth.http, httponly=True, + samesite=request.auth.samesite, ) return response diff --git a/src/fastapi_oauth2/middleware.py b/src/fastapi_oauth2/middleware.py index 76ee47e..6ff5eb8 100644 --- a/src/fastapi_oauth2/middleware.py +++ b/src/fastapi_oauth2/middleware.py @@ -37,6 +37,7 @@ class Auth(AuthCredentials): ssr: bool http: bool + samesite: str secret: str expires: int algorithm: str @@ -90,6 +91,7 @@ def __init__( ) -> None: Auth.ssr = config.enable_ssr Auth.http = config.allow_http + Auth.samesite = config.samesite Auth.secret = config.jwt_secret Auth.expires = config.jwt_expires Auth.algorithm = config.jwt_algorithm diff --git a/tests/conftest.py b/tests/conftest.py index 5766231..709b882 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,6 +75,7 @@ def auth(request: Request): max_age=request.auth.expires, expires=request.auth.expires, httponly=request.auth.http, + samesite=request.auth.samesite, ) return response