Skip to content

Commit

Permalink
feat: allow to pass a custom client
Browse files Browse the repository at this point in the history
  • Loading branch information
hartungstenio committed Jan 28, 2025
1 parent eed7d63 commit 8e86acf
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
37 changes: 23 additions & 14 deletions src/loafer/ext/aws/bases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import asynccontextmanager

from aiobotocore.session import get_session

Expand All @@ -10,20 +11,28 @@
class _BotoClient:
boto_service_name = None

def __init__(self, **client_options):
self._client_options = {
"api_version": client_options.get("api_version"),
"aws_access_key_id": client_options.get("aws_access_key_id"),
"aws_secret_access_key": client_options.get("aws_secret_access_key"),
"aws_session_token": client_options.get("aws_session_token"),
"endpoint_url": client_options.get("endpoint_url"),
"region_name": client_options.get("region_name"),
"use_ssl": client_options.get("use_ssl", True),
"verify": client_options.get("verify"),
}

def get_client(self):
return session.create_client(self.boto_service_name, **self._client_options)
def __init__(self, *, client=None, **client_options):
if client:
self._client = client
else:
self._client_options = {
"api_version": client_options.get("api_version"),
"aws_access_key_id": client_options.get("aws_access_key_id"),
"aws_secret_access_key": client_options.get("aws_secret_access_key"),
"aws_session_token": client_options.get("aws_session_token"),
"endpoint_url": client_options.get("endpoint_url"),
"region_name": client_options.get("region_name"),
"use_ssl": client_options.get("use_ssl", True),
"verify": client_options.get("verify"),
}

@asynccontextmanager
async def get_client(self):
if hasattr(self, "_client"):
yield self._client
else:
async with session.create_client(self.boto_service_name, **self._client_options) as client:
yield client


class BaseSQSClient(_BotoClient):
Expand Down
29 changes: 25 additions & 4 deletions tests/ext/aws/test_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,22 @@ async def test_get_queue_url_when_queue_name_is_url(mock_boto_session_sqs, boto_
@pytest.mark.asyncio
async def test_sqs_get_client(mock_boto_session_sqs, base_sqs_client, boto_client_sqs):
with mock_boto_session_sqs as mock_session:
client_generator = base_sqs_client.get_client()
async with base_sqs_client.get_client() as client:
assert boto_client_sqs is client

assert mock_session.called
async with client_generator as client:


@pytest.mark.asyncio
async def test_sqs_get_client_with_custom_client(mock_boto_session_sqs, boto_client_sqs):
base_sqs_client = BaseSQSClient(client=boto_client_sqs)

with mock_boto_session_sqs as mock_session:
async with base_sqs_client.get_client() as client:
assert boto_client_sqs is client

mock_session.assert_not_called()


@pytest.fixture
def base_sns_client():
Expand All @@ -70,7 +81,17 @@ async def test_cache_get_topic_arn_with_arn(base_sns_client):
@pytest.mark.asyncio
async def test_sns_get_client(mock_boto_session_sns, base_sns_client, boto_client_sns):
with mock_boto_session_sns as mock_session:
client_generator = base_sns_client.get_client()
async with base_sns_client.get_client() as client:
assert boto_client_sns is client
assert mock_session.called
async with client_generator as client:


@pytest.mark.asyncio
async def test_sns_get_client_with_custom_client(mock_boto_session_sns, boto_client_sns):
base_sns_client = BaseSNSClient(client=boto_client_sns)

with mock_boto_session_sns as mock_session:
async with base_sns_client.get_client() as client:
assert boto_client_sns is client

mock_session.assert_not_called()

0 comments on commit 8e86acf

Please sign in to comment.