From 2b801b4d8f659bb60ae5859b6678e9f987501c38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn?= Date: Sun, 1 Sep 2024 21:45:17 +0200 Subject: [PATCH] :refactor: Move tenecity logic into it's own decorator to reduece DRY and improve code legibility --- src/dns_synchub/mappers/cloudflare.py | 107 +++++++++++++++----------- src/dns_synchub/pollers/docker.py | 26 ++++--- src/dns_synchub/pollers/traefik.py | 20 +++-- 3 files changed, 89 insertions(+), 64 deletions(-) diff --git a/src/dns_synchub/mappers/cloudflare.py b/src/dns_synchub/mappers/cloudflare.py index b166e54..96b616f 100644 --- a/src/dns_synchub/mappers/cloudflare.py +++ b/src/dns_synchub/mappers/cloudflare.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from functools import partial, wraps from logging import Logger from typing import Any, Awaitable, Callable, cast @@ -24,6 +25,49 @@ class CloudFlareException(Exception): pass +def dry_run(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: + @wraps(func) + async def wrapper(self: CloudFlareMapper, zone_id: str, *args: Any, **data: Any) -> Any: + if self.dry_run: + self.logger.info(f"DRY-RUN: {func.__name__} in zone {zone_id}:, {data}") + return {**data, "zone_id": zone_id} + return await func(self, zone_id, *args, **data) + + return wrapper + + +def retry(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: + def log_before_sleep(logger, retry_state: RetryCallState): + assert retry_state.next_action + sleep_time = retry_state.next_action.sleep + logger.warning(f"Max Rate limit reached. Retry in {sleep_time} seconds...") + + @wraps(func) + async def wrapper(self: CloudFlareMapper, *args: Any, **kwargs: Any) -> Any: + assert isinstance(self, CloudFlareMapper) + + retry = AsyncRetrying( + stop=stop_after_attempt(self.config["stop"]), + wait=wait_exponential(multiplier=self.config["wait"], max=self.tout_sec), + retry=retry_if_exception_message(match="Rate limited"), + before_sleep=partial(log_before_sleep, self.logger), + ) + try: + async for attempt_ctx in retry: + with attempt_ctx: + try: + return await func(self, *args, **kwargs) + except Exception as err: + att = attempt_ctx.retry_state.attempt_number + self.logger.debug(f"CloduFlare {func.__name__} attempt {att} failed:{err}") + raise + except RetryError as err: + last_error = err.last_attempt.result() + raise CloudFlareException("Operation failed") from last_error + + return wrapper + + class CloudFlareMapper(Mapper[CloudFlare]): def __init__(self, logger: Logger, *, settings: Settings, client: CloudFlare | None = None): if client is None: @@ -54,57 +98,28 @@ async def __call__(self, hosts: list[str], source: PollerSourceType): task.cancel() raise - async def _retry(self, func: Callable[..., Awaitable[Any]], *args: Any, **kwargs: Any) -> Any: - def log_before_sleep(retry_state: RetryCallState): - sleep_time = retry_state.upcoming_sleep - self.logger.warning(f"Max Rate limit reached. Retry in {sleep_time} seconds...") - - retry = AsyncRetrying( - stop=stop_after_attempt(self.config["stop"]), - wait=wait_exponential(multiplier=self.config["wait"], max=self.tout_sec), - retry=retry_if_exception_message(match="Rate limited"), - before_sleep=log_before_sleep, - ) - try: - async for attempt in retry: - with attempt: - return await func(*args, **kwargs) - except RetryError as err: - last_error = err.last_attempt.result() - raise CloudFlareException("Operation failed") from last_error - + @retry async def get_records(self, zone_id: str, **filter: str) -> list[dict[str, Any]]: - async def _get() -> list[dict[str, Any]]: - assert self.client is not None - return await asyncio.to_thread( - self.client.zones.dns_records.get, zone_id, params=filter - ) - - return await self._retry(_get) + assert self.client is not None + return await asyncio.to_thread(self.client.zones.dns_records.get, zone_id, params=filter) + @dry_run + @retry async def post_record(self, zone_id: str, **data: str) -> dict[str, Any]: - async def _post() -> dict[str, Any]: - if self.dry_run: - self.logger.info(f"DRY-RUN: Create new record in zone {zone_id}:, {data}") - return {**data, "zone_id": zone_id} - result = await asyncio.to_thread(self.client.zones.dns_records.post, zone_id, data=data) - self.logger.info(f"Created new record in zone {zone_id}: {result}") - return cast(dict[str, Any], result) - - return await self._retry(_post) + assert self.client is not None + result = await asyncio.to_thread(self.client.zones.dns_records.post, zone_id, data=data) + self.logger.info(f"Created new record in zone {zone_id}: {result}") + return result + @dry_run + @retry async def put_record(self, zone_id: str, record_id: str, **data: str) -> dict[str, Any]: - async def _put() -> dict[str, Any]: - if self.dry_run: - self.logger.info(f"DRY-RUN: Update record {record_id } in zone {zone_id}:, {data}") - return {**data, "zone_id": zone_id} - result = await asyncio.to_thread( - self.client.zones.dns_records.put, zone_id, record_id, data=data - ) - self.logger.info(f"Updated record {record_id} in zone {zone_id} with data {data}") - return cast(dict[str, Any], result) - - return await self._retry(_put) + assert self.client is not None + result = await asyncio.to_thread( + self.client.zones.dns_records.put, zone_id, record_id, data=data + ) + self.logger.info(f"Updated record {record_id} in zone {zone_id} with data {data}") + return result # Start Program to update the Cloudflare @override diff --git a/src/dns_synchub/pollers/docker.py b/src/dns_synchub/pollers/docker.py index 04513f8..bbcce61 100644 --- a/src/dns_synchub/pollers/docker.py +++ b/src/dns_synchub/pollers/docker.py @@ -162,19 +162,23 @@ async def _watch(self): @override async def fetch(self) -> tuple[list[str], PollerSourceType]: + filters = {"status": "running"} stop = stop_after_attempt(self.config["stop"]) wait = wait_exponential(multiplier=self.config["wait"], max=self.tout_sec) - rawdata = [] - filters = {"status": "running"} + raw_data = [] try: - async for attempt in AsyncRetrying(stop=stop, wait=wait): - with attempt: - raw_data = cast( - list[Container], - await asyncio.to_thread(self.client.containers.list, filters=filters), - ) - rawdata = [DockerContainer(c, logger=self.logger) for c in raw_data] + async for attempt_ctx in AsyncRetrying(stop=stop, wait=wait): + with attempt_ctx: + try: + containers = self.client.containers + raw_data = await asyncio.to_thread(containers.list, filters=filters) + result = [DockerContainer(c, logger=self.logger) for c in raw_data] + except Exception as err: + att = attempt_ctx.retry_state.attempt_number + self.logger.debug(f"Docker.fetch attempt {att} failed: {err}") + raise except RetryError as err: - self.logger.critical(f"Could not fetch containers: {err}") + last_error = err.last_attempt.result() + self.logger.critical(f"Could not fetch containers: {last_error}") # Return a collection of routes - return self._validate(rawdata) + return self._validate(result) diff --git a/src/dns_synchub/pollers/traefik.py b/src/dns_synchub/pollers/traefik.py index 41cb4ed..2af63d7 100644 --- a/src/dns_synchub/pollers/traefik.py +++ b/src/dns_synchub/pollers/traefik.py @@ -95,15 +95,21 @@ async def _watch(self, timeout: float | None = None): async def fetch(self) -> tuple[list[str], PollerSourceType]: stop = stop_after_attempt(self.config["stop"]) wait = wait_exponential(multiplier=self.config["wait"], max=self.tout_sec) - rawdata: Any = [] + rawdata = [] assert self._client try: - async for attempt in AsyncRetrying(stop=stop, wait=wait): - with attempt: - response = await asyncio.to_thread(self._client.get, self.poll_url) - response.raise_for_status() - rawdata = response.json() + async for attempt_ctx in AsyncRetrying(stop=stop, wait=wait): + with attempt_ctx: + try: + response = await asyncio.to_thread(self._client.get, self.poll_url) + response.raise_for_status() + rawdata = response.json() + except Exception as err: + att = attempt_ctx.retry_state.attempt_number + self.logger.debug(f"Traefik.fetch attempt {att} failed: {err}") + raise except RetryError as err: - self.logger.critical(f"Failed to fetch route from Traefik API: {err}") + last_error = err.last_attempt.result() + self.logger.critical(f"Failed to fetch route from Traefik API: {last_error}") # Return a collection of routes return self._validate(rawdata)