Skip to content

Commit

Permalink
Merge pull request #48 from atomiechen/fix_ensure_client_credentials
Browse files Browse the repository at this point in the history
Fix ensure client credentials
  • Loading branch information
atomiechen authored Sep 16, 2024
2 parents 5d9c7b5 + 05a18d6 commit a6d8b95
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 176 deletions.
70 changes: 57 additions & 13 deletions src/handyllm/endpoint_manager.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,41 @@
from __future__ import annotations

__all__ = [
"Endpoint",
"EndpointManager",
]

import os
from json import JSONDecodeError
from threading import Lock
from collections.abc import MutableSequence
from typing import Iterable, Mapping, Optional, Union
from typing import Iterable, Mapping, Optional, Union, cast

from .types import PathType
from ._utils import isiterable
from ._io import yaml_load
from ._io import yaml_load, json_loads
from ._constants import TYPE_API_TYPES


class Endpoint:
def __init__(
self,
name=None,
api_key=None,
organization=None,
api_base=None,
api_type=None,
api_version=None,
model_engine_map=None,
dest_url=None,
name: Optional[str] = None,
api_key: Optional[str] = None,
organization: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[TYPE_API_TYPES] = None,
api_version: Optional[str] = None,
model_engine_map: Optional[Mapping[str, str]] = None,
dest_url: Optional[str] = None,
):
self.name = name if name else f"ep_{id(self)}"
self.api_type: Optional[TYPE_API_TYPES] = api_type
self.api_base = api_base
self.api_key = api_key
self.organization = organization
self.api_base = api_base
self.api_type = api_type
self.api_version = api_version
self.model_engine_map = model_engine_map if model_engine_map else {}
self.model_engine_map = model_engine_map
self.dest_url = dest_url

def __str__(self) -> str:
Expand Down Expand Up @@ -62,6 +67,45 @@ def get_api_info(self):
self.dest_url,
)

def merge(self, other: Endpoint, override=False):
if not isinstance(other, Endpoint):
raise ValueError(f"Cannot merge with {type(other)}")
if self.api_key is None or override:
self.api_key = other.api_key
if self.organization is None or override:
self.organization = other.organization
if self.api_base is None or override:
self.api_base = other.api_base
if self.api_type is None or override:
self.api_type = other.api_type
if self.api_version is None or override:
self.api_version = other.api_version
if self.model_engine_map is None or override:
self.model_engine_map = other.model_engine_map
if self.dest_url is None or override:
self.dest_url = other.dest_url

def merge_from_env(self):
if self.api_key is None:
self.api_key = os.environ.get("OPENAI_API_KEY")
if self.organization is None:
self.organization = os.environ.get("OPENAI_ORGANIZATION") or os.environ.get(
"OPENAI_ORG_ID"
)
if self.api_base is None:
self.api_base = os.environ.get("OPENAI_API_BASE")
if self.api_type is None:
self.api_type = cast(TYPE_API_TYPES, os.environ.get("OPENAI_API_TYPE"))
if self.api_version is None:
self.api_version = os.environ.get("OPENAI_API_VERSION")
if self.model_engine_map is None:
json_str = os.environ.get("MODEL_ENGINE_MAP")
if json_str:
try:
self.model_engine_map = json_loads(json_str)
except JSONDecodeError:
pass


class EndpointManager(MutableSequence):
def __init__(
Expand Down
Loading

0 comments on commit a6d8b95

Please sign in to comment.