From 8ae3b050fed520dbcf8a58c8b6ce5e8eac92cd09 Mon Sep 17 00:00:00 2001 From: Jorim Tielemans Date: Sun, 19 Jan 2025 10:05:52 +0100 Subject: [PATCH] Support external aiohttp session (#41) * Enhance iRail API client to support external aiohttp session management and add ETag cache clearing method * Add optional session management to iRail API client and improve logging --- pyrail/irail.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/pyrail/irail.py b/pyrail/irail.py index da53317..ce4c5f2 100644 --- a/pyrail/irail.py +++ b/pyrail/irail.py @@ -28,6 +28,9 @@ class iRail: Attributes: lang (str): The language for API responses ('nl', 'fr', 'en', 'de'). + session (ClientSession, optional): The HTTP session used for API requests. + If not provided, a new session will be created and managed internally. + If provided, the session lifecycle must be managed externally. Endpoints: stations: Retrieve all stations. @@ -54,11 +57,12 @@ class iRail: "disturbances": {"optional": ["lineBreakCharacter"]}, } - def __init__(self, lang: str = "en") -> None: + def __init__(self, lang: str = "en", session: ClientSession | None = None) -> None: """Initialize the iRail API client. Args: lang (str): The language for API responses. Default is 'en'. + session (ClientSession, optional): An existing aiohttp session. Defaults to None. """ self.lang: str = lang @@ -66,13 +70,18 @@ def __init__(self, lang: str = "en") -> None: self.burst_tokens: int = 5 self.last_request_time: float = time.time() self.lock: Lock = Lock() - self.session: ClientSession | None = None + self.session: ClientSession | None = session + self._owns_session = session is None # Track ownership self.etag_cache: Dict[str, str] = {} logger.info("iRail instance created") async def __aenter__(self) -> "iRail": - """Initialize and return the aiohttp client session when entering the async context.""" - self.session = ClientSession() + """Initialize the aiohttp client session when entering the async context.""" + if self.session and not self.session.closed: + logger.debug("Using externally provided session") + elif not self.session: + logger.debug("Creating new internal aiohttp session") + self.session = ClientSession() return self async def __aexit__( @@ -80,10 +89,16 @@ async def __aexit__( ) -> None: """Close the aiohttp client session when exiting the async context.""" if self.session: - try: - await self.session.close() - except Exception as e: - logger.error("Error closing session: %s", e) + if self.session.closed: + logger.debug("Session is already closed, skipping closure") + elif not self._owns_session: + logger.debug("Session is externally provided; not closing it") + else: + logger.debug("Closing aiohttp session") + try: + await self.session.close() + except Exception as e: + logger.error("Error while closing aiohttp session: %s", e) @property def lang(self) -> str: @@ -109,6 +124,11 @@ def lang(self, value: str) -> None: else: self.__lang = "en" + def clear_etag_cache(self) -> None: + """Clear the ETag cache.""" + self.etag_cache.clear() + logger.info("ETag cache cleared") + def _refill_tokens(self) -> None: """Refill tokens for rate limiting based on elapsed time.