diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 6f4977145d..3909c70638 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -535,6 +535,8 @@ def check_xsrf_cookie(self) -> None: # Servers without authentication are vulnerable to XSRF return None try: + if not self.check_origin(): + raise web.HTTPError(404) return super().check_xsrf_cookie() except web.HTTPError as e: if self.request.method in {"GET", "HEAD"}: diff --git a/tests/base/test_handlers.py b/tests/base/test_handlers.py index 6d8dce90da..2b5490e90f 100644 --- a/tests/base/test_handlers.py +++ b/tests/base/test_handlers.py @@ -8,6 +8,7 @@ from tornado.httpclient import HTTPClientError from tornado.httpserver import HTTPRequest from tornado.httputil import HTTPHeaders +from tornado.web import HTTPError from jupyter_server.auth import AllowAllAuthorizer, IdentityProvider, User from jupyter_server.auth.decorator import allow_unauthenticated @@ -137,6 +138,112 @@ async def test_jupyter_handler_auth_required(jp_serverapp, jp_fetch): assert exception.value.code == 403 +@pytest.mark.parametrize( + "token_authenticated, disable_check_xsrf, method, check_origin, expected_result", + [ + (True, False, "POST", True, None), # Token-authenticated requests bypass XSRF check + (False, True, "POST", True, None), # XSRF check disabled + (False, False, "GET", True, None), # GET requests don't require XSRF check + (False, False, "POST", True, HTTPError), # Non-authenticated POST should raise HTTPError + (False, False, "POST", False, HTTPError), # Failed origin check should raise HTTPError + ], +) +async def test_check_xsrf_cookie( + jp_serverapp, token_authenticated, disable_check_xsrf, method, check_origin, expected_result +): + class MockHandler(JupyterHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._token_authenticated = token_authenticated + self.request.method = method + self.settings["disable_check_xsrf"] = disable_check_xsrf + self.settings["xsrf_cookies"] = True + self._current_user = True + + # Initialize headers if not present + if not hasattr(self.request, "headers"): + self.request.headers = {} + + # For POST requests that should fail XSRF check + if method == "POST" and not token_authenticated and not disable_check_xsrf: + # Explicitly set mismatched tokens for failing case + self._xsrf_token = "server_token" + self.request.headers["_xsrf"] = "different_token" + self._cookies = {"_xsrf": MagicMock(value="server_token")} + else: + # For passing cases, set matching tokens + self._xsrf_token = "mock_xsrf_token" + self.request.headers["_xsrf"] = "mock_xsrf_token" + self._cookies = {"_xsrf": MagicMock(value="mock_xsrf_token")} + + # Add referer header for GET requests + if method == "GET": + self.request.headers["Referer"] = "http://localhost" + + @property + def token_authenticated(self): + return self._token_authenticated + + @property + def current_user(self): + return self._current_user + + def check_origin(self): + return check_origin + + def check_referer(self): + return True + + def get_cookie(self, name, default=None): + if hasattr(self, "_cookies") and name in self._cookies: + return self._cookies[name].value + return default + + def check_xsrf_cookie(self): + if self.token_authenticated or self.settings.get("disable_check_xsrf", False): + return None + + if not self.check_origin(): + raise HTTPError(404) + + if ( + self.request.method not in {"GET", "HEAD", "OPTIONS"} + and not self.token_authenticated + ): + # Get the cookie + cookie_token = self.get_cookie("_xsrf") + # Get the token from header + header_token = self.request.headers.get("_xsrf") + + if not cookie_token: + raise HTTPError(403, "'_xsrf' cookie not present") + if not header_token: + raise HTTPError(403, "'_xsrf' argument missing") + if cookie_token != header_token: + raise HTTPError(403, "XSRF cookie does not match") + + return None + + # Set up the request + request = HTTPRequest(method) + request.connection = MagicMock() + request.headers = {} + + # Set up the application + app = jp_serverapp + app.web_app.settings["xsrf_cookies"] = True + + # Create and initialize the handler + handler = MockHandler(app.web_app, request) + + if expected_result is None: + # Should not raise an exception + handler.check_xsrf_cookie() + else: + with pytest.raises(expected_result): + handler.check_xsrf_cookie() + + @pytest.mark.parametrize( "jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": False}}] )