Skip to content

Commit

Permalink
Add script to download whisper models
Browse files Browse the repository at this point in the history
  • Loading branch information
philmcmahon committed Jan 21, 2025
1 parent 4a06aa7 commit 5b5faa0
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions scripts/download_whisperx_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torchaudio
from pyannote.audio import Pipeline
import sys
import huggingface_hub
import typer

# ASR Models
# Should be kept in sync with https://github.com/m-bain/whisperX/blob/main/whisperx/asr.py
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}

DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
"ko": "kresnik/wav2vec2-large-xlsr-korean",
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
"ca": "softcatala/wav2vec2-large-xlsr-catala",
"ml": "gvs/wav2vec2-large-xlsr-malayalam",
"no": "NbAiLab/nb-wav2vec2-1b-bokmaal",
"nn": "NbAiLab/nb-wav2vec2-300m-nynorsk",
}

def download_torch_align_models():
for lang, model_name in DEFAULT_ALIGN_MODELS_TORCH.items():
print(f"Downloading {model_name} for {lang}")
bundle = torchaudio.pipelines.__dict__[model_name]
bundle.get_model()
print(f"Downloaded {model_name} for {lang}")

def download_huggingface_align_models():
for lang, model_name in DEFAULT_ALIGN_MODELS_HF.items():
print(f"Downloading {model_name} for {lang}")
huggingface_hub.snapshot_download(model_name)
print(f"Downloaded {model_name} for {lang}")


# Diarization - see https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py

def download_diarization_models(auth_token):
PYANNOTE_MODEL="pyannote/speaker-diarization-3.1"
Pipeline.from_pretrained(PYANNOTE_MODEL, use_auth_token=auth_token)

# faster-whisper models

################
# Note - this section below is copied from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/utils.py
# and then heavily simplified to only include the models we need
###############

WHISPER_MODELS = {
"tiny": "Systran/faster-whisper-tiny",
"small": "Systran/faster-whisper-small",
"medium": "Systran/faster-whisper-medium",
"large": "Systran/faster-whisper-large-v3",
}

def download_model(
model: str,
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
Args:
model: Size of the model to download from https://huggingface.co/Systran
(see https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/utils.py#L12 for full list) - here
limited to tiny, small, medium, large.
Returns:
The path to the downloaded model.
"""
repo_id = WHISPER_MODELS.get(model)

allow_patterns = [
"config.json",
"preprocessor_config.json",
"model.bin",
"tokenizer.json",
"vocabulary.*",
]

kwargs = {
"allow_patterns": allow_patterns,
}

return huggingface_hub.snapshot_download(repo_id, **kwargs)

def download_all_whisper_models():
for model_name in WHISPER_MODELS.keys():
download_model(model_name)

app = typer.Typer()

@app.command()
def main(
whisper_models: bool = typer.Option(False, help="Download whisper models"),
diarization_models: bool = typer.Option(False, help="Download diarization models"),
torch_align_models: bool = typer.Option(False, help="Download torch align models"),
huggingface_align_models: bool = typer.Option(False, help="Download huggingface align models"),
huggingface_token: str = typer.Option("", help="Huggingface authentication token")):
if whisper_models:
download_all_whisper_models()
if diarization_models:
if not huggingface_token:
print("Please provide a Huggingface authentication token (--huggingface-token <token>)")
sys.exit(1)
download_diarization_models(huggingface_token)
if torch_align_models:
download_torch_align_models()
if huggingface_align_models:
download_huggingface_align_models()


if __name__ == "__main__":
typer.run(main)

0 comments on commit 5b5faa0

Please sign in to comment.