diff --git a/src/handyllm/endpoint_manager.py b/src/handyllm/endpoint_manager.py index b9e0cc4..b0ae878 100644 --- a/src/handyllm/endpoint_manager.py +++ b/src/handyllm/endpoint_manager.py @@ -1,36 +1,41 @@ +from __future__ import annotations + __all__ = [ "Endpoint", "EndpointManager", ] +import os +from json import JSONDecodeError from threading import Lock from collections.abc import MutableSequence -from typing import Iterable, Mapping, Optional, Union +from typing import Iterable, Mapping, Optional, Union, cast from .types import PathType from ._utils import isiterable -from ._io import yaml_load +from ._io import yaml_load, json_loads +from ._constants import TYPE_API_TYPES class Endpoint: def __init__( self, - name=None, - api_key=None, - organization=None, - api_base=None, - api_type=None, - api_version=None, - model_engine_map=None, - dest_url=None, + name: Optional[str] = None, + api_key: Optional[str] = None, + organization: Optional[str] = None, + api_base: Optional[str] = None, + api_type: Optional[TYPE_API_TYPES] = None, + api_version: Optional[str] = None, + model_engine_map: Optional[Mapping[str, str]] = None, + dest_url: Optional[str] = None, ): self.name = name if name else f"ep_{id(self)}" + self.api_type: Optional[TYPE_API_TYPES] = api_type + self.api_base = api_base self.api_key = api_key self.organization = organization - self.api_base = api_base - self.api_type = api_type self.api_version = api_version - self.model_engine_map = model_engine_map if model_engine_map else {} + self.model_engine_map = model_engine_map self.dest_url = dest_url def __str__(self) -> str: @@ -62,6 +67,45 @@ def get_api_info(self): self.dest_url, ) + def merge(self, other: Endpoint, override=False): + if not isinstance(other, Endpoint): + raise ValueError(f"Cannot merge with {type(other)}") + if self.api_key is None or override: + self.api_key = other.api_key + if self.organization is None or override: + self.organization = other.organization + if self.api_base is None or override: + self.api_base = other.api_base + if self.api_type is None or override: + self.api_type = other.api_type + if self.api_version is None or override: + self.api_version = other.api_version + if self.model_engine_map is None or override: + self.model_engine_map = other.model_engine_map + if self.dest_url is None or override: + self.dest_url = other.dest_url + + def merge_from_env(self): + if self.api_key is None: + self.api_key = os.environ.get("OPENAI_API_KEY") + if self.organization is None: + self.organization = os.environ.get("OPENAI_ORGANIZATION") or os.environ.get( + "OPENAI_ORG_ID" + ) + if self.api_base is None: + self.api_base = os.environ.get("OPENAI_API_BASE") + if self.api_type is None: + self.api_type = cast(TYPE_API_TYPES, os.environ.get("OPENAI_API_TYPE")) + if self.api_version is None: + self.api_version = os.environ.get("OPENAI_API_VERSION") + if self.model_engine_map is None: + json_str = os.environ.get("MODEL_ENGINE_MAP") + if json_str: + try: + self.model_engine_map = json_loads(json_str) + except JSONDecodeError: + pass + class EndpointManager(MutableSequence): def __init__( diff --git a/src/handyllm/openai_client.py b/src/handyllm/openai_client.py index c913250..6e5a418 100644 --- a/src/handyllm/openai_client.py +++ b/src/handyllm/openai_client.py @@ -5,9 +5,8 @@ "ClientMode", ] +import copy from typing import Dict, Iterable, Mapping, Optional, TypeVar, Union -import os -from json import JSONDecodeError import time from enum import Enum, auto import asyncio @@ -35,7 +34,7 @@ TYPE_API_TYPES, ) from .types import PathType -from ._io import yaml_load, json_loads +from ._io import yaml_load RequestorType = TypeVar("RequestorType", bound="Requestor") @@ -53,34 +52,82 @@ class ClientMode(Enum): class OpenAIClient: - # set this to your API type; - # or environment variable OPENAI_API_TYPE will be used; - # can be None (roll back to default). - api_type: Optional[TYPE_API_TYPES] - - # set this to your API base; - # or environment variable OPENAI_API_BASE will be used. - # can be None (roll back to default). - api_base: Optional[str] - - # set this to your API key; - # or environment variable OPENAI_API_KEY will be used. - api_key: Optional[str] - - # set this to your organization ID; - # or environment variable OPENAI_ORGANIZATION / OPENAI_ORG_ID will be used; - # can be None. - organization: Optional[str] - - # set this to your API version; - # or environment variable OPENAI_API_VERSION will be used; - # cannot be None if using Azure API. - api_version: Optional[str] - - # set this to your model-engine map; - # or environment variable MODEL_ENGINE_MAP will be used; - # can be None. - model_engine_map: Optional[Dict[str, str]] + @property + def api_type(self) -> Optional[TYPE_API_TYPES]: + """ + Set this to your API type; + or environment variable OPENAI_API_TYPE will be used; + can be None (roll back to default). + """ + return self._endpoint.api_type + + @api_type.setter + def api_type(self, value: Optional[TYPE_API_TYPES]) -> None: + self._endpoint.api_type = value + + @property + def api_base(self) -> Optional[str]: + """ + Set this to your API base; + or environment variable OPENAI_API_BASE will be used; + can be None (roll back to default). + """ + return self._endpoint.api_base + + @api_base.setter + def api_base(self, value: Optional[str]) -> None: + self._endpoint.api_base = value + + @property + def api_key(self) -> Optional[str]: + """ + Set this to your API key; + or environment variable OPENAI_API_KEY will be used. + """ + return self._endpoint.api_key + + @api_key.setter + def api_key(self, value: Optional[str]) -> None: + self._endpoint.api_key = value + + @property + def organization(self) -> Optional[str]: + """ + Set this to your organization ID; + or environment variable OPENAI_ORGANIZATION / OPENAI_ORG_ID will be used; + can be None. + """ + return self._endpoint.organization + + @organization.setter + def organization(self, value: Optional[str]) -> None: + self._endpoint.organization = value + + @property + def api_version(self) -> Optional[str]: + """ + Set this to your API version; + or environment variable OPENAI_API_VERSION will be used; + cannot be None if using Azure API. + """ + return self._endpoint.api_version + + @api_version.setter + def api_version(self, value: Optional[str]) -> None: + self._endpoint.api_version = value + + @property + def model_engine_map(self) -> Optional[Mapping[str, str]]: + """ + Set this to your model-engine map; + or environment variable MODEL_ENGINE_MAP will be used; + can be None. + """ + return self._endpoint.model_engine_map + + @model_engine_map.setter + def model_engine_map(self, value: Optional[Mapping[str, str]]) -> None: + self._endpoint.model_engine_map = value # set this to your endpoint manager endpoint_manager: Optional[EndpointManager] = None @@ -102,6 +149,7 @@ def __init__( endpoint_manager: Optional[EndpointManager] = None, endpoints: Optional[Iterable] = None, load_path: Optional[PathType] = None, + ensure_client_credentials: bool = False, ) -> None: self._sync_client = None self._async_client = None @@ -127,12 +175,16 @@ def __init__( self._async_client = httpx.AsyncClient() - self.api_base = api_base - self.api_key = api_key - self.organization = organization - self.api_type = api_type - self.api_version = api_version - self.model_engine_map = model_engine_map + self._endpoint = Endpoint( + api_key=api_key, + organization=organization, + api_base=api_base, + api_type=api_type, + api_version=api_version, + model_engine_map=model_engine_map, + ) + + self.ensure_client_credentials = ensure_client_credentials if endpoint_manager: if not isinstance(endpoint_manager, EndpointManager): @@ -155,29 +207,23 @@ def load_from(self, path: PathType, encoding="utf-8", override=False): def load_from_obj(self, obj: Mapping, override=False): if not isinstance(obj, Mapping): raise ValueError("obj must be a mapping (dict, etc.)") - api_base = obj.get("api_base", None) - api_key = obj.get("api_key", None) - organization = obj.get("organization", None) - api_type = obj.get("api_type", None) - api_version = obj.get("api_version", None) - model_engine_map = obj.get("model_engine_map", None) + tmp_endpoint = Endpoint( + api_base=obj.get("api_base", None), + api_key=obj.get("api_key", None), + organization=obj.get("organization", None), + api_type=obj.get("api_type", None), + api_version=obj.get("api_version", None), + model_engine_map=obj.get("model_engine_map", None), + ) + self._endpoint.merge(tmp_endpoint, override=override) item = obj.get("endpoints", None) - if api_base and (override or not self.api_base): - self.api_base = api_base - if api_key and (override or not self.api_key): - self.api_key = api_key - if organization and (override or not self.organization): - self.organization = organization - if api_type and (override or not self.api_type): - self.api_type = api_type - if api_version and (override or not self.api_version): - self.api_version = api_version - if model_engine_map and (override or not self.model_engine_map): - self.model_engine_map = model_engine_map if item and (override or not self.endpoint_manager): if self.endpoint_manager is None: self.endpoint_manager = EndpointManager() self.endpoint_manager.load_from_list(item, override=override) + ensure_client_credentials = obj.get("ensure_client_credentials", None) + if ensure_client_credentials is not None and override: + self.ensure_client_credentials = ensure_client_credentials def __enter__(self): return self @@ -222,109 +268,75 @@ async def aclose(self): except Exception: pass - def _infer_api_key(self, api_key=None): - return api_key or self.api_key or os.environ.get("OPENAI_API_KEY") - - def _infer_organization(self, organization=None): - return ( - organization - or self.organization - or os.environ.get("OPENAI_ORGANIZATION") - or os.environ.get("OPENAI_ORG_ID") - ) - - def _infer_api_base(self, api_base=None): - return ( - api_base - or self.api_base - or os.environ.get("OPENAI_API_BASE") - or API_BASE_OPENAI - ) - - def _infer_api_type(self, api_type=None): - return ( - api_type - or self.api_type - or os.environ.get("OPENAI_API_TYPE") - or API_TYPE_OPENAI - ).lower() - - def _infer_api_version(self, api_version=None): - return api_version or self.api_version or os.environ.get("OPENAI_API_VERSION") - - def _infer_model_engine_map(self, model_engine_map=None): - if model_engine_map: - return model_engine_map - if self.model_engine_map: - return self.model_engine_map - json_str = os.environ.get("MODEL_ENGINE_MAP") - if not json_str: - return None - try: - return json_loads(json_str) - except JSONDecodeError: - return None - def _consume_kwargs(self, kwargs): - api_key = organization = api_base = api_type = api_version = engine = ( - model_engine_map - ) = dest_url = endpoint_manager = None - - # read API info from endpoint_manager - endpoints = kwargs.pop("endpoints", None) - if endpoints: - endpoint_manager = EndpointManager(endpoints=endpoints) - endpoint_manager = ( - kwargs.pop("endpoint_manager", endpoint_manager) or self.endpoint_manager + # consume arguments + transient_endpoint = Endpoint( + api_key=kwargs.pop("api_key", None), + organization=kwargs.pop("organization", None), + api_base=kwargs.pop("api_base", None), + api_type=kwargs.pop("api_type", None), + api_version=kwargs.pop("api_version", None), + model_engine_map=kwargs.pop("model_engine_map", None), + dest_url=kwargs.pop("dest_url", None), ) - if endpoint_manager is not None and not kwargs.get( - "__endpoint_manager_used__", False - ): - if not isinstance(endpoint_manager, EndpointManager): - raise Exception( - "endpoint_manager must be an instance of EndpointManager" - ) - # get_next_endpoint() will be called once for each request - ( - api_key, - organization, - api_base, - api_type, - api_version, - model_engine_map, - dest_url, - ) = endpoint_manager.get_next_endpoint().get_api_info() - kwargs["__endpoint_manager_used__"] = True - - # read API info from endpoint (override API info from endpoint_manager) + endpoints = kwargs.pop("endpoints", None) + endpoint_manager = kwargs.pop("endpoint_manager", None) endpoint = kwargs.pop("endpoint", None) - if endpoint is not None: - if not isinstance(endpoint, Endpoint): - endpoint = Endpoint(**endpoint) - ( - api_key, - organization, - api_base, - api_type, - api_version, - model_engine_map, - dest_url, - ) = endpoint.get_api_info() - - # read API info from kwargs, class variables, and environment variables - api_key = self._infer_api_key(kwargs.pop("api_key", api_key)) - organization = self._infer_organization( - kwargs.pop("organization", organization) - ) - api_base = self._infer_api_base(kwargs.pop("api_base", api_base)) - api_type = self._infer_api_type(kwargs.pop("api_type", api_type)) - api_version = self._infer_api_version(kwargs.pop("api_version", api_version)) - model_engine_map = self._infer_model_engine_map( - kwargs.pop("model_engine_map", model_engine_map) - ) - deployment_id = kwargs.pop("deployment_id", None) engine = kwargs.pop("engine", deployment_id) + + if self.ensure_client_credentials: + if self.endpoint_manager is not None: + # read API info from endpoint_manager + if not kwargs.get("__endpoint_manager_used__", False): + transient_endpoint = copy.deepcopy(self._endpoint) + # get_next_endpoint() will be called once for each request + transient_endpoint.merge(self.endpoint_manager.get_next_endpoint()) + kwargs["__endpoint_manager_used__"] = True + else: + # read API info from endpoint + transient_endpoint = copy.deepcopy(self._endpoint) + else: + # merge endpoint from 'endpoint' parameter + if endpoint is not None: + if not isinstance(endpoint, Endpoint): + endpoint = Endpoint(**endpoint) + transient_endpoint.merge(endpoint) + + # merge endpoint from 'endpoints' / 'endpoint_manager' parameters + if endpoints and not endpoint_manager: + endpoint_manager = EndpointManager(endpoints=endpoints) + endpoint_manager = endpoint_manager or self.endpoint_manager + if endpoint_manager is not None and not kwargs.get( + "__endpoint_manager_used__", False + ): + if not isinstance(endpoint_manager, EndpointManager): + raise Exception( + "endpoint_manager must be an instance of EndpointManager" + ) + # get_next_endpoint() will be called once for each request + transient_endpoint.merge(endpoint_manager.get_next_endpoint()) + kwargs["__endpoint_manager_used__"] = True + + # merge endpoint from client-wide credentials + transient_endpoint.merge(self._endpoint) + # merge endpoint from environment variables + transient_endpoint.merge_from_env() + + ( + api_key, + organization, + api_base, + api_type, + api_version, + model_engine_map, + dest_url, + ) = transient_endpoint.get_api_info() + + # ensure api_base and api_type are set + api_base = api_base or API_BASE_OPENAI + api_type = (api_type or API_TYPE_OPENAI).lower() + # if using Azure and engine not provided, try to get it from model parameter if api_type and api_type in API_TYPES_AZURE: if not engine: @@ -339,14 +351,6 @@ def _consume_kwargs(self, kwargs): engine = model_engine_map.get(model, model) else: engine = model - dest_url = kwargs.pop("dest_url", dest_url) - - if self.ensure_client_credentials: - api_key = self.api_key - organization = self.organization - api_base = self.api_base or API_BASE_OPENAI - api_type = self.api_type or API_TYPE_OPENAI - api_version = self.api_version return api_key, organization, api_base, api_type, api_version, engine, dest_url def _make_requestor( diff --git a/tests/test_client.py b/tests/test_client.py index 6f35d20..e446b83 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -145,8 +145,7 @@ def test_chat_stream(): def test_ensure_client_credentials(): - client = OpenAIClient() - client.api_key = "client_key" + client = OpenAIClient(api_key="client_key") assert ( client.chat(messages=[], api_key="should_be_used").api_key == "should_be_used" ) @@ -154,3 +153,47 @@ def test_ensure_client_credentials(): assert ( client.chat(messages=[], api_key="should_not_be_used").api_key == "client_key" ) + + +def test_ensure_client_credentials2(): + client2 = OpenAIClient(ensure_client_credentials=True, api_key="client_key") + assert ( + client2.chat(messages=[], api_key="should_not_be_used").api_key == "client_key" + ) + client2.ensure_client_credentials = False + assert ( + client2.chat(messages=[], api_key="should_be_used").api_key == "should_be_used" + ) + + +def test_ensure_client_credentials3(): + client3 = OpenAIClient( + ensure_client_credentials=True, + endpoints=[ + { + "api_key": "client_key1", + }, + { + "api_key": "client_key2", + }, + { + "api_key": "client_key3", + }, + ], + ) + assert ( + client3.chat(messages=[], api_key="should_not_be_used").api_key == "client_key1" + ) + assert ( + client3.chat(messages=[], api_key="should_not_be_used").api_key == "client_key2" + ) + assert ( + client3.chat(messages=[], api_key="should_not_be_used").api_key == "client_key3" + ) + assert ( + client3.chat(messages=[], api_key="should_not_be_used").api_key == "client_key1" + ) + client3.ensure_client_credentials = False + assert ( + client3.chat(messages=[], api_key="should_be_used").api_key == "should_be_used" + )