Skip to content

Commit

Permalink
Support external aiohttp session (#41)
Browse files Browse the repository at this point in the history
* Enhance iRail API client to support external aiohttp session management and add ETag cache clearing method

* Add optional session management to iRail API client and improve logging
  • Loading branch information
tjorim authored Jan 19, 2025
1 parent e25c390 commit 8ae3b05
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions pyrail/irail.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class iRail:
Attributes:
lang (str): The language for API responses ('nl', 'fr', 'en', 'de').
session (ClientSession, optional): The HTTP session used for API requests.
If not provided, a new session will be created and managed internally.
If provided, the session lifecycle must be managed externally.
Endpoints:
stations: Retrieve all stations.
Expand All @@ -54,36 +57,48 @@ class iRail:
"disturbances": {"optional": ["lineBreakCharacter"]},
}

def __init__(self, lang: str = "en") -> None:
def __init__(self, lang: str = "en", session: ClientSession | None = None) -> None:
"""Initialize the iRail API client.
Args:
lang (str): The language for API responses. Default is 'en'.
session (ClientSession, optional): An existing aiohttp session. Defaults to None.
"""
self.lang: str = lang
self.tokens: int = 3
self.burst_tokens: int = 5
self.last_request_time: float = time.time()
self.lock: Lock = Lock()
self.session: ClientSession | None = None
self.session: ClientSession | None = session
self._owns_session = session is None # Track ownership
self.etag_cache: Dict[str, str] = {}
logger.info("iRail instance created")

async def __aenter__(self) -> "iRail":
"""Initialize and return the aiohttp client session when entering the async context."""
self.session = ClientSession()
"""Initialize the aiohttp client session when entering the async context."""
if self.session and not self.session.closed:
logger.debug("Using externally provided session")
elif not self.session:
logger.debug("Creating new internal aiohttp session")
self.session = ClientSession()
return self

async def __aexit__(
self, exc_type: Type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None
) -> None:
"""Close the aiohttp client session when exiting the async context."""
if self.session:
try:
await self.session.close()
except Exception as e:
logger.error("Error closing session: %s", e)
if self.session.closed:
logger.debug("Session is already closed, skipping closure")
elif not self._owns_session:
logger.debug("Session is externally provided; not closing it")
else:
logger.debug("Closing aiohttp session")
try:
await self.session.close()
except Exception as e:
logger.error("Error while closing aiohttp session: %s", e)

@property
def lang(self) -> str:
Expand All @@ -109,6 +124,11 @@ def lang(self, value: str) -> None:
else:
self.__lang = "en"

def clear_etag_cache(self) -> None:
"""Clear the ETag cache."""
self.etag_cache.clear()
logger.info("ETag cache cleared")

def _refill_tokens(self) -> None:
"""Refill tokens for rate limiting based on elapsed time.
Expand Down

0 comments on commit 8ae3b05

Please sign in to comment.