From 30777581b8fce3442363cfa80632b5fcbd2c1140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn?= Date: Fri, 13 Sep 2024 20:52:35 +0200 Subject: [PATCH] Add missing annotations --- .../actions/setup-python-with-uv/action.yml | 11 +- .pre-commit-config.yaml | 7 +- pyproject.toml | 43 +++- scripts/versioning.py | 35 ++-- src/dns_synchub/__about__.py | 2 +- src/dns_synchub/__init__.py | 64 +++--- src/dns_synchub/__main__.py | 73 ++++++- src/dns_synchub/cli.py | 51 ++--- src/dns_synchub/events.py | 150 ++++++-------- src/dns_synchub/logger.py | 127 +++++------- src/dns_synchub/manager.py | 89 -------- src/dns_synchub/mappers/__init__.py | 53 ++--- src/dns_synchub/mappers/cloudflare.py | 166 ++++++++------- src/dns_synchub/pollers/__init__.py | 144 +++++++------ src/dns_synchub/pollers/docker.py | 96 ++++----- src/dns_synchub/pollers/traefik.py | 69 +++---- src/dns_synchub/settings.py | 93 ++++----- src/dns_synchub/types.py | 93 +++++++++ tests/test_cloudflare.py | 192 +++++++++--------- tests/test_docker.py | 174 ++++++++-------- tests/test_mananger.py | 131 ------------ tests/test_traefik.py | 28 +-- types/CloudFlare/__init__.pyi | 15 ++ types/CloudFlare/exceptions.pyi | 23 +++ 24 files changed, 978 insertions(+), 951 deletions(-) delete mode 100644 src/dns_synchub/manager.py create mode 100644 src/dns_synchub/types.py delete mode 100644 tests/test_mananger.py create mode 100644 types/CloudFlare/__init__.pyi create mode 100644 types/CloudFlare/exceptions.pyi diff --git a/.github/actions/setup-python-with-uv/action.yml b/.github/actions/setup-python-with-uv/action.yml index 9aa8ba2..92e7f74 100644 --- a/.github/actions/setup-python-with-uv/action.yml +++ b/.github/actions/setup-python-with-uv/action.yml @@ -26,7 +26,8 @@ runs: python-version: ${{ inputs.python-version }} - name: Install uv - uses: astral-sh/setup-uv@v2 + run: curl -LsSf https://astral.sh/uv/install.sh | sh + shell: bash - name: Pin Python Version run: | @@ -34,10 +35,6 @@ runs: uv python pin ${{ steps.extract-python-version.outputs.python-version }} shell: bash - - name: Install dependencies - run: uv sync ${{ inputs.uv-sync-options}} - shell: bash - - name: Normalize UV Sync Options id: normalize-uv-sync-options run: | @@ -50,3 +47,7 @@ runs: with: path: ~/.cache/uv key: ${{ runner.os }}-python-${{ steps.extract-python-version.outputs.python-version }}-uv-sync-${{ env.normalized-uv-sync-options }} + + - name: Install dependencies + run: uv sync ${{ inputs.uv-sync-options}} + shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e7e7374..6cb6f59 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,12 @@ repos: pass_filenames: false stages: [commit] files: ^pyproject\.toml$ - additional_dependencies: [toml] + + - id: mypy + name: mypy + language: system + types: [python] + entry: uv run mypy --strict - id: pyupgrade name: Pyupgrade diff --git a/pyproject.toml b/pyproject.toml index f864f2c..ab107d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ dns-synchub = "dns_synchub:main" [tool.uv] dev-dependencies = [ + # lint + "mypy>=1.11.2", "pre-commit>=3.7.1", "pyupgrade>=3.16.0", "ruff>=0.5.4", @@ -69,7 +71,9 @@ filterwarnings = [ [tool.coverage.run] branch = true -source = 'dns_synchub' +source = [ + 'dns_synchub', +] context = '${CONTEXT}' [tool.coverage.report] @@ -83,7 +87,9 @@ exclude_lines = [ ] [tool.coverage.paths] -source = 'src/dns_synchub/' +source = [ + 'src/dns_synchub/', +] [tool.ruff] # set max-line-length to 100 @@ -150,3 +156,36 @@ keep-runtime-typing = true [tool.deptry] root = 'src/dns_synchub' +[tool.mypy] +mypy_path = "types" +files = [ + "src/", + "tests/", +] +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +namespace_packages = true +no_implicit_reexport = true +python_version = '3.11' +show_error_codes = true +strict_optional = true +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = true + +# for strict mypy: (this is the tricky one :-)) +disallow_untyped_defs = true + +# remaining arguments from `mypy --strict` which cause errors +# no_implicit_optional = true +# warn_return_any = true + +[[tool.mypy.overrides]] +module = [ + 'dotenv.*', +] +ignore_missing_imports = true diff --git a/scripts/versioning.py b/scripts/versioning.py index 9096aea..e8aba2e 100755 --- a/scripts/versioning.py +++ b/scripts/versioning.py @@ -3,14 +3,13 @@ import argparse import os import sys +import tomllib -import toml +TOML_FILE = 'pyproject.toml' +DATA_FILE = 'src/dns_synchub/__about__.py' -TOML_FILE = "pyproject.toml" -DATA_FILE = "src/dns_synchub/__about__.py" - -def update_version(toml_file: str | None = None, data_file: str | None = None): +def update_version(toml_file: str | None = None, data_file: str | None = None) -> None: # Get the current working directory current_dir = os.getcwd() @@ -23,39 +22,39 @@ def update_version(toml_file: str | None = None, data_file: str | None = None): sys.stderr.write(f"Source file '{toml_file}' not found\n") sys.exit(1) if not os.path.exists(data_file): - sys.stderr.write(f"Destination File {data_file} not found\n") + sys.stderr.write(f'Destination File {data_file} not found\n') sys.exit(1) # Read the version from pyproject.toml - with open(toml_file, "r") as f: - pyproject_data = toml.load(f) + with open(toml_file, 'rb') as f: + pyproject_data = tomllib.load(f) # Read the current contents of __about__.py - with open(data_file, "r") as f: + with open(data_file) as f: lines = f.readlines() # Update the version in __about__.py - with open(data_file, "w") as f: - version = pyproject_data["project"]["version"] + with open(data_file, 'w') as f: + version = pyproject_data['project']['version'] for line in lines: - line = f'__version__ = "{version}"\n' if line.startswith("__version__") else line + line = f"__version__ = '{version}'\n" if line.startswith('__version__') else line f.write(line) -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser( - description="Update version in destination file from source file.", + description='Update version in destination file from source file.', ) parser.add_argument( - "--source", + '--source', type=str, default=TOML_FILE, - help=f"PyProject file to read version from ({TOML_FILE})", + help=f'PyProject file to read version from ({TOML_FILE})', ) parser.add_argument( - "--target", + '--target', type=str, default=DATA_FILE, - help=f"Python Destination file to update version in ({DATA_FILE})", + help=f'Python Destination file to update version in ({DATA_FILE})', ) args = parser.parse_args() diff --git a/src/dns_synchub/__about__.py b/src/dns_synchub/__about__.py index 3dc1f76..b794fd4 100644 --- a/src/dns_synchub/__about__.py +++ b/src/dns_synchub/__about__.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = '0.1.0' diff --git a/src/dns_synchub/__init__.py b/src/dns_synchub/__init__.py index cde68ce..8db5ba4 100644 --- a/src/dns_synchub/__init__.py +++ b/src/dns_synchub/__init__.py @@ -1,38 +1,26 @@ -from __future__ import annotations - -import asyncio -import sys - -from pydantic import ValidationError - -import dns_synchub.cli as cli -import dns_synchub.logger as logger -import dns_synchub.settings as settings - - -def main(): - try: - # Load environment variables from the specified env file - cli.parse_args() - - # Load settings - options = settings.Settings() - - # Check for uppercase docker secrets or env variables - assert options.cf_token - assert options.target_domain - assert len(options.domains) > 0 - - except ValidationError as e: - print(f"Unable to load settings: {e}", file=sys.stderr) - sys.exit(1) - - # Set up logging and dump runtime settings - log = logger.report_current_status_and_settings(logger.get_logger(options), options) - try: - asyncio.run(cli.main(log, settings=options)) - except KeyboardInterrupt: - # asyncio.run will cancel any task pending when the main function exits - log.info("Cancel by user.") - log.info("Exiting...") - sys.exit(1) +# src/dns_synchub/__init__.py + +from .__about__ import __version__ as VERSION +from .logger import get_logger, initialize_logger +from .mappers import CloudFlareMapper +from .pollers import DockerPoller, TraefikPoller +from .settings import Settings + +__version__ = VERSION + +__all__ = [ + # logger subpackage + 'get_logger', + 'initialize_logger', + # settings subpackage + 'Settings', + # pollers subpackage + 'DockerPoller', + 'TraefikPoller', + # mappers subpackage + 'CloudFlareMapper', +] + + +def __dir__() -> 'list[str]': + return list(__all__) diff --git a/src/dns_synchub/__main__.py b/src/dns_synchub/__main__.py index a48a23c..d588ad4 100644 --- a/src/dns_synchub/__main__.py +++ b/src/dns_synchub/__main__.py @@ -1,16 +1,81 @@ #!/usr/bin/env python3 +import asyncio +from logging import Logger import pathlib +import re import sys -import dns_synchub +from pydantic import ValidationError -if __name__ == "__main__": +import dns_synchub.cli as cli +import dns_synchub.logger as logger +import dns_synchub.settings as settings + + +def report_state(settings: settings.Settings) -> Logger: + log = logger.initialize_logger(settings=settings) + + settings.dry_run and log.info(f'Dry Run: {settings.dry_run}') # type: ignore + log.debug(f'Default TTL: {settings.default_ttl}') + log.debug(f'Refresh Entries: {settings.refresh_entries}') + + log.debug(f"Traefik Polling Mode: {'On' if settings.enable_traefik_poll else 'Off'}") + if settings.enable_traefik_poll: + if settings.traefik_poll_url and re.match(r'^\w+://[^/?#]+', settings.traefik_poll_url): + log.debug(f'Traefik Poll Url: {settings.traefik_poll_url}') + log.debug(f'Traefik Poll Seconds: {settings.traefik_poll_seconds}') + else: + settings.enable_traefik_poll = False + log.error(f'Traefik polling disabled: Bad url: {settings.traefik_poll_url}') + + log.debug(f"Docker Polling Mode: {'On' if settings.enable_docker_poll else 'Off'}") + log.debug(f'Docker Poll Seconds: {settings.docker_timeout_seconds}') + + for dom in settings.domains: + log.debug(f'Domain Configuration: {dom.name}') + log.debug(f' Target Domain: {dom.target_domain}') + log.debug(f' TTL: {dom.ttl}') + log.debug(f' Record Type: {dom.rc_type}') + log.debug(f' Proxied: {dom.proxied}') + log.debug(f' Excluded Subdomains: {dom.excluded_sub_domains}') + + return log + + +def main() -> int: + try: + # Load environment variables from the specified env file + cli.parse_args() + # Load settings + options = settings.Settings() + # Check for uppercase docker secrets or env variables + assert options.cf_token + assert options.target_domain + assert len(options.domains) > 0 + except ValidationError as e: + print(f'Unable to load settings: {e}', file=sys.stderr) + return 1 + + # Set up logging and dump runtime settings + log = report_state(options) + try: + asyncio.run(cli.main(log, settings=options)) + except KeyboardInterrupt: + # asyncio.run will cancel any task pending when the main function exits + log.info('Cancel by user.') + log.info('Exiting...') + + # Exit grqacefully + return 0 + + +if __name__ == '__main__': # If the script is run as a module, use the directory name as the script name script_name = pathlib.Path(sys.argv[0]).stem if sys.argv[0] == __file__: script_path = pathlib.Path(sys.argv[0]) - script_name = script_path.parent.name.replace("_", "-") + script_name = script_path.parent.name.replace('_', '-') # Set the script name to the first argument and invoke the main function sys.argv[0] = script_name - sys.exit(dns_synchub.main()) + sys.exit(main()) diff --git a/src/dns_synchub/cli.py b/src/dns_synchub/cli.py index d29e34c..580699c 100755 --- a/src/dns_synchub/cli.py +++ b/src/dns_synchub/cli.py @@ -1,42 +1,45 @@ -from __future__ import annotations - import argparse -import logging +import asyncio +from logging import Logger +from typing import Any import dotenv from dns_synchub.__about__ import __version__ as VERSION -from dns_synchub.manager import DataManager -from dns_synchub.mappers import CloudFlareMapper -from dns_synchub.pollers import DockerPoller, TraefikPoller +from dns_synchub.mappers.cloudflare import CloudFlareMapper +from dns_synchub.pollers import Poller +from dns_synchub.pollers.docker import DockerPoller +from dns_synchub.pollers.traefik import TraefikPoller from dns_synchub.settings import Settings -def parse_args(): - parser = argparse.ArgumentParser(description="Cloudflare Companion") - parser.add_argument("--env-file", type=str, help="Path to the .env file") - parser.add_argument("--version", action="version", version=f"%(prog)s {VERSION}") +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description='Cloudflare Companion') + parser.add_argument('--env-file', type=str, help='Path to the .env file') + parser.add_argument('--version', action='version', version=f'%(prog)s {VERSION}') args = parser.parse_args() - args.env_file and dotenv.load_dotenv(args.env_file) # type: ignore + dotenv.load_dotenv(args.env_file) return args -async def main(log: logging.Logger, *, settings: Settings): - # Init mnager - manager = DataManager(logger=log) - +async def main(log: Logger, *, settings: Settings) -> None: # Add Cloudflarte mapper - cf = CloudFlareMapper(log, settings=settings) - await manager.add_mapper(cf, backoff=5) + dns = CloudFlareMapper(log, settings=settings) # Add Pollers + pollers: list[Poller[Any]] = [] if settings.enable_traefik_poll: - poller = TraefikPoller(log, settings=settings) - manager.add_poller(poller) + pollers.append(TraefikPoller(log, settings=settings)) if settings.enable_docker_poll: - poller = DockerPoller(log, settings=settings) - manager.add_poller(poller) - - # Start manager - await manager.start() + pollers.append(DockerPoller(log, settings=settings)) + + # Start Pollers + try: + async with asyncio.TaskGroup() as tg: + for poller in pollers: + await poller.events.subscribe(dns) + tg.create_task(poller.start()) + except asyncio.CancelledError: + for poller in pollers: + await poller.stop() diff --git a/src/dns_synchub/events.py b/src/dns_synchub/events.py index db98fe7..155730f 100644 --- a/src/dns_synchub/events.py +++ b/src/dns_synchub/events.py @@ -1,137 +1,101 @@ -from __future__ import annotations - import asyncio -import time +from collections.abc import Iterator from logging import Logger +import time from typing import ( - Callable, - Coroutine, Generic, - Protocol, - TypeAlias, TypeVar, ) -from dns_synchub.settings import PollerSourceType - -PollerSourceEvent: TypeAlias = tuple[list[str], PollerSourceType] - - -class EventSubscriber(Protocol): - def __call__(self, hosts: list[str], source: PollerSourceType) -> None: ... - - -class AsyncEventSubscriber(Protocol): - async def __call__(self, hosts: list[str], source: PollerSourceType) -> None: ... - - -T = TypeVar( - "T", - bound=Callable[[list[str], PollerSourceType], None] - | Coroutine[None, None, None] - | EventSubscriber - | AsyncEventSubscriber, +from dns_synchub.types import ( + Event, + EventSubscriber, + EventSubscriberDataType, + EventSubscriberType, ) -EventSubscriberType = tuple[asyncio.Queue[PollerSourceEvent], float, float] +T_co = TypeVar('T_co') -class EventEmitter(Generic[T]): - def __init__(self, logger: Logger, *, name: str): +class EventEmitter(Generic[T_co]): + def __init__(self, logger: Logger, *, origin: str): self.logger = logger - self.logger_name = name + self.origin = origin # Subscribers - self._subscribers: dict[T, EventSubscriberType] = {} + self._subscribers: dict[EventSubscriberType[T_co], EventSubscriberDataType[T_co]] = {} - def __iter__(self): + def __iter__(self) -> Iterator[EventSubscriberDataType[T_co]]: return iter(self._subscribers.values()) def __len__(self) -> int: return len(self._subscribers) - async def subscribe(self, callback: T, backoff: float = 0): - """ - Subscribes to events from this Poller. - - Args: - callback (Callable): The callback function to be called when an event is emitted. - backoff (float): The backoff time in seconds to wait before calling the callback again. - """ + async def subscribe(self, callback: EventSubscriberType[T_co], backoff: float = 0) -> None: # Check if callback is already subscribed assert callback not in self._subscribers - assert callable(callback) # Register subscriber - self._subscribers[callback] = (asyncio.Queue(), backoff, time.time()) - - def unsubscribe(self, callback: T): - """ - Unsubscribes from events. + self._subscribers[callback] = (asyncio.Queue[Event[T_co]](), backoff, time.time()) - Args: - callback (Callable): The callback function to be removed from subscribers. - """ + def unsubscribe(self, callback: EventSubscriberType[T_co]) -> None: self._subscribers.pop(callback, None) - async def emit(self, timeout: float | None = None): - """ - Triggers an event and notifies all subscribers. - Calls each subscriber's callback with the data. - """ - + async def emit(self, timeout: float | None = None) -> None: async def invoke( - callback: T, - queue: asyncio.Queue[PollerSourceEvent], - backoff: float, - last_called: float, - ) -> tuple[T, EventSubscriberType]: + callback: EventSubscriberType[T_co], + data: EventSubscriberDataType[T_co], + ) -> tuple[EventSubscriberType[T_co], EventSubscriberDataType[T_co]]: + # Unpack data + queue, backoff, last_called = data + # Emit data until queue is empty while not queue.empty(): - current_time = time.time() - if current_time - last_called >= backoff: - # Get callback function - func: T = getattr(callback, "__call__", callback) - assert callable(func) - # Get data from queue - data: PollerSourceEvent = await queue.get() - # Invoke - if asyncio.iscoroutinefunction(func): - await func(*data) - else: - func(*data) - else: - # Wait for backoff time and try emit again - await asyncio.sleep(backoff - (current_time - last_called)) + # Wait for backoff time + await asyncio.sleep(max(0, backoff - (time.time() - last_called))) + # Get callback function + func = callback + if isinstance(callback, EventSubscriber): + func = callback.__call__ + assert callable(func) + # Get data from queue + event: Event[T_co] = await queue.get() + # Invoke + assert asyncio.iscoroutinefunction(func) + await func(event) + # Update last called time + last_called = time.time() + return callback, (queue, backoff, last_called) - tasks: list[asyncio.Task[tuple[T, EventSubscriberType]]] = [] + tasks = [] for callback, args in self._subscribers.items(): - task = asyncio.create_task(invoke(callback, *args)) + task = asyncio.create_task(invoke(callback, args)) tasks.append(task) try: # Await for tasks to complete - for task in asyncio.as_completed(tasks, timeout=timeout): - callback, data = await task + for completed in asyncio.as_completed(tasks, timeout=timeout): + callback, data = await completed self._subscribers[callback] = data - - except asyncio.TimeoutError: - self.logger.warning(f"{self.logger_name}: Emit timeout reached.") + except TimeoutError: + self.logger.warning(f'{self.origin}: Emit timeout reached.') # Cancel all tasks [task.cancel() for task in tasks] asyncio.gather(*tasks, return_exceptions=True) - pass # Data related methods - def set_data(self, data: PollerSourceEvent, *, callback: T | None = None): - if callback is None: + def set_data(self, data: T_co, *, callback: EventSubscriberType[T_co] | None = None) -> None: + event = Event[T_co](data=data) + if callback: + assert callback in self._subscribers + queue, _, _ = self._subscribers[callback] + queue.put_nowait(event) + else: + # Broadcast data to all subscribers for queue, _, _ in self._subscribers.values(): - queue.put_nowait(data) - return - assert callback in self._subscribers - queue, _, _ = self._subscribers[callback] - queue.put_nowait(data) + queue.put_nowait(event) - def has_data(self, callback: T): + def has_data(self, callback: EventSubscriberType[T_co]) -> bool: return callback in self._subscribers and not self._subscribers[callback][0].empty() - def get_data(self, callback: T) -> PollerSourceEvent: + def get_data(self, callback: EventSubscriberType[T_co]) -> T_co | None: queue, _, _ = self._subscribers[callback] - return queue.get_nowait() + event = queue.get_nowait() + return event.data if event else None diff --git a/src/dns_synchub/logger.py b/src/dns_synchub/logger.py index 3c20227..ba5f54d 100644 --- a/src/dns_synchub/logger.py +++ b/src/dns_synchub/logger.py @@ -1,91 +1,70 @@ -from __future__ import annotations - +from functools import lru_cache import logging -import re import sys -from .settings import Settings +from dns_synchub.settings import Settings -logger = None +@lru_cache +def console_log_handler(*, formatter: logging.Formatter) -> logging.Handler: + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(formatter) + return handler -# set up logging -def initialize_logger(settings: Settings): - global logger - assert logger is None, "Logger already initialized" +@lru_cache +def file_log_handler(filename: str, *, formatter: logging.Formatter) -> logging.Handler: + try: + handler = logging.FileHandler(filename) + handler.setFormatter(formatter) + except OSError as err: + logging.critical(f"Could not open log file '{err.filename}': {err.strerror}") + return handler - # Extract attributes from settings and convert to uppercase - log_level = settings.log_level.upper() - log_type = settings.log_type.upper() - log_file = settings.log_file +# set up logging +@lru_cache +def initialize_logger(settings: Settings) -> logging.Logger: # Set up logging - logger = logging.getLogger(__name__) - - fmt = None - if log_level == "DEBUG": - logger.setLevel(logging.DEBUG) - fmt = "%(asctime)s %(levelname)s %(lineno)d | %(message)s" - - if log_level == "VERBOSE": - logger.setLevel(logging.DEBUG) - fmt = "%(asctime)s %(levelname)s | %(message)s" - - if log_level in ("NOTICE", "INFO"): - logger.setLevel(logging.INFO) - fmt = "%(asctime)s %(levelname)s | %(message)s" - - formatter = logging.Formatter(fmt, "%Y-%m-%dT%H:%M:%S%z") - if log_type in ("CONSOLE", "BOTH"): - ch = logging.StreamHandler(sys.stdout) - ch.setFormatter(formatter) - logger.addHandler(ch) - - if log_type in ("FILE", "BOTH"): + logger = logging.getLogger(settings.service_name) + + # remove all existing handlers, if any + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + # Set the log level + logger.setLevel(settings.log_level) + + # Set up console logging + if 'stderr' in settings.log_handlers: + handler = console_log_handler(formatter=settings.log_formatter) + logger.addHandler(handler) + + # Set up file logging + if 'file' in settings.log_handlers: + handler = file_log_handler(settings.log_file, formatter=settings.log_formatter) + logger.addHandler(handler) + + # Set up telemetry + telemetry_options = { + 'use_otlp_console_handler': 'otlp_console' in settings.log_handlers, + 'use_otlp_handler': 'otlp' in settings.log_handlers, + } + if any(telemetry_options.values()): try: - fh = logging.FileHandler(log_file) - fh.setFormatter(formatter) - logger.addHandler(fh) - except OSError as e: - logger.error(f"Could not open log file '{e.filename}': {e.strerror}") - - return logger - - -def report_current_status_and_settings(logger: logging.Logger, settings: Settings): - settings.dry_run and logger.info(f"Dry Run: {settings.dry_run}") # type: ignore - logger.debug(f"Default TTL: {settings.default_ttl}") - logger.debug(f"Refresh Entries: {settings.refresh_entries}") - - logger.debug(f"Traefik Polling Mode: {'On' if settings.enable_traefik_poll else 'Off'}") - if settings.enable_traefik_poll: - if settings.traefik_poll_url and re.match(r"^\w+://[^/?#]+", settings.traefik_poll_url): - logger.debug(f"Traefik Poll Url: {settings.traefik_poll_url}") - logger.debug(f"Traefik Poll Seconds: {settings.traefik_poll_seconds}") - else: - settings.enable_traefik_poll = False - logger.error(f"Traefik polling disabled: Bad url: {settings.traefik_poll_url}") - - logger.debug(f"Docker Polling Mode: {'On' if settings.enable_docker_poll else 'Off'}") - logger.debug(f"Docker Poll Seconds: {settings.docker_timeout_seconds}") + from dns_synchub.telemetry import telemetry_log_handler - for dom in settings.domains: - logger.debug(f"Domain Configuration: {dom.name}") - logger.debug(f" Target Domain: {dom.target_domain}") - logger.debug(f" TTL: {dom.ttl}") - logger.debug(f" Record Type: {dom.rc_type}") - logger.debug(f" Proxied: {dom.proxied}") - logger.debug(f" Excluded Subdomains: {dom.excluded_sub_domains}") + handler = telemetry_log_handler(settings.service_name, **telemetry_options) + logger.addHandler(handler) + except ImportError: + logger.warning('Telemetry module not found. Logging to console only.') return logger -def get_logger(settings: Settings | None = None) -> logging.Logger: - global logger - if logger is None and settings is None: - raise ValueError("Logger has not been initialized") - # Init logger if needed - assert settings is not None, "Settings must be provided if logger is not initialized" - logger = logger or initialize_logger(settings) +def get_logger(name: str | Settings) -> logging.Logger: + if isinstance(name, str): + return logging.getLogger(name) + logger = logging.getLogger(name.service_name) + initialize_logger(name) return logger diff --git a/src/dns_synchub/manager.py b/src/dns_synchub/manager.py deleted file mode 100644 index 2e52328..0000000 --- a/src/dns_synchub/manager.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging -from typing import Any - -from dns_synchub.settings import PollerSourceType - -from .events import EventEmitter -from .mappers import Mapper -from .pollers import Poller - - -class DataManager: - def __init__(self, *, logger: logging.Logger): - self.logger = logger - self.tasks: list[asyncio.Task[None]] = [] - - # Subscribers - self.pollers: set[tuple[Poller[Any], float]] = set() - self.mappers: EventEmitter[Mapper[Any]] = EventEmitter(logger, name="Manager") - - # Data - self.data: dict[PollerSourceType, list[str]] = {} - - async def __call__(self, names: list[str], source: PollerSourceType): - # Store new data for mappers in each mappers queue - self.mappers.set_data((names, source)) - # Combine data previously received from pollers - self._combine_data({source: names}) - await self.mappers.emit() - - def _combine_data(self, data: dict[PollerSourceType, list[str]]): - """Combine data from multiple pollers.""" - for source, values in data.items(): - assert isinstance(values, list) - self.data.setdefault(source, []) - self.data[source].extend(values) - self.data[source] = list(set(self.data[source])) - - def add_poller(self, poller: Poller[Any], backoff: float = 0): - """Add a DataPoller to the manager.""" - assert not any(poller == p for p, _ in self.pollers) - self.pollers.add((poller, backoff)) - - async def add_mapper(self, mapper: Mapper[Any], backoff: float = 0): - """Add a Mapper to the manager.""" - await self.mappers.subscribe(mapper, backoff=backoff) - - async def start(self, timeout: float | None = None): - """Start all pollers by fetching initial data and subscribing to events.""" - assert len(self.tasks) == 0 - # Loop pollers - for poller, backoff in self.pollers: - # Register itelf to be called when new data is available - await poller.events.subscribe(self, backoff=backoff) - # Ask poller to start monitoring data - self.tasks.append(asyncio.create_task(poller.run(timeout=timeout))) - # Add mappers emission to tasks that mast run concurrently - if len(self.mappers) > 0: - self.tasks.append(asyncio.create_task(self.mappers.emit(timeout=timeout))) - try: - # wait until timeout is reached or tasks are canceled - await asyncio.gather(*self.tasks) - except asyncio.CancelledError: - # Gracefully stop monitoring - await self.stop() - finally: - # Clear tasks - self.tasks.clear() - - async def stop(self): - """Unsubscribe all pollers from their event systems.""" - # This could be extended to stop any running background tasks if needed - if pending := [task for task in self.tasks if task.cancel()]: - self.logger.info("Stopping running pollers...") - await asyncio.gather(*pending, return_exceptions=True) - self.tasks.clear() - - def aggregate_data(self): - """Aggregate and return the latest data from all pollers.""" - for poller, _ in self.pollers: - try: - names, source = poller.events.get_data(self) - self._combine_data({source: names}) - except asyncio.QueueEmpty: - pass - # Return the combined data - return self.data diff --git a/src/dns_synchub/mappers/__init__.py b/src/dns_synchub/mappers/__init__.py index b06cb92..5bf2b8d 100644 --- a/src/dns_synchub/mappers/__init__.py +++ b/src/dns_synchub/mappers/__init__.py @@ -1,8 +1,18 @@ -import logging -from abc import ABC, abstractmethod -from typing import Generic, TypedDict, TypeVar +from abc import ABC +from logging import Logger +from typing import Generic, Protocol, TypedDict, TypeVar, runtime_checkable -from dns_synchub.settings import DomainsModel, PollerSourceType, Settings +from dns_synchub.settings import Settings +from dns_synchub.types import DomainsModel, EventSubscriber + +T = TypeVar('T') # Client backemd +E = TypeVar('E') # Event type accepted +R = TypeVar('R') # Result type + + +@runtime_checkable +class MapperProtocol(EventSubscriber[E], Protocol[E, R]): + async def sync(self, data: E) -> list[R] | None: ... class MapperConfig(TypedDict): @@ -14,31 +24,21 @@ class MapperConfig(TypedDict): """Delay in seconds before syncing mappings""" -class BaseMapper(ABC): +class BaseMapper(ABC, MapperProtocol[E, R], Generic[E, R]): config: MapperConfig = { - "stop": 3, - "wait": 4, - "delay": 0, + 'stop': 3, + 'wait': 4, + 'delay': 0, } - def __init__(self, logger: logging.Logger): + def __init__(self, logger: Logger): self.logger = logger - self.mappings = {} - - @abstractmethod - async def __call__(self, hosts: list[str], source: PollerSourceType): ... - @abstractmethod - async def sync(self, host: str, source: PollerSourceType) -> DomainsModel | None: ... - -T = TypeVar("T") - - -class Mapper(BaseMapper, Generic[T]): - def __init__(self, logger: logging.Logger, *, settings: Settings, client: T | None = None): +class Mapper(BaseMapper[E, DomainsModel], Generic[E, T]): + def __init__(self, logger: Logger, *, settings: Settings, client: T | None = None): # init client - self.client: T | None = client + self._client: T | None = client # Domain defaults self.dry_run = settings.dry_run @@ -50,9 +50,14 @@ def __init__(self, logger: logging.Logger, *, settings: Settings, client: T | No self.included_hosts = settings.included_hosts self.excluded_hosts = settings.excluded_hosts - super(Mapper, self).__init__(logger) + super().__init__(logger) + + @property + def client(self) -> T: + assert self._client is not None, 'Client is not initialized' + return self._client from dns_synchub.mappers.cloudflare import CloudFlareMapper # noqa: E402 -__all__ = ["CloudFlareMapper"] +__all__ = ['CloudFlareMapper'] diff --git a/src/dns_synchub/mappers/cloudflare.py b/src/dns_synchub/mappers/cloudflare.py index 96b616f..4d3b0bb 100644 --- a/src/dns_synchub/mappers/cloudflare.py +++ b/src/dns_synchub/mappers/cloudflare.py @@ -1,13 +1,15 @@ from __future__ import annotations import asyncio +from collections.abc import Awaitable, Callable from functools import partial, wraps from logging import Logger -from typing import Any, Awaitable, Callable, cast +import time +from typing import Any, cast -from CloudFlare import CloudFlare # type: ignore -from CloudFlare import exceptions as CloudFlareExceptions # type: ignore -from tenacity import ( # type: ignore +from CloudFlare import CloudFlare +from CloudFlare import exceptions as CloudFlareExceptions +from tenacity import ( AsyncRetrying, RetryCallState, RetryError, @@ -18,7 +20,9 @@ from typing_extensions import override from dns_synchub.mappers import Mapper -from dns_synchub.settings import DomainsModel, PollerSourceType, Settings +from dns_synchub.pollers import PollerData +from dns_synchub.settings import Settings +from dns_synchub.types import DomainsModel, Event, PollerSourceType class CloudFlareException(Exception): @@ -29,27 +33,27 @@ def dry_run(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any] @wraps(func) async def wrapper(self: CloudFlareMapper, zone_id: str, *args: Any, **data: Any) -> Any: if self.dry_run: - self.logger.info(f"DRY-RUN: {func.__name__} in zone {zone_id}:, {data}") - return {**data, "zone_id": zone_id} + self.logger.info(f'DRY-RUN: {func.__name__} in zone {zone_id}: {data}') + return {**data, 'zone_id': zone_id} return await func(self, zone_id, *args, **data) return wrapper def retry(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: - def log_before_sleep(logger, retry_state: RetryCallState): + def log_before_sleep(logger: Logger, retry_state: RetryCallState) -> None: assert retry_state.next_action sleep_time = retry_state.next_action.sleep - logger.warning(f"Max Rate limit reached. Retry in {sleep_time} seconds...") + logger.warning(f'Max Rate limit reached. Retry in {sleep_time} seconds...') @wraps(func) async def wrapper(self: CloudFlareMapper, *args: Any, **kwargs: Any) -> Any: assert isinstance(self, CloudFlareMapper) retry = AsyncRetrying( - stop=stop_after_attempt(self.config["stop"]), - wait=wait_exponential(multiplier=self.config["wait"], max=self.tout_sec), - retry=retry_if_exception_message(match="Rate limited"), + stop=stop_after_attempt(self.config['stop']), + wait=wait_exponential(multiplier=self.config['wait'], max=self.tout_sec), + retry=retry_if_exception_message(match='Rate limited'), before_sleep=partial(log_before_sleep, self.logger), ) try: @@ -59,44 +63,41 @@ async def wrapper(self: CloudFlareMapper, *args: Any, **kwargs: Any) -> Any: return await func(self, *args, **kwargs) except Exception as err: att = attempt_ctx.retry_state.attempt_number - self.logger.debug(f"CloduFlare {func.__name__} attempt {att} failed:{err}") + self.logger.debug(f'CloudFlare {func.__name__} attempt {att} failed: {err}') raise except RetryError as err: last_error = err.last_attempt.result() - raise CloudFlareException("Operation failed") from last_error + raise CloudFlareException('Operation failed') from last_error return wrapper -class CloudFlareMapper(Mapper[CloudFlare]): +class CloudFlareMapper(Mapper[PollerData[PollerSourceType], CloudFlare]): def __init__(self, logger: Logger, *, settings: Settings, client: CloudFlare | None = None): if client is None: assert settings.cf_token is not None client = CloudFlare( token=settings.cf_token, - debug=settings.log_level.upper() == "VERBOSE", + debug=settings.log_level == settings.verbose, ) - logger.debug("CloudFlare Scoped API client started") + logger.debug('CloudFlare Scoped API client started') self.tout_sec = settings.cf_timeout_seconds self.sync_sec = settings.cf_sync_seconds + self.lastcall = 0.0 # Initialize the parent class - super(CloudFlareMapper, self).__init__(logger, settings=settings, client=client) + super().__init__(logger, settings=settings, client=client) @override - async def __call__(self, hosts: list[str], source: PollerSourceType): - tasks = [asyncio.create_task(self.sync(host, source)) for host in hosts] - try: - _, pending = await asyncio.wait(tasks, timeout=self.tout_sec) - if pending: - for task in pending: - task.cancel() - self.logger.warning("Timeout reached. Cancelling pending tasks...") - except asyncio.CancelledError: - for task in tasks: - task.cancel() - raise + async def __call__(self, event: Event[PollerData[PollerSourceType]]) -> None: + while True: + if backoff := (self.lastcall + self.sync_sec) - time.time() <= 0: + break + await asyncio.sleep(backoff) + # Reset sync time + self.lastcall = time.time() + await self.sync(event.data) @retry async def get_records(self, zone_id: str, **filter: str) -> list[dict[str, Any]]: @@ -108,7 +109,7 @@ async def get_records(self, zone_id: str, **filter: str) -> list[dict[str, Any]] async def post_record(self, zone_id: str, **data: str) -> dict[str, Any]: assert self.client is not None result = await asyncio.to_thread(self.client.zones.dns_records.post, zone_id, data=data) - self.logger.info(f"Created new record in zone {zone_id}: {result}") + self.logger.info(f'Created new record in zone {zone_id}: {result}') return result @dry_run @@ -118,59 +119,78 @@ async def put_record(self, zone_id: str, record_id: str, **data: str) -> dict[st result = await asyncio.to_thread( self.client.zones.dns_records.put, zone_id, record_id, data=data ) - self.logger.info(f"Updated record {record_id} in zone {zone_id} with data {data}") + self.logger.info(f'Updated record {record_id} in zone {zone_id} with data {data}') return result # Start Program to update the Cloudflare @override - async def sync(self, host: str, source: PollerSourceType) -> DomainsModel | None: - def is_domain_excluded(host: str, domain: DomainsModel): + async def sync(self, data: PollerData[PollerSourceType]) -> list[DomainsModel] | None: # noqa: C901 + def is_domain_excluded(host: str, domain: DomainsModel) -> bool: for sub_dom in domain.excluded_sub_domains: - if f"{sub_dom}.{domain.name}" in host: - self.logger.info(f"Ignoring {host}: Match excluded sub domain: {sub_dom}") + if f'{sub_dom}.{domain.name}' in host: + self.logger.info(f'Ignoring {host}: Match excluded sub domain: {sub_dom}') return True return False - for domain_info in self.domains: - # Don't update the domain if it's the same as the target domain, which sould be used on tunnel - if host == domain_info.target_domain: - continue - # Skip if it's not a subdomain of the domain we're looking for - if host.find(domain_info.name) < 0: - continue - # Skip if the domain is in exclude list - if is_domain_excluded(host, domain_info): - continue - # Skip if already present and refresh entries is not required - records = await self.get_records(domain_info.zone_id, name=host) - if records and not self.refresh_entries: - assert len(records) == 1 - self.logger.info(f"Record {host} found. Not refreshing. Skipping...") - return DomainsModel(**records.pop()) - # Prepare data for the new record - data = cast( - dict[str, str], - { - "type": self.rc_type, - "name": host, - "content": domain_info.target_domain, - "ttl": str(domain_info.ttl) if domain_info.ttl is not None else "auto", - "proxied": str(domain_info.proxied), - "comment": domain_info.comment, - "tag": f"poller:{source}", - }, - ) - result = None - try: + tasks: list[Any] = [] + for host in data.hosts: + for domain_info in self.domains: + # Don't update the domain if it's the same as the target domain, which sould be used on tunnel + if host == domain_info.target_domain: + continue + # Skip if it's not a subdomain of the domain we're looking for + if host.find(domain_info.name) < 0: + continue + # Skip if the domain is in exclude list + if is_domain_excluded(host, domain_info): + continue + # Skip if already present and refresh entries is not required + records = await self.get_records(domain_info.zone_id, name=host) + if records and not self.refresh_entries: + assert len(records) == 1 + tasks.append(asyncio.create_task(asyncio.sleep(0, result=records.pop()))) + self.logger.info(f'Record {host} found. Not refreshing. Skipping...') + continue + # Prepare data for the new record + domain = cast( + dict[str, Any], + { + 'type': self.rc_type, + 'name': host, + 'content': domain_info.target_domain, + 'ttl': str(domain_info.ttl) if domain_info.ttl is not None else 'auto', + 'proxied': domain_info.proxied, + 'comment': domain_info.comment, + 'tag': f'poller:{data.source}', + }, + ) # Update the record if it already exists if records: assert len(records) == 1 assert self.refresh_entries - result = await self.put_record(domain_info.zone_id, records.pop()["id"], **data) + future = self.put_record(domain_info.zone_id, records.pop()['id'], **domain) # Create a new record if it doesn't exist yet else: - result = await self.post_record(domain_info.zone_id, **data) - except CloudFlareExceptions.CloudFlareAPIError as err: - self.logger.error(f"Sync Error for {host}: {str(err)} [Code {int(err)}]") - finally: - return DomainsModel(**result) if result else None + future = self.post_record(domain_info.zone_id, **domain) + # Append the task to the results + tasks.append(asyncio.ensure_future(future)) + break + + if not tasks: + return None + + results: list[DomainsModel] = [] + # run tasks concurrently + done, pending = await asyncio.wait(tasks, timeout=self.tout_sec) + # Cancel pending tasks + [task.cancel() for task in pending] + # Process Exceptions and get results + for task in done: + if err := task.exception(): + if isinstance(err, CloudFlareExceptions.CloudFlareAPIError): + self.logger.error(f"Sync failed for '{data.source}': [{int(err)}]") + self.logger.error(f'{str(err)}') + continue + results.append(DomainsModel(**task.result())) + # Return results + return results or None diff --git a/src/dns_synchub/pollers/__init__.py b/src/dns_synchub/pollers/__init__.py index 591c078..593be01 100644 --- a/src/dns_synchub/pollers/__init__.py +++ b/src/dns_synchub/pollers/__init__.py @@ -1,47 +1,64 @@ from __future__ import annotations -import asyncio from abc import ABC, abstractmethod +import asyncio +from dataclasses import dataclass from datetime import datetime, timedelta from logging import Logger -from typing import Any, Callable, Generic, TypedDict, TypeVar +from typing import ( + Any, + ClassVar, + Generic, + NotRequired, + Protocol, + TypedDict, + TypeVar, + runtime_checkable, +) from weakref import ref as WeakRef from typing_extensions import override from dns_synchub.events import EventEmitter -from dns_synchub.settings import PollerSourceType, Settings +from dns_synchub.settings import Settings +from dns_synchub.types import EventSubscriberType, PollerSourceType + +T = TypeVar('T') + + +@dataclass +class PollerData(Generic[T]): + hosts: list[str] + source: T + +@runtime_checkable +class PollerProtocol(Protocol[T]): + events: EventEmitter[PollerData[T]] -class PollerConfig(TypedDict): + async def fetch(self) -> PollerData[T]: ... + + async def start(self, timeout: float | None = None) -> None: ... + + async def stop(self) -> None: ... + + +class PollerConfig(TypedDict, Generic[T]): stop: int """Max number of retries to attempt before exponential backoff fails""" wait: int """Factor to multiply the backoff time by""" - source: PollerSourceType + source: NotRequired[T] """The source of the poller""" -class PollerEvents(EventEmitter[Callable[..., Any]]): - def __init__(self, logger: Logger, *, poller: Poller[Any]): - self.poller = WeakRef(poller) - assert "source" in poller.config - super(PollerEvents, self).__init__(logger, name=poller.config["source"]) - - # Event related methods - @override - async def subscribe(self, callback: Callable[..., Any], backoff: float = 0): - # Register subscriber - await super(PollerEvents, self).subscribe(callback, backoff=backoff) - # Fetch data and store locally if required - self.set_data(await self.poller().fetch(), callback=callback) # type: ignore - - -class BasePoller(ABC): - config: PollerConfig = { - "stop": 3, - "wait": 4, - "source": "manual", +class BasePoller(ABC, PollerProtocol[T], Generic[T]): + # Generic Typed ClassVars are not supported. + # See https://github.com/python/typing/discussions/1424 for + # open discussion about support + config: ClassVar[PollerConfig[T]] = { # type: ignore + 'stop': 3, + 'wait': 4, } def __init__(self, logger: Logger): @@ -53,18 +70,10 @@ def __init__(self, logger: Logger): client (Any): The client instance for making requests. """ self.logger = logger - - # Poller methods - @abstractmethod - async def fetch(self) -> tuple[list[str], PollerSourceType]: - """ - Abstract method to fetch data. - Must be implemented by subclasses. - """ - pass + self._wtask: asyncio.Task[None] @abstractmethod - async def _watch(self): + async def _watch(self) -> None: """ Abstract method to watch for changes. This method must emit signals whenever new data is available. @@ -73,7 +82,7 @@ async def _watch(self): """ pass - async def run(self, timeout: float | None = None): + async def start(self, timeout: float | None = None) -> None: """ Starts the Poller and watches for changes. @@ -82,62 +91,75 @@ async def run(self, timeout: float | None = None): the method will wait indefinitely. """ name = self.__class__.__name__ - self.logger.info(f"Starting {name}: Watching for changes") + self.logger.info(f'Starting {name}: Watching for changes') # self.fetch is called for the firstime, whehever a a client subscribe to # this poller, so there's no need to initialy fetch data - watch_task = asyncio.create_task(self._watch()) + self._wtask = asyncio.create_task(self._watch()) if timeout is not None: until = datetime.now() + timedelta(seconds=timeout) - self.logger.debug(f"{name}: Stop programed at {until}") + self.logger.debug(f'{name}: Stop programed at {until}') try: - await asyncio.wait_for(watch_task, timeout) - except asyncio.TimeoutError: + await asyncio.wait_for(self._wtask, timeout) + except TimeoutError: self.logger.info(f"{name}: Run timeout '{timeout}s' reached") except asyncio.CancelledError: - self.logger.info(f"{name}: Run was cancelled") + self.logger.info(f'{name}: Run was cancelled') finally: - watch_task.cancel() + await self.stop() + + async def stop(self) -> None: + name = self.__class__.__name__ + if self._wtask and not self._wtask.done(): + self.logger.info(f'Stopping {name}: Cancelling watch task') + self._wtask.cancel() try: - await watch_task + await self._wtask except asyncio.CancelledError: - self.logger.info(f"{name}: Watch task was cancelled") + self.logger.info(f'{name}: Watch task was cancelled') -T = TypeVar("T") +class PollerEventEmitter(EventEmitter[PollerData[PollerSourceType]]): + def __init__(self, logger: Logger, *, poller: Poller[Any]): + assert 'source' in poller.config + self.poller = WeakRef(poller) + super().__init__(logger, origin=poller.config['source']) + # Event related methods + @override + async def subscribe( + self, callback: EventSubscriberType[PollerData[PollerSourceType]], backoff: float = 0 + ) -> None: + # Register subscriber + await super().subscribe(callback, backoff=backoff) + # Fetch data and store locally if required + if poller := self.poller(): + self.set_data(await poller.fetch(), callback=callback) -class Poller(BasePoller, Generic[T]): - def __init__(self, logger: Logger, *, settings: Settings, client: T | None = None): - super(Poller, self).__init__(logger) +class Poller(BasePoller[PollerSourceType], Generic[T]): + def __init__(self, logger: Logger, *, settings: Settings, client: T | None = None): # init client self._client: T | None = client - self.events = PollerEvents(logger, poller=self) + self.events = PollerEventEmitter(logger, poller=self) # Computed from settings self.included_hosts = settings.included_hosts self.excluded_hosts = settings.excluded_hosts + super().__init__(logger) + @property def client(self) -> T: - assert self._client is not None, "Client is not initialized" + assert self._client is not None, 'Client is not initialized' return self._client - @abstractmethod - async def fetch(self) -> tuple[list[str], PollerSourceType]: - """ - Abstract method to fetch data. - Must be implemented by subclasses. - """ - pass - # ruff: noqa: E402 from dns_synchub.pollers.docker import DockerPoller from dns_synchub.pollers.traefik import TraefikPoller -# run: enable +# ruff: enable -__all__ = ["TraefikPoller", "DockerPoller"] +__all__ = ['TraefikPoller', 'DockerPoller'] diff --git a/src/dns_synchub/pollers/docker.py b/src/dns_synchub/pollers/docker.py index bbcce61..5dfc584 100644 --- a/src/dns_synchub/pollers/docker.py +++ b/src/dns_synchub/pollers/docker.py @@ -1,8 +1,8 @@ import asyncio -import logging -import re from datetime import datetime from functools import lru_cache +import logging +import re from typing import Any, cast import docker @@ -12,8 +12,9 @@ from tenacity import AsyncRetrying, RetryError, stop_after_attempt, wait_exponential from typing_extensions import override -from dns_synchub.pollers import Poller -from dns_synchub.settings import PollerSourceType, Settings +from dns_synchub.pollers import Poller, PollerData +from dns_synchub.settings import Settings +from dns_synchub.types import PollerSourceType class DockerError(Exception): @@ -29,13 +30,13 @@ def __init__(self, container: Container, *, logger: logging.Logger): @property def id(self) -> str | None: - return self.container.attrs.get("Id") + return self.container.attrs.get('Id') @property - def labels(self) -> dict[str, str]: - return self.container.attrs.get("Config", {}).get("Labels", {}) + def labels(self) -> Any: + return self.container.attrs.get('Config', {}).get('Labels', {}) - def __getattr__(self, name: str) -> str | None: + def __getattr__(self, name: str) -> Any: if name in self.container.attrs: return self.container.attrs[name] return self.labels.get(name) @@ -46,16 +47,16 @@ def hosts(self) -> list[str]: # Try to find traefik filter. If found, tray to parse for label, value in self.labels.items(): - if not re.match(r"traefik.*?\.rule", label): + if not re.match(r'traefik.*?\.rule', label): continue self.logger.debug(f"Found traefik label '{label}' from container {self.id}") - if "Host" not in value: - self.logger.debug(f"Malformed rule in container {self.id} - Missing Host") + if 'Host' not in value: + self.logger.debug(f'Malformed rule in container {self.id} - Missing Host') continue # Extract the domains from the rule # Host(`example.com`) => ['example.com'] - hosts = re.findall(r"Host\(`([^`]+)`\)", value) + hosts = re.findall(r'Host\(`([^`]+)`\)', value) self.logger.debug(f"Found service '{self.Name}' with hosts: {hosts}") return hosts @@ -63,7 +64,7 @@ def hosts(self) -> list[str]: class DockerPoller(Poller[DockerClient]): - config = {**Poller.config, "source": "docker"} + config = {**Poller.config, 'source': 'docker'} # type: ignore def __init__( self, @@ -72,9 +73,6 @@ def __init__( settings: Settings, client: DockerClient | None = None, ): - # Initialize the Poller - super(DockerPoller, self).__init__(logger, settings=settings, client=client) - # Computed from settings self.poll_sec = settings.docker_poll_seconds self.tout_sec = settings.docker_timeout_seconds @@ -83,24 +81,30 @@ def __init__( self.filter_label = settings.docker_filter_label self.filter_value = settings.docker_filter_value + self._client: DockerClient | None = client + + # Initialize the Poller + super().__init__(logger, settings=settings, client=client) + @property def client(self) -> DockerClient: if self._client is None: try: # Init Docker client if not provided - self.logger.debug("Connecting to Docker...") + self.logger.debug('Connecting to Docker...') self._client = docker.from_env(timeout=self.tout_sec) except Exception as err: - self.logger.error(f"Could not connect to Docker: {err}") - self.logger.error("Please make sure Docker is running and accessible") - raise DockerError("Could not connect to Docker", error=err) from err + self._client = None + self.logger.error(f'Could not connect to Docker: {err}') + self.logger.error('Please make sure Docker is running and accessible') + raise DockerError('Could not connect to Docker', error=err) from err else: # Get Docker Host info info = cast(dict[str, Any], self._client.info()) # type: ignore self.logger.debug(f"Connected to Docker Host at '{info.get('Name')}'") return self._client - def _is_enabled(self, container: DockerContainer): + def _is_enabled(self, container: DockerContainer) -> bool: # If no filter is set, return True if self.filter_label is None: return True @@ -114,57 +118,57 @@ def _is_enabled(self, container: DockerContainer): return True # A filter value is also set, check if it matches assert isinstance(self.filter_value, re.Pattern) - return self.filter_value.match(value) + return self.filter_value.match(value) is not None return False - def _validate(self, raw_data: list[DockerContainer]) -> tuple[list[str], PollerSourceType]: - data: list[str] = [] + def _validate(self, raw_data: list[DockerContainer]) -> PollerData[PollerSourceType]: + hosts: list[str] = [] for container in raw_data: # Check if container is enabled if not self._is_enabled(container): - self.logger.debug(f"Skipping container {container.id}") + self.logger.debug(f'Skipping container {container.id}') continue # Validate domain and queue for sync for host in container.hosts: - data.append(host) + hosts.append(host) # Return a collection of zones to sync - assert "source" in self.config - return data, self.config["source"] + assert 'source' in self.config + return PollerData[PollerSourceType](hosts, self.config['source']) @override - async def _watch(self): - until = datetime.now().strftime("%s") + async def _watch(self) -> None: + until = datetime.now().strftime('%s') while True: since = until - self.logger.debug("Fetching routers from Docker API") - until = datetime.now().strftime("%s") + self.logger.debug('Fetching routers from Docker API') + until = datetime.now().strftime('%s') # Ther's no swarm in podman engine, so remove Action filter - filter = {"Type": "service", "status": "start"} - kwargs = {"since": since, "until": until, "filters": filter, "decode": True} + filter = {'Type': 'service', 'status': 'start'} + kwargs = {'since': since, 'until': until, 'filters': filter, 'decode': True} events = None try: - events: Any = await asyncio.to_thread(self.client.events, **kwargs) # type: ignore + events = await asyncio.to_thread(self.client.events, **kwargs) for event in events: - if "id" not in event: - self.logger.warning("Container ID is None. Skipping container.") + if 'id' not in event: + self.logger.warning('Container ID is None. Skipping container.') continue - raw_data = await asyncio.to_thread(self.client.containers.get, event["id"]) + raw_data = await asyncio.to_thread(self.client.containers.get, event['id']) services = [DockerContainer(raw_data, logger=self.logger)] self.events.set_data(self._validate(services)) except NotFound: await self.events.emit() await asyncio.sleep(self.poll_sec) except asyncio.CancelledError: - self.logger.info("Dokcker polling cancelled. Performing cleanup.") - return + self.logger.info('Docker polling cancelled. Performing cleanup.') + raise finally: _ = events and await asyncio.to_thread(events.close) @override - async def fetch(self) -> tuple[list[str], PollerSourceType]: - filters = {"status": "running"} - stop = stop_after_attempt(self.config["stop"]) - wait = wait_exponential(multiplier=self.config["wait"], max=self.tout_sec) + async def fetch(self) -> PollerData[PollerSourceType]: + filters = {'status': 'running'} + stop = stop_after_attempt(self.config['stop']) + wait = wait_exponential(multiplier=self.config['wait'], max=self.tout_sec) raw_data = [] try: async for attempt_ctx in AsyncRetrying(stop=stop, wait=wait): @@ -175,10 +179,10 @@ async def fetch(self) -> tuple[list[str], PollerSourceType]: result = [DockerContainer(c, logger=self.logger) for c in raw_data] except Exception as err: att = attempt_ctx.retry_state.attempt_number - self.logger.debug(f"Docker.fetch attempt {att} failed: {err}") + self.logger.debug(f'Docker.fetch attempt {att} failed: {err}') raise except RetryError as err: last_error = err.last_attempt.result() - self.logger.critical(f"Could not fetch containers: {last_error}") + self.logger.critical(f'Could not fetch containers: {last_error}') # Return a collection of routes return self._validate(result) diff --git a/src/dns_synchub/pollers/traefik.py b/src/dns_synchub/pollers/traefik.py index 2af63d7..c7acfe1 100644 --- a/src/dns_synchub/pollers/traefik.py +++ b/src/dns_synchub/pollers/traefik.py @@ -1,14 +1,15 @@ import asyncio -import re from logging import Logger +import re from typing import Any -from requests import Session +from requests import Response, Session from tenacity import AsyncRetrying, RetryError, stop_after_attempt, wait_exponential from typing_extensions import override -from dns_synchub.pollers import Poller -from dns_synchub.settings import PollerSourceType, Settings +from dns_synchub.pollers import Poller, PollerData +from dns_synchub.settings import Settings +from dns_synchub.types import PollerSourceType class TimeoutSession(Session): @@ -16,85 +17,85 @@ def __init__(self, *, timeout: float | None = None): self.timeout = timeout super().__init__() - def request(self, *args: Any, **kwargs: dict[str, Any]): - if "timeout" not in kwargs and self.timeout: - kwargs["timeout"] = self.timeout # type: ignore - return super().request(*args, **kwargs) + def request(self, method: str | bytes, url: str | bytes, *params: Any, **data: Any) -> Response: + if 'timeout' not in data and self.timeout: + data['timeout'] = self.timeout + return super().request(method, url, *params, **data) class TraefikPoller(Poller[Session]): - config = {**Poller.config, "source": "traefik"} + config = {**Poller.config, 'source': 'traefik'} # type: ignore def __init__(self, logger: Logger, *, settings: Settings, client: Session | None = None): # Computed from settings self.poll_sec = settings.traefik_poll_seconds self.tout_sec = settings.traefik_timeout_seconds - self.poll_url = f"{settings.traefik_poll_url}/api/http/routers" + self.poll_url = f'{settings.traefik_poll_url}/api/http/routers' # Providers filtering self.excluded_providers = settings.traefik_excluded_providers # Initialize the Poller client = client or TimeoutSession(timeout=self.tout_sec) - super(TraefikPoller, self).__init__(logger, settings=settings, client=client) + super().__init__(logger, settings=settings, client=client) def _is_valid_route(self, route: dict[str, Any]) -> bool: # Computed from settings - required_keys = ["status", "name", "rule"] + required_keys = ['status', 'name', 'rule'] if any(key not in route for key in required_keys): - self.logger.debug(f"Traefik Router Name: {route} - Missing Key") + self.logger.debug(f'Traefik Router Name: {route} - Missing Key') return False - if route["status"] != "enabled": + if route['status'] != 'enabled': self.logger.debug(f"Traefik Router Name: {route['name']} - Not Enabled") return False - if "Host" not in route["rule"]: + if 'Host' not in route['rule']: self.logger.debug(f"Traefik Router Name: {route['name']} - Missing Host") # Route is valid and enabled return True def _is_valid_host(self, host: str) -> bool: if not any(pattern.match(host) for pattern in self.included_hosts): - self.logger.debug(f"Traefik Router Host: {host} - Not Match with Include Hosts") + self.logger.debug(f'Traefik Router Host: {host} - Not Match with Include Hosts') return False if any(pattern.match(host) for pattern in self.excluded_hosts): - self.logger.debug(f"Traefik Router Host: {host} - Match with Exclude Hosts") + self.logger.debug(f'Traefik Router Host: {host} - Match with Exclude Hosts') return False # Host is intended to be synced return True - def _validate(self, raw_data: list[dict[str, Any]]) -> tuple[list[str], PollerSourceType]: - data: list[str] = [] + def _validate(self, raw_data: list[dict[str, Any]]) -> PollerData[PollerSourceType]: + hosts: list[str] = [] for route in raw_data: # Check if route is whell formed if not self._is_valid_route(route): continue # Extract the domains from the rule - hosts = re.findall(r"Host\(`([^`]+)`\)", route["rule"]) - self.logger.debug(f"Traefik Router Name: {route['name']} domains: {hosts}") + host_rules = re.findall(r'Host\(`([^`]+)`\)', route['rule']) + self.logger.debug(f"Traefik Router Name: {route['name']} host: {host_rules}") # Validate domain and queue for sync - for host in (host for host in hosts if self._is_valid_host(host)): + for host in (host for host in host_rules if self._is_valid_host(host)): self.logger.info(f"Found Traefik Router: {route['name']} with Hostname {host}") - data.append(host) + hosts.append(host) # Return a collection of zones to sync - assert "source" in self.config - return data, self.config["source"] + assert 'source' in self.config + return PollerData[PollerSourceType](hosts, self.config['source']) @override - async def _watch(self, timeout: float | None = None): + async def _watch(self) -> None: try: while True: - self.logger.debug("Fetching routers from Traefik API") - self.events.set_data(await self.fetch()) # type: ignore + self.logger.debug('Fetching routers from Traefik API') + self.events.set_data(await self.fetch()) await self.events.emit() await asyncio.sleep(self.poll_sec) except asyncio.CancelledError: - self.logger.info("Traefik Polling cancelled. Performing cleanup.") + self.logger.info('Traefik Polling cancelled. Performing cleanup.') return @override - async def fetch(self) -> tuple[list[str], PollerSourceType]: - stop = stop_after_attempt(self.config["stop"]) - wait = wait_exponential(multiplier=self.config["wait"], max=self.tout_sec) + async def fetch(self) -> PollerData[PollerSourceType]: + stop = stop_after_attempt(self.config['stop']) + wait = wait_exponential(multiplier=self.config['wait'], max=self.tout_sec) rawdata = [] assert self._client try: @@ -106,10 +107,10 @@ async def fetch(self) -> tuple[list[str], PollerSourceType]: rawdata = response.json() except Exception as err: att = attempt_ctx.retry_state.attempt_number - self.logger.debug(f"Traefik.fetch attempt {att} failed: {err}") + self.logger.debug(f'Traefik.fetch attempt {att} failed: {err}') raise except RetryError as err: last_error = err.last_attempt.result() - self.logger.critical(f"Failed to fetch route from Traefik API: {last_error}") + self.logger.critical(f'Failed to fetch route from Traefik API: {last_error}') # Return a collection of routes return self._validate(rawdata) diff --git a/src/dns_synchub/settings.py b/src/dns_synchub/settings.py index a4625c9..cd6664f 100644 --- a/src/dns_synchub/settings.py +++ b/src/dns_synchub/settings.py @@ -1,54 +1,47 @@ +import logging import re -from typing import Annotated, Literal +from typing import Self -from pydantic import BaseModel, BeforeValidator, model_validator +from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from typing_extensions import Self -# Define the type alias -RecordType = Literal["A", "AAAA", "CNAME"] - - -def validate_ttl(value: int | Literal["auto"]) -> int | Literal["auto"]: - if isinstance(value, int) and value >= 30: - return value - if value == "auto": - return value - raise ValueError("TTL must be at least 30 seconds or 'auto'") - - -TTLType = Annotated[int | str, BeforeValidator(validate_ttl)] - -PollerSourceType = Literal["manual", "docker", "traefik"] - - -class DomainsModel(BaseModel): - name: str - zone_id: str - proxied: bool = True - ttl: TTLType | None = None - target_domain: str | None = None - comment: str | None = None - rc_type: RecordType | None = None - excluded_sub_domains: list[str] = [] +from dns_synchub.types import ( + DomainsModel, + LogHandlersType, + LogLevelType, + RecordType, + TTLType, +) class Settings(BaseSettings): model_config = SettingsConfigDict( validate_default=False, - extra="ignore", - secrets_dir="/var/run", - env_file=(".env", ".env.prod"), - env_file_encoding="utf-8", - env_nested_delimiter="__", + extra='ignore', + secrets_dir='/var/run', + env_file=('.env', '.env.prod'), + env_file_encoding='utf-8', + env_nested_delimiter='__', ) # Settings dry_run: bool = False - log_file: str = "/logs/tcc.log" - log_level: str = "INFO" - log_type: str = "BOTH" - refresh_entries: bool = False + verbose: bool = False + + # Telemetry Settings + service_name: str = 'dns-synchub' + log_level: LogLevelType = logging.INFO + log_handlers: set[LogHandlersType] = {'otlp', 'stderr'} + log_file: str = '/logs/dns-synchub.log' + + @property + def log_formatter(self) -> logging.Formatter: + fmt = '%(asctime)s | %(message)s' + if self.verbose: + fmt = '%(asctime)s %(levelname)s %(lineno)d | %(message)s' + elif logging.DEBUG == self.log_level: + fmt = '%(asctime)s %(levelname)s | %(message)s' + return logging.Formatter(fmt, '%Y-%m-%dT%H:%M:%S%z') # Poller Common settings @@ -64,14 +57,15 @@ class Settings(BaseSettings): traefik_poll_url: str | None = None traefik_poll_seconds: int = 30 # Polling interval in seconds traefik_timeout_seconds: int = 5 # Timeout for blocking requests operations - traefik_excluded_providers: list[str] = ["docker"] + traefik_excluded_providers: list[str] = ['docker'] # Mapper Settings target_domain: str | None = None zone_id: str | None = None - default_ttl: TTLType = "auto" + default_ttl: TTLType = 'auto' proxied: bool = True - rc_type: RecordType = "CNAME" + rc_type: RecordType = 'CNAME' + refresh_entries: bool = False included_hosts: list[re.Pattern[str]] = [] excluded_hosts: list[re.Pattern[str]] = [] @@ -83,7 +77,7 @@ class Settings(BaseSettings): domains: list[DomainsModel] = [] - @model_validator(mode="after") + @model_validator(mode='after') def update_domains(self) -> Self: for dom in self.domains: dom.ttl = dom.ttl or self.default_ttl @@ -92,20 +86,23 @@ def update_domains(self) -> Self: dom.proxied = dom.proxied or self.proxied return self - @model_validator(mode="after") + @model_validator(mode='after') def add_default_include_host(self) -> Self: if len(self.included_hosts) == 0: - self.included_hosts.append(re.compile(".*")) + self.included_hosts.append(re.compile('.*')) return self - @model_validator(mode="after") + @model_validator(mode='after') def sanity_options(self) -> Self: if self.enable_traefik_poll and not self.traefik_poll_url: - raise ValueError("Traefik Polling is enabled but no URL is set") + raise ValueError('Traefik Polling is enabled but no URL is set') return self - @model_validator(mode="after") + @model_validator(mode='after') def enforce_tokens(self) -> Self: if self.dry_run or self.cf_token: return self - raise ValueError("Missing Cloudflare API token. Provide it or enable dry-run mode.") + raise ValueError('Missing Cloudflare API token. Provide it or enable dry-run mode.') + + def __hash__(self) -> int: + return id(self) diff --git a/src/dns_synchub/types.py b/src/dns_synchub/types.py new file mode 100644 index 0000000..f5096a9 --- /dev/null +++ b/src/dns_synchub/types.py @@ -0,0 +1,93 @@ +from abc import abstractmethod +import asyncio +from collections.abc import Coroutine +from dataclasses import dataclass, field +import logging +from typing import ( + Annotated, + Generic, + Literal, + Protocol, + TypeVar, + cast, + runtime_checkable, +) + +from pydantic import BaseModel, BeforeValidator + +# Settings Types + +LogHandlersType = Literal['otlp_console', 'otlp', 'stderr', 'file'] + + +def validate_log_level(value: str | int) -> int: + if isinstance(value, str): + valid_str_levels = {'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'} + if value.upper() not in valid_str_levels: + raise ValueError(f'Invalid log level: {value}. Must be one of {valid_str_levels}.') + return cast(int, getattr(logging, value.upper())) + else: + valid_int_levels = { + logging.CRITICAL, + logging.ERROR, + logging.WARNING, + logging.INFO, + logging.DEBUG, + } + if value not in valid_int_levels: + raise ValueError(f'Invalid log level: {value}. Must be one of {valid_int_levels}.') + return value + + +LogLevelType = Annotated[int, BeforeValidator(validate_log_level)] + +# Poller Types +PollerSourceType = Literal['manual', 'docker', 'traefik'] + +# Mapper Types +RecordType = Literal['A', 'AAAA', 'CNAME'] + + +def validate_ttl(value: int | Literal['auto']) -> int | Literal['auto']: + if isinstance(value, int) and value >= 30: + return value + if value == 'auto': + return value + raise ValueError("TTL must be at least 30 seconds or 'auto'") + + +TTLType = Annotated[int | str, BeforeValidator(validate_ttl)] + + +class DomainsModel(BaseModel): + name: str + zone_id: str + proxied: bool = True + ttl: TTLType | None = None + target_domain: str | None = None + comment: str | None = None + rc_type: RecordType | None = None + excluded_sub_domains: list[str] = [] + + +# Event Types +T = TypeVar('T') + + +@dataclass +class Event(Generic[T]): + klass: type[T] = field(init=False) + data: T + + def __post_init__(self) -> None: + self.klass = type(self.data) + + +@runtime_checkable +class EventSubscriber(Protocol[T]): + @abstractmethod + async def __call__(self, event: Event[T]) -> None: ... + + +EventSubscriberType = Coroutine[None, None, None] | EventSubscriber[T] +EventSubscriberDataType = tuple[asyncio.Queue[Event[T]], float, float] diff --git a/tests/test_cloudflare.py b/tests/test_cloudflare.py index a538141..adb4b9b 100644 --- a/tests/test_cloudflare.py +++ b/tests/test_cloudflare.py @@ -1,59 +1,62 @@ from copy import deepcopy from logging import Logger from typing import Any, cast -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch +from CloudFlare import CloudFlare +from CloudFlare.exceptions import CloudFlareAPIError import pytest -from CloudFlare import CloudFlare # type: ignore -from CloudFlare.exceptions import CloudFlareAPIError # type: ignore -from dns_synchub.mappers import CloudFlareMapper, Settings -from dns_synchub.settings import DomainsModel + +from dns_synchub.mappers.cloudflare import CloudFlareMapper +from dns_synchub.pollers import PollerData +from dns_synchub.settings import Settings +from dns_synchub.types import DomainsModel, Event, PollerSourceType @pytest.fixture -def settings(): +def settings() -> Settings: records: list[DomainsModel] = [] for i in range(1, 5): entry = DomainsModel( - zone_id=f"{i}", - name=f"region{i}.example.ltd", - target_domain=f"target{i}.example.ltd", - comment=f"Test comment {i}", + zone_id=f'{i}', + name=f'region{i}.example.ltd', + target_domain=f'target{i}.example.ltd', + comment=f'Test comment {i}', ) records.append(entry) - return Settings(cf_token="token", dry_run=True, domains=records) + return Settings(cf_token='token', dry_run=True, domains=records) @pytest.fixture -def mock_logger(): +def mock_logger() -> Logger: return MagicMock(spec=Logger) @pytest.fixture -def mock_cf_client(): +def mock_cf_client() -> list[dict[str, Any]]: def create_response(requests: dict[str, Any] | list[dict[str, Any]]) -> list[dict[str, Any]]: requests = requests if isinstance(requests, list) else [requests] response: dict[str, Any] = { - "success": True, - "result": [ + 'success': True, + 'result': [ { **( lambda req: ( - req.setdefault("zone_id", "default_zone_id"), - req.setdefault("ttl", "auto"), + req.setdefault('zone_id', 'default_zone_id'), + req.setdefault('ttl', 'auto'), deepcopy(req), )[2] )(request), - "created_on": "2014-01-01T05:20:00.12345Z", - "modified_on": "2014-01-01T05:20:00.12345Z", - "meta": {"auto_added": True, "source": "primary"}, - "proxiable": True, + 'created_on': '2014-01-01T05:20:00.12345Z', + 'modified_on': '2014-01-01T05:20:00.12345Z', + 'meta': {'auto_added': True, 'source': 'primary'}, + 'proxiable': True, } for request in requests ], } - return response["result"] + return cast(list[dict[str, Any]], response['result']) def filter_response( response: list[dict[str, Any]], params: dict[str, Any] @@ -70,26 +73,26 @@ def filter_response( # Example usage requests: list[dict[str, Any]] = [ { - "content": f"198.51.100.{i}", - "name": f"subdomain{i}.region{i}.example.ltd", - "proxied": False, - "type": "A", - "comment": "Domain verification record", - "id": f"023e105f4ecef8ad9ca31a8372d0c353{i}", - "tags": [], - "ttl": 60, + 'content': f'198.51.100.{i}', + 'name': f'subdomain{i}.region{i}.example.ltd', + 'proxied': False, + 'type': 'A', + 'comment': 'Domain verification record', + 'id': f'023e105f4ecef8ad9ca31a8372d0c353{i}', + 'tags': [], + 'ttl': 60, } for i in range(1, 5) ] - def get_side_effect(_, params: dict[str, Any]) -> list[dict[str, Any]]: + def get_side_effect(_: Any, params: dict[str, Any]) -> list[dict[str, Any]]: return filter_response(create_response(requests), params) def post_side_effect(zone_id: str, data: dict[str, Any]) -> dict[str, Any]: - return create_response({**data, "zone_id": zone_id, "id": "record_id"}).pop() + return create_response({**data, 'zone_id': zone_id, 'id': 'record_id'}).pop() def put_side_effect(zone_id: str, record_id: str, data: dict[str, Any]) -> dict[str, Any]: - return create_response({**data, "zone_id": zone_id, "id": record_id}).pop() + return create_response({**data, 'zone_id': zone_id, 'id': record_id}).pop() cf = MagicMock() cf.zones.dns_records.get.side_effect = get_side_effect @@ -98,13 +101,7 @@ def put_side_effect(zone_id: str, record_id: str, data: dict[str, Any]) -> dict[ return cf -@pytest.fixture(autouse=True) -def mock_tenacity(): - with patch("asyncio.sleep"): - yield - - -def test_init(mock_logger: MagicMock, settings: Settings): +def test_init(mock_logger: MagicMock, settings: Settings) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings) assert mapper.dry_run == settings.dry_run assert mapper.rc_type == settings.rc_type @@ -113,71 +110,76 @@ def test_init(mock_logger: MagicMock, settings: Settings): assert mapper.tout_sec == settings.cf_timeout_seconds assert mapper.sync_sec == settings.cf_sync_seconds - assert isinstance(mapper.client, CloudFlare) + assert isinstance(mapper._client, CloudFlare) mock_logger.debug.assert_called_once() -def test_init_with_client(mock_logger: MagicMock, settings: Settings): - client = CloudFlare() +def test_init_with_client(mock_logger: MagicMock, settings: Settings) -> None: + client = CloudFlare(token='token') mapper = CloudFlareMapper(mock_logger, settings=settings, client=client) - assert mapper.client == client + assert mapper._client == client mock_logger.debug.assert_not_called() @pytest.mark.asyncio -async def test_call(mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock): +async def test_call(mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) - events = (["subdomain.example.ltd"], "manual") - - with patch.object(mapper, "sync", new_callable=AsyncMock) as mock_sync: - await mapper(*events) - mock_sync.assert_called_once_with("subdomain.example.ltd", "manual") + events = PollerData[PollerSourceType](['subdomain.example.ltd'], 'manual') + with patch.object(mapper, 'sync', new_callable=AsyncMock) as mock_sync: + await mapper(Event[PollerData[PollerSourceType]](events)) + mock_sync.assert_called_once_with(events) @pytest.mark.asyncio -async def test_get_records(mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock): - zone_id, name = ("zone_id", "example.ltd") +async def test_get_records( + mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock +) -> None: + zone_id, name = ('zone_id', 'example.ltd') mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) zones = await mapper.get_records(zone_id, name=name) - mock_cf_client.zones.dns_records.get.assert_called_with(zone_id, params={"name": name}) + mock_cf_client.zones.dns_records.get.assert_called_with(zone_id, params={'name': name}) assert len(zones) == 4 @pytest.mark.asyncio -async def test_post_record(mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock): +async def test_post_record( + mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock +) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) - zone_id, zone = "zone_id", {"type": "A", "name": "example.ltd", "content": "1.2.3.4"} + zone_id, zone = 'zone_id', {'type': 'A', 'name': 'example.ltd', 'content': '1.2.3.4'} # Dry run await mapper.post_record(zone_id, **zone) mock_cf_client.zones.dns_records.post.assert_not_called() cast(MagicMock, mapper.logger.info).assert_called_once() - with patch.object(mapper, "dry_run", False): + with patch('asyncio.sleep'), patch.object(mapper, 'dry_run', False): # Client call await mapper.post_record(zone_id, **zone) mock_cf_client.zones.dns_records.post.assert_called_with(zone_id, data=zone) # retry Call - with pytest.raises(CloudFlareAPIError, match="Rate limited"): - rate_error = CloudFlareAPIError(-1, "Rate limited") + with pytest.raises(CloudFlareAPIError, match='Rate limited'): + rate_error = CloudFlareAPIError(-1, 'Rate limited') cast(MagicMock, mock_cf_client.zones.dns_records.post).side_effect = rate_error await mapper.post_record(zone_id, **zone) @pytest.mark.asyncio -async def test_put_record(mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock): +async def test_put_record( + mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock +) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) - zone_id, record_id = "zone_id", "record_id" - zone = {"type": "A", "name": "example.ltd", "content": "1.2.3.4"} + zone_id, record_id = 'zone_id', 'record_id' + zone = {'type': 'A', 'name': 'example.ltd', 'content': '1.2.3.4'} # Dry run call await mapper.put_record(zone_id, record_id, **zone) mock_cf_client.zones.dns_records.put.assert_not_called() cast(MagicMock, mapper.logger.info).assert_called_once() - with patch.object(mapper, "dry_run", False): + with patch.object(mapper, 'dry_run', False): # Client call await mapper.put_record(zone_id, record_id, **zone) mock_cf_client.zones.dns_records.put.assert_called_with(zone_id, record_id, data=zone) @@ -186,12 +188,12 @@ async def test_put_record(mock_logger: MagicMock, settings: Settings, mock_cf_cl @pytest.mark.asyncio async def test_sync_with_target_domain( mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock -): +) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) host = settings.domains[0].target_domain assert isinstance(host, str) - result = await mapper.sync(host, "manual") + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) assert result is None mock_logger.info.assert_not_called() @@ -199,11 +201,11 @@ async def test_sync_with_target_domain( @pytest.mark.asyncio async def test_sync_with_non_subdomain( mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock -): +) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) - host = "nonexistent.example.ltd" + host = 'nonexistent.example.ltd' - result = await mapper.sync(host, "manual") + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) assert result is None mock_logger.info.assert_not_called() @@ -211,44 +213,44 @@ async def test_sync_with_non_subdomain( @pytest.mark.asyncio async def test_sync_with_excluded_domain( mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock -): - settings.domains[0].excluded_sub_domains = ["excluded"] +) -> None: + settings.domains[0].excluded_sub_domains = ['excluded'] mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) - host = f"excluded.{settings.domains[0].name}" + host = f'excluded.{settings.domains[0].name}' - result = await mapper.sync(host, "manual") + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) assert result is None - mock_logger.info.assert_called_with(f"Ignoring {host}: Match excluded sub domain: excluded") + mock_logger.info.assert_called_with(f'Ignoring {host}: Match excluded sub domain: excluded') @pytest.mark.asyncio async def test_sync_with_existing_record( mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock -): +) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) host = settings.domains[0].name mapper.refresh_entries = False - result = await mapper.sync(host, "manual") + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) assert result is not None - mock_logger.info.assert_called_with(f"Record {host} found. Not refreshing. Skipping...") + mock_logger.info.assert_called_with(f'Record {host} found. Not refreshing. Skipping...') @pytest.mark.asyncio async def test_sync_with_record_creation( mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock -): +) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) - host = "newsubdomain.region1.example.ltd" + host = 'newsubdomain.region1.example.ltd' mock_cf_client.zones.dns_records.get.return_value = [] # dry run - result = await mapper.sync(host, "manual") + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) mock_cf_client.zones.dns_records.post.assert_not_called() cast(MagicMock, mapper.logger.info).assert_called_once() - with patch.object(mapper, "dry_run", False): - result = await mapper.sync(host, "manual") + with patch.object(mapper, 'dry_run', False): + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) assert result is not None mock_cf_client.zones.dns_records.post.assert_called_once() @@ -256,31 +258,37 @@ async def test_sync_with_record_creation( @pytest.mark.asyncio async def test_sync_with_record_update( mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock -): +) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) - host = f"subdomain{settings.domains[0].zone_id}.{settings.domains[0].name}" + host = f'subdomain{settings.domains[0].zone_id}.{settings.domains[0].name}' mapper.refresh_entries = True # dry run - result = await mapper.sync(host, "manual") + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) mock_cf_client.zones.dns_records.put.assert_not_called() cast(MagicMock, mapper.logger.info).assert_called_once() - with patch.object(mapper, "dry_run", False): - result = await mapper.sync(host, "manual") - assert isinstance(result, DomainsModel) + with patch.object(mapper, 'dry_run', False): + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) + assert result and isinstance(result.pop(), DomainsModel) mock_cf_client.zones.dns_records.put.assert_called_once() @pytest.mark.asyncio async def test_sync_with_cloudflare_api_error( mock_logger: MagicMock, settings: Settings, mock_cf_client: MagicMock -): +) -> None: mapper = CloudFlareMapper(mock_logger, settings=settings, client=mock_cf_client) - host = f"newsubdomain.{settings.domains[0].name}" + host = f'newsubdomain.{settings.domains[0].name}' - mock_cf_client.zones.dns_records.post.side_effect = CloudFlareAPIError(1000, "API Error") - with patch.object(mapper, "dry_run", False): - result = await mapper.sync(host, "manual") + mock_cf_client.zones.dns_records.post.side_effect = CloudFlareAPIError(1000, 'API Error') + with patch.object(mapper, 'dry_run', False): + result = await mapper.sync(PollerData[PollerSourceType]([host], 'manual')) assert result is None - mock_logger.error.assert_called_with(f"Sync Error for {host}: API Error [Code 1000]") + # Assert + expected_calls = [ + call("Sync failed for 'manual': [1000]"), + call('API Error'), + ] + mock_logger.error.assert_has_calls(expected_calls, any_order=False) + assert mock_logger.error.call_count == 2 diff --git a/tests/test_docker.py b/tests/test_docker.py index 2274866..39ccde3 100644 --- a/tests/test_docker.py +++ b/tests/test_docker.py @@ -1,16 +1,20 @@ import asyncio -import re +from collections.abc import Callable, Generator from logging import Logger +import re from typing import Any, cast -from unittest.mock import MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, call, patch from urllib.parse import urlparse import docker import docker.client import docker.errors import pytest + +from dns_synchub.pollers import PollerData from dns_synchub.pollers.docker import DockerError, DockerPoller from dns_synchub.settings import Settings +from dns_synchub.types import Event, EventSubscriber class MockDockerEvents: @@ -19,26 +23,26 @@ def __init__(self, data: list[dict[str, str]]): self.close = MagicMock() self.reset() - def __iter__(self): + def __iter__(self) -> 'MockDockerEvents': return self - def __next__(self): + def __next__(self) -> dict[str, str]: try: return next(self.iter) except StopIteration: - raise docker.errors.NotFound("No more events") + raise docker.errors.NotFound('No more events') - def reset(self): + def reset(self) -> None: self.iter = iter(self.data) @pytest.fixture -def settings(): - return Settings(cf_token="token", dry_run=True) +def settings() -> Settings: + return Settings(cf_token='token', dry_run=True) @pytest.fixture -def logger(): +def logger() -> Logger: return MagicMock(spec=Logger) @@ -46,40 +50,44 @@ def logger(): def containers() -> dict[str, Any]: data: dict[str, dict[str, Any]] = { str(id_): { - "Id": id_, - "Config": { - "Labels": { - "traefik.http.routers.example.rule": f"Host(`subdomain{id_}.example.ltd`)" + 'Id': id_, + 'Config': { + 'Labels': { + 'traefik.http.routers.example.rule': f'Host(`subdomain{id_}.example.ltd`)' } }, } for id_ in range(1, 5) } - data["1"]["Config"]["Labels"]["traefik.constraint"] = "enable" - data["2"]["Config"]["Labels"]["traefik.constraint"] = "disable" + data['1']['Config']['Labels']['traefik.constraint'] = 'enable' + data['2']['Config']['Labels']['traefik.constraint'] = 'disable' return data @pytest.fixture(autouse=True) -def mock_requests_get(request: pytest.FixtureRequest, containers: dict[str, Any]): - if "skip_mock_requests_get" in request.keywords: - yield - return - with patch("requests.Session.get") as mock_get: - - def side_effect(url: str, *args: Any, **kwargs: dict[str, Any]): +def mock_requests_get( + request: pytest.FixtureRequest, containers: dict[str, Any] +) -> Generator[Any, None, Any] | Callable[..., Any]: + for mark in request.node.iter_markers(): + if mark.name == 'skip_fixture' and request.fixturename in mark.args: + yield + return + with patch('requests.Session.get') as mock_get: + + def side_effect(url: str, *args: Any, **kwargs: dict[str, Any]) -> MagicMock: + return_value: dict[str, Any] | list[dict[str, Any]] | None = None # Process URLs match urlparse(url).path: - case "/version": - return_value = {"ApiVersion": "1.41"} - case "/v1.41/info": - return_value = {"Name": "Mock Docker"} - case "/v1.41/containers/json": - return_value = [{"Id": id_} for id_ in containers.keys()] - case details if (match := re.search(r"/v1.41/containers/([^/]+)/json", details)): - return_value = containers[match.group(1)] # type: ignore + case '/version': + return_value = {'ApiVersion': '1.41'} + case '/v1.41/info': + return_value = {'Name': 'Mock Docker'} + case '/v1.41/containers/json': + return_value = [{'Id': id_} for id_ in containers.keys()] + case details if (match := re.search(r'/v1.41/containers/([^/]+)/json', details)): + return_value = containers[match.group(1)] case other: - raise AssertionError(f"Unexpected URL: {other}") + raise AssertionError(f'Unexpected URL: {other}') # Create a MagicMock object to mock the response response = MagicMock() @@ -91,27 +99,26 @@ def side_effect(url: str, *args: Any, **kwargs: dict[str, Any]): @pytest.fixture -@pytest.mark.usefixtures("mock_requests_get") -def docker_poller(logger: MagicMock, settings: Settings, containers: dict[str, Any]): - events = [{"status": "start", "id": id_} for id_ in containers.keys()] - docker_client = docker.DockerClient(base_url="unix:///") - with patch.object(docker_client, "events", return_value=MockDockerEvents(events)): +def docker_poller( + logger: MagicMock, settings: Settings, containers: dict[str, Any] +) -> Generator[DockerPoller, None, None]: + events = [{'status': 'start', 'id': id_} for id_ in containers.keys()] + docker_client = docker.DockerClient(base_url='unix:///') + with patch.object(docker_client, 'events', return_value=MockDockerEvents(events)): yield DockerPoller(logger, settings=settings, client=docker_client) -@pytest.mark.skip_mock_requests_get +@pytest.mark.skip_fixture('mock_requests_get') def test_docker_init_with_bad_engine( - logger: MagicMock, - settings: Settings, - monkeypatch: pytest.MonkeyPatch, -): + logger: MagicMock, settings: Settings, monkeypatch: pytest.MonkeyPatch +) -> None: with pytest.raises(DockerError) as err: - monkeypatch.setenv("DOCKER_HOST", "unix:///") + monkeypatch.setenv('DOCKER_HOST', 'unix:///') DockerPoller(logger, settings=settings).client - assert str(err.value) == "Could not connect to Docker" + assert str(err.value) == 'Could not connect to Docker' -def test_init(logger: MagicMock, settings: Settings): +def test_init(logger: MagicMock, settings: Settings) -> None: poller = DockerPoller(logger, settings=settings) assert poller.poll_sec == settings.docker_poll_seconds assert poller.tout_sec == settings.docker_timeout_seconds @@ -119,89 +126,96 @@ def test_init(logger: MagicMock, settings: Settings): assert poller.filter_value == settings.docker_filter_value -def test_init_from_env(logger: MagicMock, settings: Settings): +def test_init_from_env(logger: MagicMock, settings: Settings) -> None: poller = DockerPoller(logger, settings=settings) assert isinstance(poller.client, docker.DockerClient) -def test_init_from_client(logger: MagicMock, settings: Settings): - client = docker.DockerClient(base_url="unix:///") +def test_init_from_client(logger: MagicMock, settings: Settings) -> None: + client = docker.DockerClient(base_url='unix:///') poller = DockerPoller(logger, settings=settings, client=client) assert poller.client == client @pytest.mark.asyncio -async def test_fetch(docker_poller: DockerPoller): - hosts, source = await docker_poller.fetch() - assert source == "docker" - assert hosts == [f"subdomain{i}.example.ltd" for i in range(1, 5)] +async def test_fetch(docker_poller: DockerPoller) -> None: + data = await docker_poller.fetch() + assert data.source == 'docker' + assert data.hosts == [f'subdomain{i}.example.ltd' for i in range(1, 5)] @pytest.mark.asyncio -async def test_fetch_filter_by_label(docker_poller: DockerPoller): - docker_poller.filter_label = re.compile(r"traefik.constraint") - hosts, source = await docker_poller.fetch() - assert source == "docker" - assert hosts == [f"subdomain{i}.example.ltd" for i in range(1, 3)] +async def test_fetch_filter_by_label(docker_poller: DockerPoller) -> None: + docker_poller.filter_label = re.compile(r'traefik.constraint') + data = await docker_poller.fetch() + assert data.source == 'docker' + assert data.hosts == [f'subdomain{i}.example.ltd' for i in range(1, 3)] @pytest.mark.asyncio -async def test_fetch_filter_by_value(docker_poller: DockerPoller): - docker_poller.filter_label = re.compile(r"traefik.constraint") - docker_poller.filter_value = re.compile(r"enable") - hosts, source = await docker_poller.fetch() - assert source == "docker" - assert hosts == [f"subdomain{i}.example.ltd" for i in range(1, 2)] +async def test_fetch_filter_by_value(docker_poller: DockerPoller) -> None: + docker_poller.filter_label = re.compile(r'traefik.constraint') + docker_poller.filter_value = re.compile(r'enable') + data = await docker_poller.fetch() + assert data.source == 'docker' + assert data.hosts == [f'subdomain{i}.example.ltd' for i in range(1, 2)] @pytest.mark.asyncio -async def test_run(docker_poller: DockerPoller): - callback_mock = MagicMock() +async def test_run(docker_poller: DockerPoller) -> None: + callback_mock = MagicMock(spec=EventSubscriber) + callback_mock.__call__ = AsyncMock(return_value=None) # type: ignore + await docker_poller.events.subscribe(callback_mock) assert 0 == callback_mock.call_count - await docker_poller.run(timeout=0.1) - # Check timeout was reached + await docker_poller.start(timeout=0.1) logger = cast(MagicMock, docker_poller.logger) - assert any("Run timeout" in str(arg) for arg in logger.info.call_args_list) + assert any('Run timeout' in str(arg) for arg in logger.info.call_args_list) # Docker Client asserts - docker_client_events = cast(MagicMock, docker_poller.client.events) # type: ignore + await asyncio.gather(docker_poller.start(), docker_poller.stop()) + docker_client_events = cast(MagicMock, docker_poller.client.events) docker_client_events.assert_called_once() docker_client_events.return_value.close.assert_called_once() # Check callback calls. First run will fetch all containers plus events - expected_calls = [call([f"subdomain{i}.example.ltd"], "docker") for i in range(1, 5)] - expected_calls.insert(0, call([f"subdomain{i}.example.ltd" for i in range(1, 5)], "docker")) - assert callback_mock.call_count == len(expected_calls) - callback_mock.assert_has_calls(expected_calls, any_order=False) + expected_calls = ( + [] + + [call(Event(PollerData([f'subdomain{i}.example.ltd' for i in range(1, 5)], 'docker')))] + + [call(Event(PollerData([f'subdomain{i}.example.ltd'], 'docker'))) for i in range(1, 5)] + ) + assert callback_mock.__call__.call_count == len(expected_calls) + callback_mock.__call__.assert_has_calls(expected_calls, any_order=False) # Check the rest of the runs will not perform a fetch expected_calls.pop(0) - callback_mock.reset_mock() + callback_mock.__call__.reset_mock() docker_client_events.return_value.reset() - await docker_poller.run(timeout=0.1) - assert callback_mock.call_count == len(expected_calls) - callback_mock.assert_has_calls(expected_calls, any_order=False) + loop = asyncio.get_event_loop() + loop.call_later(0.1, lambda: asyncio.create_task(docker_poller.stop())) + await docker_poller.start() + assert callback_mock.__call__.call_count == len(expected_calls) + callback_mock.__call__.assert_has_calls(expected_calls, any_order=False) @pytest.mark.asyncio -async def test_run_canceled(docker_poller: DockerPoller): +async def test_run_canceled(docker_poller: DockerPoller) -> None: async def cancel(task: asyncio.Task[Any]) -> None: await asyncio.sleep(0.1) task.cancel() - poller_task = asyncio.create_task(docker_poller.run()) + poller_task = asyncio.create_task(docker_poller.start()) tasks = [poller_task, asyncio.create_task(cancel(poller_task))] await asyncio.gather(*tasks) # Check timeout was reached # Check timeout was reached logger = cast(MagicMock, docker_poller.logger) - logger.info.assert_any_call("DockerPoller: Run was cancelled") + logger.info.assert_any_call('DockerPoller: Run was cancelled') # Docker Client asserts - docker_client_events = cast(MagicMock, docker_poller.client.events) # type: ignore + docker_client_events = cast(MagicMock, docker_poller.client.events) docker_client_events.assert_called_once() docker_client_events.return_value.close.assert_called_once() diff --git a/tests/test_mananger.py b/tests/test_mananger.py deleted file mode 100644 index 450b77b..0000000 --- a/tests/test_mananger.py +++ /dev/null @@ -1,131 +0,0 @@ -import asyncio -from logging import Logger -from typing import cast -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -import pytest - -from dns_synchub.events import EventEmitter -from dns_synchub.manager import DataManager -from dns_synchub.mappers import BaseMapper -from dns_synchub.pollers import BasePoller - - -@pytest.fixture -def mock_logger(): - return MagicMock(spec=Logger) - - -@pytest.fixture -def mock_poller(): - poller = MagicMock(spec=BasePoller) - poller.events = MagicMock(spec=EventEmitter) - poller.run = AsyncMock() - return poller - - -@pytest.fixture -def mock_poller_infinity(mock_poller: MagicMock): - async def run_indenfinitely(timeout: float): - while True: - await asyncio.sleep(1) - - # Patch the run method to run indefinitely - mock_poller.run = AsyncMock(side_effect=run_indenfinitely) - return mock_poller - - -@pytest.fixture -def mock_mapper(): - return MagicMock(spec=BaseMapper) - - -@pytest.fixture -def data_manager(mock_logger: MagicMock): - class MockList(list[asyncio.Task[None]]): - def clear(self): ... - - manager = DataManager(logger=mock_logger) - manager.tasks = MockList() - manager.tasks.clear = MagicMock(spec=list[asyncio.Task[None]].clear, return_value=None) - return manager - - -def test_initialization(data_manager: DataManager, mock_logger: Logger): - assert data_manager.logger == mock_logger - assert data_manager.tasks == [] - assert isinstance(data_manager.pollers, set) - assert isinstance(data_manager.mappers, EventEmitter) - - -def test_add_poller(data_manager: DataManager, mock_poller: AsyncMock): - data_manager.add_poller(mock_poller, backoff=5.0) - assert (mock_poller, 5.0) in data_manager.pollers - - -@pytest.mark.asyncio -async def test_add_mapper(data_manager: DataManager, mock_mapper: MagicMock): - data_manager.mappers = MagicMock(spec=EventEmitter) - await data_manager.add_mapper(mock_mapper, backoff=5.0) - data_manager.mappers.subscribe.assert_called_once_with(mock_mapper, backoff=5.0) - - -@pytest.mark.asyncio -async def test_start(data_manager: DataManager, mock_poller: AsyncMock, mock_mapper: MagicMock): - mock_event_emitter = MagicMock(spec=EventEmitter) - mock_event_emitter.__len__.return_value = 1 - data_manager.mappers = mock_event_emitter - await data_manager.add_mapper(mock_mapper, backoff=5.0) - data_manager.add_poller(mock_poller, backoff=5.0) - # Patch the clear method of the specific list instance - await data_manager.start(timeout=10.0) - mock_poller.run.assert_awaited_once_with(timeout=10.0) - assert len(data_manager.tasks) == 2 # One for poller and one for mappers.emit - cast(MagicMock, data_manager.tasks.clear).assert_called_once() - - -@pytest.mark.asyncio -async def test_start_no_mapper(data_manager: DataManager, mock_poller: AsyncMock): - data_manager.add_poller(mock_poller, backoff=5.0) - await data_manager.start(timeout=10.0) - mock_poller.run.assert_awaited_once_with(timeout=10.0) - assert len(data_manager.tasks) == 1 # Only one task for poller - - -@pytest.mark.asyncio -async def test_start_no_pollers(data_manager: DataManager, mock_mapper: MagicMock): - mock_event_emitter = MagicMock(spec=EventEmitter) - mock_event_emitter.__len__.return_value = 1 - data_manager.mappers = mock_event_emitter - - await data_manager.add_mapper(mock_mapper, backoff=5.0) - await data_manager.start(timeout=10.0) - data_manager.mappers.emit.assert_awaited_once_with(timeout=10.0) - assert len(data_manager.tasks) == 1 # Only one task for mappers.emit - - -@pytest.mark.asyncio -async def test_start_cancel_interrupt(data_manager: DataManager, mock_poller_infinity: AsyncMock): - data_manager.add_poller(mock_poller_infinity, backoff=5.0) - with patch.object(data_manager, "stop", new_callable=AsyncMock) as mock_stop: - with patch("asyncio.gather", side_effect=asyncio.exceptions.CancelledError): - await data_manager.start(timeout=10.0) - mock_stop.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_stop(data_manager: DataManager, mock_poller_infinity: AsyncMock): - # Restore patched tasks list to a valid one - data_manager.tasks = [] - data_manager.add_poller(mock_poller_infinity, backoff=5.0) - asyncio.create_task(data_manager.start(timeout=10.0)) - await asyncio.sleep(0.1) - await data_manager.stop() - assert len(data_manager.tasks) == 0 - - -def test_aggregate_data(data_manager: DataManager, mock_poller: AsyncMock): - data_manager.add_poller(mock_poller, backoff=5.0) - mock_poller.events.get_data = Mock(return_value=(["value"], "source")) - combined_data = data_manager.aggregate_data() - assert combined_data == {"source": ["value"]} diff --git a/tests/test_traefik.py b/tests/test_traefik.py index cba3a7f..ca18dba 100644 --- a/tests/test_traefik.py +++ b/tests/test_traefik.py @@ -1,44 +1,46 @@ +from collections.abc import Generator from logging import Logger from unittest.mock import MagicMock, patch import pytest + from dns_synchub.pollers.traefik import TraefikPoller from dns_synchub.settings import Settings @pytest.fixture -def settings(): - return Settings(cf_token="token", dry_run=True) +def settings() -> Settings: + return Settings(cf_token='token', dry_run=True) @pytest.fixture -def mock_logger(): +def mock_logger() -> Logger: return MagicMock(spec=Logger) @pytest.fixture -def mock_api_no_routers(): - with patch("requests.Session.get") as mock_get: +def mock_api_no_routers() -> Generator[MagicMock, None, None]: + with patch('requests.Session.get') as mock_get: mock_get.return_value.ok = True mock_get.return_value.json.return_value = [] yield mock_get @pytest.fixture -def traefik_poller(mock_logger: MagicMock, settings: Settings): +def traefik_poller(mock_logger: MagicMock, settings: Settings) -> TraefikPoller: return TraefikPoller(mock_logger, settings=settings) -def test_init(mock_logger: MagicMock, settings: Settings): +def test_init(mock_logger: MagicMock, settings: Settings) -> None: poller = TraefikPoller(mock_logger, settings=settings) assert poller.poll_sec == settings.traefik_poll_seconds assert poller.tout_sec == settings.traefik_timeout_seconds - assert poller.poll_url == f"{settings.traefik_poll_url}/api/http/routers" - assert "docker" in poller.excluded_providers + assert poller.poll_url == f'{settings.traefik_poll_url}/api/http/routers' + assert 'docker' in poller.excluded_providers @pytest.mark.asyncio -async def test_no_routers(traefik_poller: TraefikPoller, mock_api_no_routers: MagicMock): - hosts, source = await traefik_poller.fetch() - assert source == "traefik" - assert hosts == [] +async def test_no_routers(traefik_poller: TraefikPoller, mock_api_no_routers: MagicMock) -> None: + data = await traefik_poller.fetch() + assert data.source == 'traefik' + assert data.hosts == [] diff --git a/types/CloudFlare/__init__.pyi b/types/CloudFlare/__init__.pyi new file mode 100644 index 0000000..fb521d5 --- /dev/null +++ b/types/CloudFlare/__init__.pyi @@ -0,0 +1,15 @@ +from typing import Any + +class CloudFlare: + def __init__(self, token: str, debug: bool = False) -> None: ... + + class zones: + class dns_records: + @staticmethod + def get(zone_id: str, params: dict[str, str] | None = None) -> list[dict[str, Any]]: ... + @staticmethod + def post(zone_id: str, data: dict[str, str]) -> dict[str, Any]: ... + @staticmethod + def put(zone_id: str, record_id: str, data: dict[str, str]) -> dict[str, Any]: ... + +__all__ = ['CloudFlare'] diff --git a/types/CloudFlare/exceptions.pyi b/types/CloudFlare/exceptions.pyi new file mode 100644 index 0000000..9a22485 --- /dev/null +++ b/types/CloudFlare/exceptions.pyi @@ -0,0 +1,23 @@ +from typing import Any + +class CloudFlareError(Exception): + class _CodeMessage: + def __init__(self, code: int, message: str) -> None: ... + def __int__(self) -> int: ... + + def __init__( + self, + code: int = 0, + message: str | Any | None = None, + error_chain: Any | None = None, + e: Exception | None = None, + ) -> None: ... + def __bool__(self) -> bool: ... + def __int__(self) -> int: ... + def __len__(self) -> int: ... + def __getitem__(self, ii: int) -> Any: ... + def __iter__(self) -> Any: ... + def __next__(self) -> Any: ... + +class CloudFlareAPIError(CloudFlareError): ... +class CloudFlareInternalError(CloudFlareError): ...