This repository has been archived by the owner on Feb 15, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from defenseunicorns/gerred/whisper-backend
Add Whisper FastAPI backend
- Loading branch information
Showing
8 changed files
with
542 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
data/* | ||
*.tar.zst | ||
__pycache__ | ||
__pycache__ | ||
.venv | ||
.ipynb_checkpoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
cpython@3.11.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"[python]": { | ||
"editor.defaultFormatter": "ms-python.black-formatter" | ||
}, | ||
"python.formatting.provider": "none" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Whsiper API server | ||
|
||
## Pre-requisites | ||
|
||
* [Rye](https://github.com/mitsuhiko/rye) | ||
* System install of ffmpeg | ||
* For GPU usage, latest CUDA drivers | ||
|
||
## Getting Started | ||
|
||
```shell | ||
rye sync | ||
rye run uvicorn main:app --reload | ||
|
||
# to enter a Python shell: | ||
rye shell | ||
``` | ||
|
||
## Usage | ||
|
||
Open `http://localhost:8000/docs` for a Swagger interface. | ||
|
||
## TODO | ||
|
||
* Tests | ||
* Error handling for ffmpeg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import whisper | ||
import aiofiles | ||
import asyncio | ||
from typing import Annotated | ||
from enum import Enum | ||
from fastapi import FastAPI, UploadFile, File, Form | ||
from pydantic import BaseModel | ||
from prometheus_fastapi_instrumentator import Instrumentator | ||
|
||
app = FastAPI() | ||
|
||
model = whisper.load_model("large") | ||
loop = asyncio.get_event_loop() | ||
|
||
|
||
def make_transcribe_request(filename, task, language): | ||
return model.transcribe(filename, task=task, language=language) | ||
|
||
|
||
class WhisperTask(str, Enum): | ||
transcribe = "transcribe" | ||
translate = "translate" | ||
|
||
|
||
# TODO: Make this class use a dict, and show languages in Swagger | ||
class LanguageCode(str, Enum): | ||
af = "af" | ||
ar = "ar" | ||
az = "az" | ||
be = "be" | ||
bg = "bg" | ||
bs = "bs" | ||
ca = "ca" | ||
cs = "cs" | ||
cy = "cy" | ||
da = "da" | ||
de = "de" | ||
el = "el" | ||
en = "en" | ||
es = "es" | ||
et = "et" | ||
fa = "fa" | ||
fi = "fi" | ||
fr = "fr" | ||
gl = "gl" | ||
he = "he" | ||
hi = "hi" | ||
hr = "hr" | ||
hu = "hu" | ||
hy = "hy" | ||
id = "id" | ||
icelandic = "is" | ||
it = "it" | ||
ja = "ja" | ||
kk = "kk" | ||
kn = "kn" | ||
ko = "ko" | ||
lt = "lt" | ||
lv = "lv" | ||
mk = "mk" | ||
ms = "ms" | ||
mr = "mr" | ||
mi = "mi" | ||
nl = "nl" | ||
ne = "ne" | ||
no = "no" | ||
pl = "pl" | ||
pt = "pt" | ||
ro = "ro" | ||
ru = "ru" | ||
sk = "sk" | ||
sl = "sl" | ||
sr = "sr" | ||
sv = "sv" | ||
sw = "sw" | ||
ta = "ta" | ||
th = "th" | ||
tl = "tl" | ||
tr = "tr" | ||
uk = "uk" | ||
ur = "ur" | ||
vi = "vi" | ||
zh = "zh" | ||
|
||
|
||
class TranscribeResponse(BaseModel): | ||
result: str | ||
|
||
|
||
@app.post("/transcribe") | ||
async def transcribe( | ||
file: Annotated[ | ||
UploadFile, File(description="Audio file to run the Whisper model on") | ||
], | ||
task: Annotated[ | ||
WhisperTask, Form(description="The Whisper task to perform") | ||
] = "transcribe", | ||
language: Annotated[ | ||
LanguageCode, | ||
Form( | ||
description="The ISO 3166-1 alpha-2 (two letter language) code of the source audio" | ||
), | ||
] = "en", | ||
) -> TranscribeResponse: | ||
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f: | ||
contents = await file.read() | ||
await f.write(contents) | ||
result = await loop.run_in_executor( | ||
None, make_transcribe_request, f.name, task, language | ||
) | ||
return TranscribeResponse(result=result["text"]) | ||
|
||
|
||
@app.get("/healthz") | ||
async def healthz(): | ||
return {"status": "ok"} | ||
|
||
|
||
Instrumentator().instrument(app).expose(app) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
[project] | ||
name = "whisper" | ||
version = "0.1.0" | ||
description = "Add a short description here" | ||
authors = [ | ||
{ name = "Gerred Dillon", email = "hello@gerred.org" } | ||
] | ||
dependencies = [ | ||
"datasets~=2.12.0", | ||
"transformers~=4.29.2", | ||
"torch~=2.0.1", | ||
"soundfile~=0.12.1", | ||
"librosa~=0.10.0.post2", | ||
"jupyter~=1.0.0", | ||
"pydub~=0.25.1", | ||
"openai-whisper==20230314", | ||
"numpy~=1.24.3", | ||
"fastapi~=0.96.0", | ||
"uvicorn~=0.22.0", | ||
"python-multipart~=0.0.6", | ||
"prometheus-fastapi-instrumentator~=6.0.0", | ||
"asyncio~=3.4.3", | ||
"aiofiles~=23.1.0", | ||
"fastapi_utils~=0.2.1", | ||
] | ||
readme = "README.md" | ||
requires-python = ">= 3.8" | ||
|
||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[tool.rye] | ||
managed = true | ||
|
||
[tool.hatch.metadata] | ||
allow-direct-references = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# generated by rye | ||
# use `rye lock` or `rye sync` to update this lockfile | ||
# | ||
# last locked with the following flags: | ||
# pre: false | ||
# features: [] | ||
# all-features: false | ||
|
||
-e file:. | ||
aiofiles==23.1.0 | ||
aiohttp==3.8.4 | ||
aiosignal==1.3.1 | ||
anyio==3.7.0 | ||
appdirs==1.4.4 | ||
argon2-cffi==21.3.0 | ||
argon2-cffi-bindings==21.2.0 | ||
arrow==1.2.3 | ||
asttokens==2.2.1 | ||
async-timeout==4.0.2 | ||
asyncio==3.4.3 | ||
attrs==23.1.0 | ||
audioread==3.0.0 | ||
backcall==0.2.0 | ||
beautifulsoup4==4.12.2 | ||
bleach==6.0.0 | ||
certifi==2023.5.7 | ||
cffi==1.15.1 | ||
charset-normalizer==3.1.0 | ||
click==8.1.3 | ||
cmake==3.26.3 | ||
comm==0.1.3 | ||
datasets==2.12.0 | ||
debugpy==1.6.7 | ||
decorator==5.1.1 | ||
defusedxml==0.7.1 | ||
dill==0.3.6 | ||
executing==1.2.0 | ||
fastapi==0.96.0 | ||
fastapi-utils==0.2.1 | ||
fastjsonschema==2.17.1 | ||
ffmpeg-python==0.2.0 | ||
filelock==3.12.0 | ||
fqdn==1.5.1 | ||
frozenlist==1.3.3 | ||
fsspec==2023.5.0 | ||
future==0.18.3 | ||
greenlet==2.0.2 | ||
h11==0.14.0 | ||
huggingface-hub==0.15.1 | ||
idna==3.4 | ||
ipykernel==6.23.1 | ||
ipython==8.13.2 | ||
ipython-genutils==0.2.0 | ||
ipywidgets==8.0.6 | ||
isoduration==20.11.0 | ||
jedi==0.18.2 | ||
jinja2==3.1.2 | ||
joblib==1.2.0 | ||
jsonpointer==2.3 | ||
jsonschema==4.17.3 | ||
jupyter==1.0.0 | ||
jupyter-client==8.2.0 | ||
jupyter-console==6.6.3 | ||
jupyter-core==5.3.0 | ||
jupyter-events==0.6.3 | ||
jupyter-server==2.6.0 | ||
jupyter-server-terminals==0.4.4 | ||
jupyterlab-pygments==0.2.2 | ||
jupyterlab-widgets==3.0.7 | ||
lazy-loader==0.2 | ||
librosa==0.10.0.post2 | ||
lit==16.0.5 | ||
llvmlite==0.40.0 | ||
markupsafe==2.1.2 | ||
matplotlib-inline==0.1.6 | ||
mistune==2.0.5 | ||
more-itertools==9.1.0 | ||
mpmath==1.3.0 | ||
msgpack==1.0.5 | ||
multidict==6.0.4 | ||
multiprocess==0.70.14 | ||
nbclassic==1.0.0 | ||
nbclient==0.8.0 | ||
nbconvert==7.4.0 | ||
nbformat==5.9.0 | ||
nest-asyncio==1.5.6 | ||
networkx==3.1 | ||
notebook==6.5.4 | ||
notebook-shim==0.2.3 | ||
numba==0.57.0 | ||
numpy==1.24.3 | ||
nvidia-cublas-cu11==11.10.3.66 | ||
nvidia-cuda-cupti-cu11==11.7.101 | ||
nvidia-cuda-nvrtc-cu11==11.7.99 | ||
nvidia-cuda-runtime-cu11==11.7.99 | ||
nvidia-cudnn-cu11==8.5.0.96 | ||
nvidia-cufft-cu11==10.9.0.58 | ||
nvidia-curand-cu11==10.2.10.91 | ||
nvidia-cusolver-cu11==11.4.0.1 | ||
nvidia-cusparse-cu11==11.7.4.91 | ||
nvidia-nccl-cu11==2.14.3 | ||
nvidia-nvtx-cu11==11.7.91 | ||
openai-whisper==20230314 | ||
overrides==7.3.1 | ||
packaging==23.1 | ||
pandas==2.0.2 | ||
pandocfilters==1.5.0 | ||
parso==0.8.3 | ||
pexpect==4.8.0 | ||
pickleshare==0.7.5 | ||
platformdirs==3.5.1 | ||
pooch==1.6.0 | ||
prometheus-client==0.17.0 | ||
prometheus-fastapi-instrumentator==6.0.0 | ||
prompt-toolkit==3.0.38 | ||
psutil==5.9.5 | ||
ptyprocess==0.7.0 | ||
pure-eval==0.2.2 | ||
pyarrow==12.0.0 | ||
pycparser==2.21 | ||
pydantic==1.10.8 | ||
pydub==0.25.1 | ||
pygments==2.15.1 | ||
pyrsistent==0.19.3 | ||
python-dateutil==2.8.2 | ||
python-json-logger==2.0.7 | ||
python-multipart==0.0.6 | ||
pytz==2023.3 | ||
pyyaml==6.0 | ||
pyzmq==25.1.0 | ||
qtconsole==5.4.3 | ||
qtpy==2.3.1 | ||
regex==2023.5.5 | ||
requests==2.31.0 | ||
responses==0.18.0 | ||
rfc3339-validator==0.1.4 | ||
rfc3986-validator==0.1.1 | ||
scikit-learn==1.2.2 | ||
scipy==1.10.1 | ||
send2trash==1.8.2 | ||
six==1.16.0 | ||
sniffio==1.3.0 | ||
soundfile==0.12.1 | ||
soupsieve==2.4.1 | ||
soxr==0.3.5 | ||
sqlalchemy==1.4.48 | ||
stack-data==0.6.2 | ||
starlette==0.27.0 | ||
sympy==1.12 | ||
terminado==0.17.1 | ||
threadpoolctl==3.1.0 | ||
tiktoken==0.3.1 | ||
tinycss2==1.2.1 | ||
tokenizers==0.13.3 | ||
torch==2.0.1 | ||
tornado==6.3.2 | ||
tqdm==4.65.0 | ||
traitlets==5.9.0 | ||
transformers==4.29.2 | ||
triton==2.0.0 | ||
typing-extensions==4.6.2 | ||
tzdata==2023.3 | ||
uri-template==1.2.0 | ||
urllib3==2.0.2 | ||
uvicorn==0.22.0 | ||
wcwidth==0.2.6 | ||
webcolors==1.13 | ||
webencodings==0.5.1 | ||
websocket-client==1.5.2 | ||
wheel==0.40.0 | ||
widgetsnbextension==4.0.7 | ||
xxhash==3.2.0 | ||
yarl==1.9.2 | ||
# The following packages are considered to be unsafe in a requirements file: | ||
setuptools==67.8.0 |
Oops, something went wrong.