-
Notifications
You must be signed in to change notification settings - Fork 700
/
Copy pathbase.py
119 lines (93 loc) · 3.33 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import abc
from typing import Any, Generator, List, Optional
from injector import inject
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.config.module_config import ModuleConfig
from taskweaver.llm.util import ChatMessageType
class ExtLLMModuleConfig(ModuleConfig):
def _configure(self) -> None:
self._set_name("ext_llms")
self.ext_llm_config_dicts = self._get_dict("llm_configs", {})
self.ext_llm_config_mapping = {}
for key, config_dict in self.ext_llm_config_dicts.items():
config = self.src.clone()
for k, v in config_dict.items():
config.set_config_value(
var_name=k,
var_type="str",
value=v,
source="override",
) # override the LLM config from extra llms
self.ext_llm_config_mapping[key] = config
class LLMModuleConfig(ModuleConfig):
def _configure(self) -> None:
self._set_name("llm")
self.api_type = self._get_str(
"api_type",
"openai",
)
self.embedding_api_type = self._get_str(
"embedding_api_type",
"sentence_transformers",
)
self.api_base: Optional[str] = self._get_str("api_base", None, required=False)
self.api_key: Optional[str] = self._get_str(
"api_key",
None,
required=False,
)
self.model: Optional[str] = self._get_str("model", None, required=False)
self.embedding_model: Optional[str] = self._get_str(
"embedding_model",
None,
required=False,
)
self.response_format: Optional[str] = self._get_enum(
"response_format",
options=["json_object", "text", "json_schema"],
default="json_object",
)
self.use_mock: bool = self._get_bool("use_mock", False)
class LLMServiceConfig(ModuleConfig):
@inject
def __init__(
self,
src: AppConfigSource,
llm_module_config: LLMModuleConfig,
) -> None:
self.llm_module_config = llm_module_config
super().__init__(src)
def _set_name(self, name: str) -> None:
self.name = f"llm.{name}"
class CompletionService(abc.ABC):
@abc.abstractmethod
def chat_completion(
self,
messages: List[ChatMessageType],
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
"""
Chat completion API
:param messages: list of messages
:param stream: whether to stream the response
:param temperature: temperature
:param max_tokens: maximum number of tokens
:param top_p: top p
:param kwargs: other model specific keyword arguments
:return: generator of messages
"""
raise NotImplementedError
class EmbeddingService(abc.ABC):
@abc.abstractmethod
def get_embeddings(self, strings: List[str]) -> List[List[float]]:
"""
Embedding API
:param strings: list of strings to be embedded
:return: list of embeddings
"""
raise NotImplementedError