Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expose openai api endpoints from vllm #112

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 46 additions & 34 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ tzdata==2024.2 ; python_version >= "3.11" and python_version < "3.12"
urllib3==2.2.3 ; python_version >= "3.11" and python_version < "3.12"
uuid6==2024.7.10 ; python_version >= "3.11" and python_version < "3.12"
uvicorn[standard]==0.29.0 ; python_version >= "3.11" and python_version < "3.12"
uvloop==0.20.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.11" and python_version < "3.12"
uvloop==0.21.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.11" and python_version < "3.12"
vllm==0.6.2 ; python_version >= "3.11" and python_version < "3.12"
watchfiles==0.24.0 ; python_version >= "3.11" and python_version < "3.12"
websockets==13.1 ; python_version >= "3.11" and python_version < "3.12"
Expand Down
13 changes: 10 additions & 3 deletions skynet/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import sys
import uuid

import torch

app_uuid = str(uuid.uuid4())

is_mac = sys.platform == 'darwin'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
use_vllm = device == 'cuda'


# utilities
def tobool(val: str | None):
Expand All @@ -18,6 +23,7 @@ def tobool(val: str | None):


# general
app_port = int(os.environ.get('SKYNET_PORT', 8000))
log_level = os.environ.get('LOG_LEVEL', 'DEBUG').strip().upper()
supported_modules = {'summaries:dispatcher', 'summaries:executor', 'streaming_whisper'}
enabled_modules = set(os.environ.get('ENABLED_MODULES', 'summaries:dispatcher,summaries:executor').split(','))
Expand All @@ -36,9 +42,10 @@ def tobool(val: str | None):

# openai api
llama_cpp_server_path = os.environ.get('LLAMA_CPP_SERVER_PATH', './llama.cpp/llama-server')
vllm_server_path = os.environ.get('VLLM_SERVER_PATH', 'vllm.entrypoints.openai.api_server')
openai_api_server_port = int(os.environ.get('OPENAI_API_SERVER_PORT', 8003))
openai_api_base_url = os.environ.get('OPENAI_API_BASE_URL', f'http://localhost:{openai_api_server_port}')
openai_api_server_port = int(os.environ.get('OPENAI_API_SERVER_PORT', app_port if use_vllm else 8003))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we get rid of this while we're here, and use the URL, defaulting to ollama's default port when not doing vllm?

openai_api_base_url = os.environ.get(
'OPENAI_API_BASE_URL', f'http://localhost:{openai_api_server_port}{"/openai" if use_vllm else ""}'
)

# openai
openai_credentials_file = os.environ.get('SKYNET_CREDENTIALS_PATH')
Expand Down
3 changes: 3 additions & 0 deletions skynet/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ <h1>Skynet</h1>
<li>
<a href="/summaries/docs">Summaries API</a>
</li>
<li>
<a href="/openai/docs">OpenAI API</a>
</li>
</ul>
</body>
</html>
20 changes: 16 additions & 4 deletions skynet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@
from fastapi.responses import FileResponse

from skynet.agent import create_tcpserver
from skynet.env import enable_haproxy_agent, enable_metrics, modules
from skynet.env import app_port, device, enable_haproxy_agent, enable_metrics, is_mac, modules, use_vllm
from skynet.logs import get_logger
from skynet.utils import create_app, create_webserver

log = get_logger(__name__)

if not modules:
log.warn('No modules enabled!')
log.warning('No modules enabled!')
sys.exit(1)

log.info(f'Enabled modules: {modules}')

if device == 'cuda' or is_mac:
log.info('Using GPU')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like how you solved this!

else:
log.info('Using CPU')


@asynccontextmanager
async def lifespan(main_app: FastAPI):
Expand All @@ -40,7 +45,14 @@ async def lifespan(main_app: FastAPI):
if 'summaries:executor' in modules:
from skynet.modules.ttt.summaries.app import executor_startup as executor_startup

await executor_startup()
if use_vllm:
from vllm.entrypoints.openai.api_server import lifespan

app = create_app(lifespan=lifespan)
await executor_startup(app)
main_app.mount('/openai', app)
else:
await executor_startup()

yield

Expand All @@ -61,7 +73,7 @@ def root():


async def main():
tasks = [asyncio.create_task(create_webserver('skynet.main:app', port=8000))]
tasks = [asyncio.create_task(create_webserver('skynet.main:app', port=app_port))]

if enable_metrics:
tasks.append(asyncio.create_task(create_webserver('skynet.metrics:metrics', port=8001)))
Expand Down
13 changes: 9 additions & 4 deletions skynet/modules/stt/streaming_whisper/cfg.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import os

import torch
from faster_whisper import WhisperModel

from skynet.env import whisper_compute_type, whisper_device, whisper_gpu_indices, whisper_model_name, whisper_model_path
from skynet.env import (
device,
whisper_compute_type,
whisper_device,
whisper_gpu_indices,
whisper_model_name,
whisper_model_path,
)
from skynet.logs import get_logger
from skynet.modules.stt.streaming_whisper.utils import vad_utils as vad
from skynet.utils import get_device

log = get_logger(__name__)


vad_model = vad.init_jit_model(f'{os.getcwd()}/skynet/modules/stt/streaming_whisper/models/vad/silero_vad.jit')

device = whisper_device if whisper_device != 'auto' else get_device()
device = whisper_device if whisper_device != 'auto' else device
log.info(f'Using {device}')
num_workers = 1
gpu_indices = [0]
Expand Down
Loading