-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathARGO.py
80 lines (68 loc) · 2.4 KB
/
ARGO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#
# A wrapper class for the Argonne Argo LLM service
#
import os
import requests
import json
MODEL_GPT35 = "gpt35"
MODEL_GPT4 = "gpt4"
class ArgoWrapper:
default_url = "https://apps-dev.inside.anl.gov/argoapi/api/v1/resource/chat/"
def __init__(self,
url = None,
model = MODEL_GPT35,
system = "",
temperature = 0.8,
top_p=0.7,
user = os.getenv("USER"))-> None:
self.url = url
if self.url is None:
self.url = ArgoWrapper.default_url
self.model = model
self.temperature = temperature
self.top_p = top_p
self.user = user
self.system = ""
def invoke(self, prompt: str):
headers = {
"Content-Type": "application/json"
}
data = {
"user": self.user,
"model": self.model,
"system": self.system,
"prompt": [prompt],
"stop": [],
"temperature": self.temperature,
"top_p": self.top_p
}
data_json = json.dumps(data)
response = requests.post(self.url, headers=headers, data=data_json)
if response.status_code == 200:
parsed = json.loads(response.text)
return parsed
else:
raise Exception(f"Request failed with status code: {response.status_code}")
class ArgoEmbeddingWrapper:
default_url = "https://apps-dev.inside.anl.gov/argoapi/api/v1/resource/embed/"
def __init__(self, url = None, user = os.getenv("USER")) -> None:
self.url = url if url else ArgoEmbeddingWrapper.default_url
self.user = user
#self.argo_embedding_wrapper = argo_embedding_wrapper
def invoke(self, prompts: list):
headers = { "Content-Type": "application/json" }
data = {
"user": self.user,
"prompt": prompts
}
data_json = json.dumps(data)
response = requests.post(self.url, headers=headers, data=data_json)
if response.status_code == 200:
parsed = json.loads(response.text)
return parsed
else:
raise Exception(f"Request failed with status code: {response.status_code}")
def embed_documents(self, texts):
return self.invoke(texts)
def embed_query(self, query):
return self.invoke(query)[0]