-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
37 lines (30 loc) · 1.02 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
from langchain_huggingface import ChatHuggingFace
from langchain_community.llms import HuggingFaceHub
from langchain_core.prompts import ChatPromptTemplate
ZEPHYR_ID = "HuggingFaceH4/zephyr-7b-beta"
def get_model(repo_id=ZEPHYR_ID, **kwargs):
hf_token = kwargs.get("HUGGINGFACEHUB_API_TOKEN", None)
if not hf_token:
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN", None)
os.environ["HF_TOKEN"] = hf_token
llm = HuggingFaceHub(
repo_id=repo_id,
task="text-generation",
model_kwargs={
"max_new_tokens": 512,
"top_k": 30,
"temperature": 0.1,
"repetition_penalty": 1.03,
"huggingfacehub_api_token": hf_token,
})
return llm
# chat_model = ChatHuggingFace(llm=llm)
# return chat_model
def basic_chain(model=None, prompt=None):
if not model:
model = get_model()
if not prompt:
prompt = ChatPromptTemplate.from_template("Hello world")
chain = prompt | model
return chain