Skip to content

Commit

Permalink
[#2537] Introduce a proxy class to invoke multiple ZGW clients in par…
Browse files Browse the repository at this point in the history
…allel
  • Loading branch information
swrichards committed Jun 27, 2024
1 parent 8bc269f commit 2d3ad96
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 1 deletion.
114 changes: 113 additions & 1 deletion src/open_inwoner/openzaak/clients.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import base64
import concurrent.futures
import logging
import warnings
from dataclasses import dataclass
from datetime import date
from typing import Literal, Mapping, Type, TypeAlias
from functools import cached_property
from typing import Any, Literal, Mapping, Type, TypeAlias, TypeVar

from django.conf import settings
from django.core.files.uploadedfile import InMemoryUploadedFile
Expand All @@ -14,11 +17,13 @@
from zgw_consumers.api_models.catalogi import Catalogus
from zgw_consumers.api_models.constants import RolOmschrijving, RolTypes
from zgw_consumers.client import build_client
from zgw_consumers.concurrent import parallel
from zgw_consumers.constants import APITypes
from zgw_consumers.models import Service
from zgw_consumers.service import pagination_helper

from open_inwoner.openzaak.api_models import InformatieObject
from open_inwoner.openzaak.exceptions import MultiZgwClientProxyError
from open_inwoner.utils.api import ClientError, get_json_response

from ..utils.decorators import cache as cache_result
Expand Down Expand Up @@ -51,6 +56,9 @@ def __init__(self, *args, **kwargs):
self.configured_from = kwargs.pop("configured_from")
super().__init__(*args, **kwargs)

def __str__(self):
return f"Client {self.__class__.__name__} for {self.base_url}"


class ZakenClient(ZgwAPIClient):
def fetch_cases(
Expand Down Expand Up @@ -694,6 +702,110 @@ def fetch_open_tasks(self, bsn: str) -> list[OpenTask]:
return results


TClient = TypeVar("TClient", bound=APIClient)


@dataclass(frozen=True)
class ZgwClientResponse:
"""A single response in a MultiZgwClientResult."""

client: TClient
result: Any
exception: Exception | None = None


@dataclass(frozen=True)
class MultiZgwClientProxyResult:
"""Container for a multi-backend responses"""

responses: list[ZgwClientResponse]

@cached_property
def has_errors(self) -> bool:
return any(r.exception is not None for r in self.responses)

@cached_property
def failing_responses(self) -> list[ZgwClientResponse]:
return list(r for r in self if r.exception is not None)

@cached_property
def successful_responses(self) -> list[ZgwClientResponse]:
return list(r for r in self if r.exception is None)

def raise_on_failures(self):
"""Raise a MultiZgwClientProxyError wrapping all errors raised by the clients."""
if not self.has_errors:
return

raise MultiZgwClientProxyError([r.exception for r in self.failing_responses])

def join_results(self):
"""Joins results for all successful responses in a list."""
return list(
result for row in self.successful_responses for result in row.result
)

def __iter__(self):
yield from self.responses


class MultiZgwClientProxy:
"""A proxy to call the same method on multiple ZGW clients in parallel."""

clients: list[TClient] = []

def __init__(self, clients: list[TClient]):
self.clients = clients

if len(clients) == 0:
raise ValueError("You must specify at least one client")

def _call_method(self, method, *args, **kwargs) -> MultiZgwClientProxyResult:
if not all(hasattr(client, method) for client in self.clients):
raise AttributeError(f"Method `{method}` does not exist on the clients")

with parallel() as executor:
futures_mapping: Mapping[concurrent.futures.Future, TClient] = {}
for client in self.clients:
future = executor.submit(
getattr(client, method),
*args,
**kwargs,
)
# Remember which future corresponds to which client,
# so we can associate them in the response
futures_mapping[future] = client

responses: list[ZgwClientResponse] = []
for task in concurrent.futures.as_completed(futures_mapping.keys()):
result: Any | None = None
exception: Exception | None = None
try:
result: Any = task.result()
except BaseException:
exception = task.exception()

responses.append(
ZgwClientResponse(
result=result, exception=exception, client=futures_mapping[task]
)
)

# Ensure the response list is deterministic, based on the client order.
# This is mainly useful for testing but also generally promotes consistent
# behavior.
responses.sort(
key=lambda r: self.clients.index(r.client),
)
return MultiZgwClientProxyResult(responses=responses)

def __getattr__(self, name):
def wrapper(*args, **kwargs):
return self._call_method(name, *args, **kwargs)

return wrapper


ZgwClientType = Literal["zaak", "catalogi", "document", "form"]
ZgwClientFactoryReturn: TypeAlias = (
ZakenClient | CatalogiClient | DocumentenClient | FormClient
Expand Down
18 changes: 18 additions & 0 deletions src/open_inwoner/openzaak/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Sequence


class InvalidAuth(Exception):
pass

Expand All @@ -12,3 +15,18 @@ class InvalidAuthForClientID(InvalidAuth):
def __init__(self, client_id):
self.client_id = client_id
super().__init__(f"secret invalid for subscription client_id'{client_id}'")


class MultiZgwClientProxyError(ExceptionGroup):
"""A container for exceptions raised within individual client requests."""

def __new__(cls, exceptions: Sequence[Exception]):
self = super().__new__(
MultiZgwClientProxyError,
"One or more ZGW clients raised an error",
exceptions,
)
return self

def derive(self, exceptions: Sequence[Exception]) -> "MultiZgwClientProxyError":
return MultiZgwClientProxyError(exceptions)
125 changes: 125 additions & 0 deletions src/open_inwoner/openzaak/tests/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from unittest import TestCase as PlainTestCase

from django.test import TestCase

import requests
import requests_mock
from zgw_consumers.constants import APITypes

from open_inwoner.openzaak.clients import (
MultiZgwClientProxy,
MultiZgwClientProxyResult,
ZgwClientResponse,
build_catalogi_client,
build_catalogi_clients,
build_documenten_client,
Expand All @@ -12,6 +19,7 @@
build_zaken_client,
build_zaken_clients,
)
from open_inwoner.openzaak.exceptions import MultiZgwClientProxyError
from open_inwoner.openzaak.tests.factories import ZGWApiGroupConfigFactory


Expand Down Expand Up @@ -56,3 +64,120 @@ def test_originating_service_is_persisted_on_all_clients(self):
getattr(self.api_groups[i], api_group_field),
)
self.assertEqual(client.configured_from.api_type, api_type)


@requests_mock.Mocker()
class MultiZgwClientProxyTests(PlainTestCase):
class SimpleClient:
def __init__(self, url):
self.url = url

def fetch_rows(self):
resp = requests.get(self.url)
return resp.json()

def setUp(self):
self.a_client = self.SimpleClient("http://foo/bar/rows")
self.another_client = self.SimpleClient("http://bar/baz/rows")

def test_accessing_non_existent_methods_raises(self, m):
proxy = MultiZgwClientProxy([self.a_client, self.another_client])

with self.assertRaises(AttributeError) as cm:
proxy.non_existent()

self.assertEqual(
str(cm.exception), "Method `non_existent` does not exist on the clients"
)

def test_all_successful_responses_are_returned(self, m):
m.get(self.a_client.url, json=["foo", "bar"])
m.get(self.another_client.url, json=["bar", "baz"])
proxy = MultiZgwClientProxy([self.a_client, self.another_client])

result = proxy.fetch_rows()

self.assertEqual(
result,
MultiZgwClientProxyResult(
responses=[
ZgwClientResponse(
client=self.a_client, result=["foo", "bar"], exception=None
),
ZgwClientResponse(
client=self.another_client,
result=["bar", "baz"],
exception=None,
),
]
),
)
self.assertEqual(result.responses, result.successful_responses)
self.assertEqual(result.failing_responses, [])
self.assertEqual(result.join_results(), ["foo", "bar", "bar", "baz"])
self.assertEqual(
result.raise_on_failures(),
None,
msg="raise_on_failures is a noop if all responses are successful",
)

def test_partial_exceptions_are_returned(self, m):
m.get(self.a_client.url, json=["foo", "bar"])
# Second client will raise an exception
m.get(self.another_client.url, exc=requests.exceptions.Timeout)

proxy = MultiZgwClientProxy([self.a_client, self.another_client])

result = proxy.fetch_rows()
successful_response, failing_response = result.responses
self.assertEqual(
successful_response,
ZgwClientResponse(
client=self.a_client, result=["foo", "bar"], exception=None
),
)
self.assertEqual(
result.join_results(),
["foo", "bar"],
msg="Only successful results should be joined",
)

# It's non-trivial to compare exceptions on equality, so we have to validate the
# object in steps
self.assertEqual(
(failing_response.client, failing_response.result),
(self.another_client, None),
)
self.assertIsInstance(failing_response.exception, requests.exceptions.Timeout)
self.assertEqual(result.has_errors, True)

with self.assertRaises(MultiZgwClientProxyError) as cm:
result.raise_on_failures()

self.assertTrue(
[e.__class__ for e in cm.exception.exceptions],
[requests.exceptions.Timeout],
)
self.assertEqual(len(result.failing_responses), 1)

def test_result_can_be_used_as_response_iterator(self, m):
m.get(self.a_client.url, json=["foo", "bar"])
m.get(self.another_client.url, json=["bar", "baz"])
proxy = MultiZgwClientProxy([self.a_client, self.another_client])

result = proxy.fetch_rows()

# Verify that each invocation yields a fresh generator with a clean iteration state
for _ in range(2):
self.assertEqual(
[row.result for row in result], [["foo", "bar"], ["bar", "baz"]]
)

def test_response_iterator_includes_failing_responses(self, m):
m.get(self.a_client.url, json=["foo", "bar"])
m.get(self.another_client.url, exc=requests.exceptions.Timeout)

proxy = MultiZgwClientProxy([self.a_client, self.another_client])
result = proxy.fetch_rows()

self.assertEqual([row.exception is None for row in result], [True, False])

0 comments on commit 2d3ad96

Please sign in to comment.