forked from SakuraLLM/SakuraLLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
132 lines (100 loc) · 3.53 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import sys
# Fix for windows embedded environment
file_dir = os.path.dirname(__file__)
sys.path.append(file_dir)
import random
import asyncio
import coloredlogs
import logging
from argparse import ArgumentParser
from dacite import from_dict
from hypercorn import Config
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from api import log_request
from api.auth import get_auth_username
from utils import *
from utils import model as M
from utils import state
from utils.state import ServerConfig
from utils.cli import parse_args
dependencies = [
Depends(log_request),
]
def extra_args(parser: ArgumentParser):
server_group = parser.add_argument_group("Server")
server_group.add_argument("--listen", type=str, default="127.0.0.1:5000", help="listen address ip:port")
server_group.add_argument("--auth", type=str, help="user:pass, user & pass should not contain ':'")
server_group.add_argument("--no-auth", action="store_true", help="force disable auth")
return
args = parse_args(add_extra_args_fn=extra_args)
coloredlogs.install(level=args.logLevel.upper())
logger = logging.getLogger(__name__)
logger.debug(f"Current Log Level: {args.logLevel}")
addr = args.listen.split(":")
ServerConfig.address = addr[0]
ServerConfig.port = int(addr[1])
# Hidden trick to disable auth, useful when you use docker-compose
if args.auth == ":":
args.auth = None
args.no_auth = True
auth = [None, None]
if args.no_auth:
logger.warning("Auth is disabled!")
else:
if not args.auth:
# Generate random auth credentials
args.auth = f"sakura:{random.randint(114514, 19194545)}"
logger.warning(f"Using random auth credentials. {auth}")
auth = args.auth.split(":")
# Insert http auth check
dependencies.append(Depends(get_auth_username))
ServerConfig.username = auth[0]
ServerConfig.password = auth[1]
app = FastAPI(dependencies=dependencies)
from api.legacy import router as legacy_router
app.include_router(legacy_router)
from api.openai.v1 import router as openai_router
app.include_router(openai_router)
from api.openai.v1.chat import router as openai_chat_router
app.include_router(openai_chat_router)
from api.core import router as core_router
app.include_router(core_router)
origins = [
"*",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
logger.info(f"Current server config: {ServerConfig.show()}")
# build cfg from args
cfg = from_dict(data_class=M.SakuraModelConfig, data=args.__dict__)
logger.info(f"Current model config: {cfg}")
state.init_model(cfg)
state.get_model().check_model_by_magic()
logger.info(
f"Server will run at http://{ServerConfig.address}:{ServerConfig.port}, preparing...")
# disable multiprocessing, since LLM model is not thread safe
if False: # use uvicorn
import uvicorn
uvicorn.run("server:app",
host=ServerConfig.address,
port=ServerConfig.port,
log_level=args.logLevel,
workers=1
)
else: # use hypercorn
from hypercorn.asyncio import serve
config = Config()
binding = f"{ServerConfig.address}:{ServerConfig.port}"
logger.debug(f"hypercorn binding: {binding}")
config.bind= [binding,]
config.loglevel = args.logLevel
config.debug = args.logLevel == "debug"
asyncio.run(serve(app, config))