Skip to content

Commit

Permalink
More explicit retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
doneholmes committed Jun 3, 2020
1 parent 395f497 commit a30bb4c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 18 deletions.
15 changes: 9 additions & 6 deletions nuget_package_scanner/smart_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import aiohttp
from async_lru import alru_cache
from tenacity import retry, stop_after_attempt, wait_random
from tenacity import before_log, retry, retry_if_exception_type, stop_after_attempt, wait_random, TryAgain


class SmartClient:
Expand Down Expand Up @@ -67,22 +67,25 @@ async def get_as_json(self, url: str, ignore_404 = True, headers: Optional[dict]
response = await self.get(url, ignore_404, headers)
if response:
async with response:
return await response.json()
return await response.json()

@retry(stop=stop_after_attempt(3), wait=wait_random(min=1, max=2))
# Retry a few times in the event that it's some kind of connection error or 5xx error
# This method should not retry in the event of any 4xx errors
@retry(stop=stop_after_attempt(3), retry=retry_if_exception_type(TryAgain), \
wait=wait_random(min=1, max=3), before=before_log(logging.getLogger(), logging.DEBUG))
async def get(self, url: str, ignore_404 = True, headers: Optional[dict] = None) -> aiohttp.ClientResponse:
assert isinstance(url, str) and url, "url must be a non-empty string"
client = self.get_aiohttp_client(url)
try:
response = await client.get(url,headers=headers)
if ignore_404 and response.status == 404:
logging.debug(f'404 GET {url}')
return
if response.status != 200:
return
if response.status != 200:
raise response.raise_for_status()
logging.debug(f'200 GET {url}')
return response
except aiohttp.ClientResponseError as e:
logging.exception(e)
raise
raise e if e.status < 500 else TryAgain # Explicit call to retry for 5xx errors

38 changes: 26 additions & 12 deletions tests/smart_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,26 @@

from nuget_package_scanner.smart_client import SmartClient


class TestSmartClient(IsolatedAsyncioTestCase):

async def asyncSetUp(self):
async def asyncSetUp(self):
self.sc = SmartClient()

async def asyncTearDown(self):
# Clear alru_cache after each test
# pylint: disable=no-member
self.sc.get_as_json.cache_clear()
self.sc.get_as_text.cache_clear()
# pylint: enable=no-member
await self.sc.close()

async def test_get_empty(self):
c = MagicMock(aiohttp.ClientSession)
c.get = AsyncMock()
self.sc.get_aiohttp_client = MagicMock(return_value=c)
with self.assertRaises(tenacity.RetryError):
await self.sc.get('')

async def test_retry_on_error(self):
c = MagicMock(aiohttp.ClientSession)
c.get = AsyncMock()
self.sc.get_aiohttp_client = MagicMock(return_value=c)
with self.assertRaises(tenacity.RetryError):
await self.sc.get('notreallyaurl')
self.assertEqual(c.get.await_count,3)
with self.assertRaises(AssertionError):
await self.sc.get('')

async def test_get_200(self):
c = MagicMock(aiohttp.ClientSession)
Expand Down Expand Up @@ -95,7 +93,23 @@ async def test_get_aiohttp_client_not_cached(self):
client = self.sc.get_aiohttp_client(url)
client2 = self.sc.get_aiohttp_client(url2)
self.assertNotEqual(client,client2)
self.assertEqual(len(self.sc.clients),2)
self.assertEqual(len(self.sc.clients),2)

async def test_get_is_retried_for_5xx(self):
c = MagicMock(aiohttp.ClientSession)
c.get = AsyncMock(side_effect=aiohttp.ClientResponseError(None,None,status=500))
self.sc.get_aiohttp_client = MagicMock(return_value=c)
with self.assertRaises(tenacity.RetryError):
await self.sc.get('https://a.url.here')
self.assertTrue(c.get.await_count > 1)

async def test_get_is_not_retried_for_4xx(self):
c = MagicMock(aiohttp.ClientSession)
c.get = AsyncMock(side_effect=aiohttp.ClientResponseError(None,None,status=422))
self.sc.get_aiohttp_client = MagicMock(return_value=c)
with self.assertRaises(aiohttp.ClientResponseError):
await self.sc.get('https://a.url.here')
c.get.assert_awaited_once()


if __name__ == '__main__':
Expand Down

0 comments on commit a30bb4c

Please sign in to comment.