diff --git a/starlette/requests.py b/starlette/requests.py index 08dbd84d4..9c6776f0c 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -19,6 +19,7 @@ if typing.TYPE_CHECKING: + from starlette.applications import Starlette from starlette.routing import Router @@ -175,8 +176,10 @@ def state(self) -> State: return self._state def url_for(self, name: str, /, **path_params: typing.Any) -> URL: - router: Router = self.scope["router"] - url_path = router.url_path_for(name, **path_params) + url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") + if url_path_provider is None: + raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") + url_path = url_path_provider.url_path_for(name, **path_params) return url_path.make_absolute_url(base_url=self.base_url) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index e2375a7b9..041cc7ce2 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -2,12 +2,7 @@ import contextvars from contextlib import AsyncExitStack -from typing import ( - Any, - AsyncGenerator, - AsyncIterator, - Generator, -) +from typing import Any, AsyncGenerator, AsyncIterator, Generator import anyio import pytest diff --git a/tests/test_requests.py b/tests/test_requests.py index 2f173713e..f0494e751 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -6,7 +6,7 @@ import anyio import pytest -from starlette.datastructures import Address, State +from starlette.datastructures import URL, Address, State from starlette.requests import ClientDisconnect, Request from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.types import Message, Receive, Scope, Send @@ -592,3 +592,44 @@ async def rcv() -> Message: assert await s2.__anext__() with pytest.raises(StopAsyncIteration): await s1.__anext__() + + +def test_request_url_outside_starlette_context(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope, receive) + request.url_for("index") + + client = test_client_factory(app) + with pytest.raises( + RuntimeError, + match="The `url_for` method can only be used inside a Starlette application or with a router.", + ): + client.get("/") + + +def test_request_url_starlette_context(test_client_factory: TestClientFactory) -> None: + from starlette.applications import Starlette + from starlette.middleware import Middleware + from starlette.routing import Route + from starlette.types import ASGIApp + + url_for = None + + async def homepage(request: Request) -> Response: + return PlainTextResponse("Hello, world!") + + class CustomMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + nonlocal url_for + request = Request(scope, receive) + url_for = request.url_for("homepage") + await self.app(scope, receive, send) + + app = Starlette(routes=[Route("/home", homepage)], middleware=[Middleware(CustomMiddleware)]) + + client = test_client_factory(app) + client.get("/home") + assert url_for == URL("http://testserver/home")