From d0f5fb528d7055f612797bda46179eb40a3a65c2 Mon Sep 17 00:00:00 2001 From: mitraan-deshaw <142438905+mitraan-deshaw@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:10:04 -0400 Subject: [PATCH] feat(vertex): add credentials argument (#542) Co-authored-by: Anshuman Mitra --- src/anthropic/lib/vertex/_client.py | 30 +++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/anthropic/lib/vertex/_client.py b/src/anthropic/lib/vertex/_client.py index 5d78756c..9d695524 100644 --- a/src/anthropic/lib/vertex/_client.py +++ b/src/anthropic/lib/vertex/_client.py @@ -128,6 +128,7 @@ def __init__( proxies: ProxiesTypes | None = None, # See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration) connection_pool_limits: httpx.Limits | None = None, + credentials: GoogleCredentials | None = None, _strict_response_validation: bool = False, ) -> None: if not is_given(region): @@ -161,7 +162,7 @@ def __init__( self.region = region self.access_token = access_token - self._credentials: GoogleCredentials | None = None + self.credentials = credentials self.messages = Messages(self) @@ -179,18 +180,18 @@ def _ensure_access_token(self) -> str: if self.access_token is not None: return self.access_token - if not self._credentials: - self._credentials, project_id = load_auth(project_id=self.project_id) + if not self.credentials: + self.credentials, project_id = load_auth(project_id=self.project_id) if not self.project_id: self.project_id = project_id else: - refresh_auth(self._credentials) + refresh_auth(self.credentials) - if not self._credentials.token: + if not self.credentials.token: raise RuntimeError("Could not resolve API token from the environment") - assert isinstance(self._credentials.token, str) - return self._credentials.token + assert isinstance(self.credentials.token, str) + return self.credentials.token class AsyncAnthropicVertex(BaseVertexClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient): @@ -215,6 +216,7 @@ def __init__( proxies: ProxiesTypes | None = None, # See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration) connection_pool_limits: httpx.Limits | None = None, + credentials: GoogleCredentials | None = None, _strict_response_validation: bool = False, ) -> None: if not is_given(region): @@ -248,7 +250,7 @@ def __init__( self.region = region self.access_token = access_token - self._credentials: GoogleCredentials | None = None + self.credentials = credentials self.messages = AsyncMessages(self) @@ -266,15 +268,15 @@ async def _ensure_access_token(self) -> str: if self.access_token is not None: return self.access_token - if not self._credentials: - self._credentials, project_id = await asyncify(load_auth)(project_id=self.project_id) + if not self.credentials: + self.credentials, project_id = await asyncify(load_auth)(project_id=self.project_id) if not self.project_id: self.project_id = project_id else: - await asyncify(refresh_auth)(self._credentials) + await asyncify(refresh_auth)(self.credentials) - if not self._credentials.token: + if not self.credentials.token: raise RuntimeError("Could not resolve API token from the environment") - assert isinstance(self._credentials.token, str) - return self._credentials.token + assert isinstance(self.credentials.token, str) + return self.credentials.token \ No newline at end of file