Skip to content

Commit

Permalink
Close session of ChallengePolicyClient in context manager (Azure#20047)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyan99 authored Aug 3, 2021
1 parent d681e2c commit 2cb6357
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 8 deletions.
2 changes: 2 additions & 0 deletions sdk/containerregistry/azure-containerregistry/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bugs Fixed

- Close session of `ChallengePolicyClient` in context manager #20000

### Other Changes

- Bumped dependency on `msrest` to `>=0.6.21`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,10 @@ def on_challenge(self, request, response, challenge):
access_token = self._exchange_client.get_acr_access_token(challenge)
request.http_request.headers["Authorization"] = "Bearer " + access_token
return access_token is not None

def __enter__(self):
self._exchange_client.__enter__()
return self

def __exit__(self, *args):
self._exchange_client.__exit__(*args)
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,22 @@ class ContainerRegistryBaseClient(object):

def __init__(self, endpoint, credential, **kwargs):
# type: (str, Optional[TokenCredential], Dict[str, Any]) -> None
auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs)
self._auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs)
self._client = ContainerRegistry(
credential=credential,
url=endpoint,
sdk_moniker=USER_AGENT,
authentication_policy=auth_policy,
authentication_policy=self._auth_policy,
**kwargs
)

def __enter__(self):
self._auth_policy.__enter__()
self._client.__enter__()
return self

def __exit__(self, *args):
self._auth_policy.__exit__(*args)
self._client.__exit__(*args)

def close(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ async def exchange_refresh_token_for_access_token(
return access_token.access_token

async def __aenter__(self):
self._client.__aenter__()
await self._client.__aenter__()
return self

async def __aexit__(self, *args):
self._client.__aexit__(*args)
await self._client.__aexit__(*args)

async def close(self) -> None:
"""Close sockets opened by the client.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ async def on_challenge(self, request, response, challenge):
access_token = await self._exchange_client.get_acr_access_token(challenge)
request.http_request.headers["Authorization"] = "Bearer " + access_token
return access_token is not None

async def __aenter__(self):
await self._exchange_client.__aenter__()
return self

async def __aexit__(self, *args):
await self._exchange_client.__aexit__()
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,22 @@ class ContainerRegistryBaseClient(object):
"""

def __init__(self, endpoint: str, credential: Optional["AsyncTokenCredential"] = None, **kwargs) -> None:
auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs)
self._auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs)
self._client = ContainerRegistry(
credential=credential,
url=endpoint,
sdk_moniker=USER_AGENT,
authentication_policy=auth_policy,
authentication_policy=self._auth_policy,
**kwargs
)

async def __aenter__(self):
await self._auth_policy.__aenter__()
await self._client.__aenter__()
return self

async def __aexit__(self, *args):
await self._auth_policy.__aexit__(*args)
await self._client.__aexit__(*args)

async def close(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ async def exchange_refresh_token_for_access_token(
return access_token.access_token

async def __aenter__(self):
self._client.__aenter__()
await self._client.__aenter__()
return self

async def __aexit__(self, *args):
self._client.__aexit__(*args)
await self._client.__aexit__(*args)

async def close(self) -> None:
"""Close sockets opened by the client.
Expand Down

0 comments on commit 2cb6357

Please sign in to comment.