diff --git a/src/open_inwoner/openzaak/clients.py b/src/open_inwoner/openzaak/clients.py index 72b2d52226..b8976c8f58 100644 --- a/src/open_inwoner/openzaak/clients.py +++ b/src/open_inwoner/openzaak/clients.py @@ -1,8 +1,11 @@ import base64 +import concurrent.futures import logging import warnings +from dataclasses import dataclass from datetime import date -from typing import Literal, Mapping, Type, TypeAlias +from functools import cached_property +from typing import Any, Literal, Mapping, Type, TypeAlias, TypeVar from django.conf import settings from django.core.files.uploadedfile import InMemoryUploadedFile @@ -14,11 +17,13 @@ from zgw_consumers.api_models.catalogi import Catalogus from zgw_consumers.api_models.constants import RolOmschrijving, RolTypes from zgw_consumers.client import build_client +from zgw_consumers.concurrent import parallel from zgw_consumers.constants import APITypes from zgw_consumers.models import Service from zgw_consumers.service import pagination_helper from open_inwoner.openzaak.api_models import InformatieObject +from open_inwoner.openzaak.exceptions import MultiZgwClientProxyError from open_inwoner.utils.api import ClientError, get_json_response from ..utils.decorators import cache as cache_result @@ -51,6 +56,9 @@ def __init__(self, *args, **kwargs): self.configured_from = kwargs.pop("configured_from") super().__init__(*args, **kwargs) + def __str__(self): + return f"Client {self.__class__.__name__} for {self.base_url}" + class ZakenClient(ZgwAPIClient): def fetch_cases( @@ -694,6 +702,110 @@ def fetch_open_tasks(self, bsn: str) -> list[OpenTask]: return results +TClient = TypeVar("TClient", bound=APIClient) + + +@dataclass(frozen=True) +class ZgwClientResponse: + """A single response in a MultiZgwClientResult.""" + + client: TClient + result: Any + exception: Exception | None = None + + +@dataclass(frozen=True) +class MultiZgwClientProxyResult: + """Container for a multi-backend responses""" + + responses: list[ZgwClientResponse] + + @cached_property + def has_errors(self) -> bool: + return any(r.exception is not None for r in self.responses) + + @cached_property + def failing_responses(self) -> list[ZgwClientResponse]: + return list(r for r in self if r.exception is not None) + + @cached_property + def successful_responses(self) -> list[ZgwClientResponse]: + return list(r for r in self if r.exception is None) + + def raise_on_failures(self): + """Raise a MultiZgwClientProxyError wrapping all errors raised by the clients.""" + if not self.has_errors: + return + + raise MultiZgwClientProxyError([r.exception for r in self.failing_responses]) + + def join_results(self): + """Joins results for all successful responses in a list.""" + return list( + result for row in self.successful_responses for result in row.result + ) + + def __iter__(self): + yield from self.responses + + +class MultiZgwClientProxy: + """A proxy to call the same method on multiple ZGW clients in parallel.""" + + clients: list[TClient] = [] + + def __init__(self, clients: list[TClient]): + self.clients = clients + + if len(clients) == 0: + raise ValueError("You must specify at least one client") + + def _call_method(self, method, *args, **kwargs) -> MultiZgwClientProxyResult: + if not all(hasattr(client, method) for client in self.clients): + raise AttributeError(f"Method `{method}` does not exist on the clients") + + with parallel() as executor: + futures_mapping: Mapping[concurrent.futures.Future, TClient] = {} + for client in self.clients: + future = executor.submit( + getattr(client, method), + *args, + **kwargs, + ) + # Remember which future corresponds to which client, + # so we can associate them in the response + futures_mapping[future] = client + + responses: list[ZgwClientResponse] = [] + for task in concurrent.futures.as_completed(futures_mapping.keys()): + result: Any | None = None + exception: Exception | None = None + try: + result: Any = task.result() + except BaseException: + exception = task.exception() + + responses.append( + ZgwClientResponse( + result=result, exception=exception, client=futures_mapping[task] + ) + ) + + # Ensure the response list is deterministic, based on the client order. + # This is mainly useful for testing but also generally promotes consistent + # behavior. + responses.sort( + key=lambda r: self.clients.index(r.client), + ) + return MultiZgwClientProxyResult(responses=responses) + + def __getattr__(self, name): + def wrapper(*args, **kwargs): + return self._call_method(name, *args, **kwargs) + + return wrapper + + ZgwClientType = Literal["zaak", "catalogi", "document", "form"] ZgwClientFactoryReturn: TypeAlias = ( ZakenClient | CatalogiClient | DocumentenClient | FormClient diff --git a/src/open_inwoner/openzaak/exceptions.py b/src/open_inwoner/openzaak/exceptions.py index 63de439d02..ff20d4b4f5 100644 --- a/src/open_inwoner/openzaak/exceptions.py +++ b/src/open_inwoner/openzaak/exceptions.py @@ -1,3 +1,6 @@ +from typing import Sequence + + class InvalidAuth(Exception): pass @@ -12,3 +15,18 @@ class InvalidAuthForClientID(InvalidAuth): def __init__(self, client_id): self.client_id = client_id super().__init__(f"secret invalid for subscription client_id'{client_id}'") + + +class MultiZgwClientProxyError(ExceptionGroup): + """A container for exceptions raised within individual client requests.""" + + def __new__(cls, exceptions: Sequence[Exception]): + self = super().__new__( + MultiZgwClientProxyError, + "One or more ZGW clients raised an error", + exceptions, + ) + return self + + def derive(self, exceptions: Sequence[Exception]) -> "MultiZgwClientProxyError": + return MultiZgwClientProxyError(exceptions) diff --git a/src/open_inwoner/openzaak/tests/test_clients.py b/src/open_inwoner/openzaak/tests/test_clients.py index 67dfae0f0b..06c28441fa 100644 --- a/src/open_inwoner/openzaak/tests/test_clients.py +++ b/src/open_inwoner/openzaak/tests/test_clients.py @@ -1,8 +1,15 @@ +from unittest import TestCase as PlainTestCase + from django.test import TestCase +import requests +import requests_mock from zgw_consumers.constants import APITypes from open_inwoner.openzaak.clients import ( + MultiZgwClientProxy, + MultiZgwClientProxyResult, + ZgwClientResponse, build_catalogi_client, build_catalogi_clients, build_documenten_client, @@ -12,6 +19,7 @@ build_zaken_client, build_zaken_clients, ) +from open_inwoner.openzaak.exceptions import MultiZgwClientProxyError from open_inwoner.openzaak.tests.factories import ZGWApiGroupConfigFactory @@ -56,3 +64,120 @@ def test_originating_service_is_persisted_on_all_clients(self): getattr(self.api_groups[i], api_group_field), ) self.assertEqual(client.configured_from.api_type, api_type) + + +@requests_mock.Mocker() +class MultiZgwClientProxyTests(PlainTestCase): + class SimpleClient: + def __init__(self, url): + self.url = url + + def fetch_rows(self): + resp = requests.get(self.url) + return resp.json() + + def setUp(self): + self.a_client = self.SimpleClient("http://foo/bar/rows") + self.another_client = self.SimpleClient("http://bar/baz/rows") + + def test_accessing_non_existent_methods_raises(self, m): + proxy = MultiZgwClientProxy([self.a_client, self.another_client]) + + with self.assertRaises(AttributeError) as cm: + proxy.non_existent() + + self.assertEqual( + str(cm.exception), "Method `non_existent` does not exist on the clients" + ) + + def test_all_successful_responses_are_returned(self, m): + m.get(self.a_client.url, json=["foo", "bar"]) + m.get(self.another_client.url, json=["bar", "baz"]) + proxy = MultiZgwClientProxy([self.a_client, self.another_client]) + + result = proxy.fetch_rows() + + self.assertEqual( + result, + MultiZgwClientProxyResult( + responses=[ + ZgwClientResponse( + client=self.a_client, result=["foo", "bar"], exception=None + ), + ZgwClientResponse( + client=self.another_client, + result=["bar", "baz"], + exception=None, + ), + ] + ), + ) + self.assertEqual(result.responses, result.successful_responses) + self.assertEqual(result.failing_responses, []) + self.assertEqual(result.join_results(), ["foo", "bar", "bar", "baz"]) + self.assertEqual( + result.raise_on_failures(), + None, + msg="raise_on_failures is a noop if all responses are successful", + ) + + def test_partial_exceptions_are_returned(self, m): + m.get(self.a_client.url, json=["foo", "bar"]) + # Second client will raise an exception + m.get(self.another_client.url, exc=requests.exceptions.Timeout) + + proxy = MultiZgwClientProxy([self.a_client, self.another_client]) + + result = proxy.fetch_rows() + successful_response, failing_response = result.responses + self.assertEqual( + successful_response, + ZgwClientResponse( + client=self.a_client, result=["foo", "bar"], exception=None + ), + ) + self.assertEqual( + result.join_results(), + ["foo", "bar"], + msg="Only successful results should be joined", + ) + + # It's non-trivial to compare exceptions on equality, so we have to validate the + # object in steps + self.assertEqual( + (failing_response.client, failing_response.result), + (self.another_client, None), + ) + self.assertIsInstance(failing_response.exception, requests.exceptions.Timeout) + self.assertEqual(result.has_errors, True) + + with self.assertRaises(MultiZgwClientProxyError) as cm: + result.raise_on_failures() + + self.assertTrue( + [e.__class__ for e in cm.exception.exceptions], + [requests.exceptions.Timeout], + ) + self.assertEqual(len(result.failing_responses), 1) + + def test_result_can_be_used_as_response_iterator(self, m): + m.get(self.a_client.url, json=["foo", "bar"]) + m.get(self.another_client.url, json=["bar", "baz"]) + proxy = MultiZgwClientProxy([self.a_client, self.another_client]) + + result = proxy.fetch_rows() + + # Verify that each invocation yields a fresh generator with a clean iteration state + for _ in range(2): + self.assertEqual( + [row.result for row in result], [["foo", "bar"], ["bar", "baz"]] + ) + + def test_response_iterator_includes_failing_responses(self, m): + m.get(self.a_client.url, json=["foo", "bar"]) + m.get(self.another_client.url, exc=requests.exceptions.Timeout) + + proxy = MultiZgwClientProxy([self.a_client, self.another_client]) + result = proxy.fetch_rows() + + self.assertEqual([row.exception is None for row in result], [True, False])