Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add model-util CLI #59

Merged
merged 27 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
132f481
Port tgis cli changes from past PRs
rafvasq Jul 23, 2024
8f0283e
remove dead code (vllm<=0.5.0.post1)
dtrifiro Jul 25, 2024
aeeeea6
Consolidates cli and fixes imports
rafvasq Jul 26, 2024
156cb49
build(deps): bump ruff from 0.5.4 to 0.5.5
dependabot[bot] Jul 29, 2024
90a8b47
Changes name of cli and fixes bugs
rafvasq Jul 30, 2024
ba5c5e0
Merge branch 'main' into add-cli-cmds
rafvasq Aug 1, 2024
5b664bc
Merge branch 'main' into add-cli-cmds
rafvasq Aug 2, 2024
497879a
Change cmd to model-util
rafvasq Aug 2, 2024
1edcb55
Merge branch 'main' into add-cli-cmds
rafvasq Aug 7, 2024
f79c341
Update src/vllm_tgis_adapter/tgis_utils/hub.py
rafvasq Aug 7, 2024
45a3f76
Merge branch 'main' into add-cli-cmds
rafvasq Aug 7, 2024
af203d2
Updates
rafvasq Aug 7, 2024
2ef8dc8
Merge branch 'main' into add-cli-cmds
rafvasq Aug 7, 2024
6a58a2f
Remove unneeded handler and trycatch
rafvasq Aug 7, 2024
434e652
Merge branch 'main' into add-cli-cmds
rafvasq Aug 7, 2024
127062d
Add alias text-generation-server
rafvasq Aug 7, 2024
b302586
Merge branch 'main' into add-cli-cmds
rafvasq Aug 7, 2024
a668409
Reverts grpc_healthceck
rafvasq Aug 8, 2024
2ab8490
Merge branch 'main' into add-cli-cmds
rafvasq Aug 8, 2024
c42b992
Add min workers
rafvasq Aug 8, 2024
e211a68
Merge branch 'main' into add-cli-cmds
rafvasq Aug 12, 2024
e6c32d1
Fixes readability
rafvasq Aug 12, 2024
6e02a7e
Add custom marker for large download
rafvasq Aug 12, 2024
0ae7f55
Merge branch 'main' into add-cli-cmds
rafvasq Aug 12, 2024
9b6a3d5
Adds hf_data marker and deselects from pytest
rafvasq Aug 13, 2024
2fd00bc
Merge branch 'main' into add-cli-cmds
prashantgupta24 Aug 13, 2024
fbb0cd7
Merge branch 'main' into add-cli-cmds
prashantgupta24 Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ Source = "https://github.com/opendatahub-io/vllm-tgis-adapter"

[project.scripts]
grpc_healthcheck = "vllm_tgis_adapter.healthcheck:cli"
model-util = "vllm_tgis_adapter.tgis_utils.scripts:cli"
text-generation-server = "vllm_tgis_adapter.tgis_utils.scripts:cli"

[project.optional-dependencies]
tests = [
Expand Down Expand Up @@ -83,6 +85,9 @@ vllm_tgis_adapter = [

[tool.pytest.ini_options]
addopts = "-ra"
markers = [
"large"
]

[tool.coverage.run]
branch = true
Expand Down
221 changes: 221 additions & 0 deletions src/vllm_tgis_adapter/tgis_utils/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from __future__ import annotations

import concurrent
import datetime
import json
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path

import torch
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from safetensors.torch import _remove_duplicate_names, load_file, save_file
from tqdm import tqdm

from vllm_tgis_adapter.logging import init_logger

logger = init_logger(__name__)


def weight_hub_files(
model_name: str,
extension: str = ".safetensors",
revision: str | None = None,
auth_token: str | None = None,
) -> list:
"""Get the safetensors filenames on the hub."""
exts = [extension] if isinstance(extension, str) else extension
api = HfApi()
info = api.model_info(model_name, revision=revision, token=auth_token)
filenames = [
s.rfilename
for s in info.siblings
if any(
s.rfilename.endswith(ext)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
for ext in exts
)
]
return filenames


def weight_files(
model_name: str, extension: str = ".safetensors", revision: str | None = None
) -> list:
"""Get the local safetensors filenames."""
filenames = weight_hub_files(model_name, extension)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_name, filename=filename, revision=revision
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_name} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `vllm \
download-weights {model_name}` first."
)
files.append(cache_file)

return files


def download_weights(
model_name: str,
extension: str = ".safetensors",
revision: str | None = None,
auth_token: str | None = None,
) -> list:
"""Download the safetensors files from the hub."""
filenames = weight_hub_files(
model_name, extension, revision=revision, auth_token=auth_token
)

download_function = partial(
hf_hub_download,
repo_id=model_name,
local_files_only=False,
revision=revision,
token=auth_token,
)

logger.info("Downloading %s files for model %s", len(filenames), model_name)
executor = ThreadPoolExecutor(max_workers=min(16, os.cpu_count()))
futures = [
executor.submit(download_function, filename=filename) for filename in filenames
]
files = [
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]

return files


def get_model_path(model_name: str, revision: str | None = None) -> str:
"""Get path to model dir in local huggingface hub (model) cache."""
config_file = "config.json"
config_path = try_to_load_from_cache(
model_name,
config_file,
cache_dir=os.getenv(
"TRANSFORMERS_CACHE"
), # will fall back to HUGGINGFACE_HUB_CACHE
revision=revision,
)
if config_path is not None:
return config_path.removesuffix(f"/{config_file}")
if Path(f"{model_name}/{config_file}").is_file():
return model_name # Just treat the model name as an explicit model path

raise ValueError(f"Weights not found in local cache for model {model_name}")


def local_weight_files(model_path: str, extension: str = ".safetensors") -> list[Path]:
"""Get the local safetensors filenames."""
ext = "" if extension is None else extension
return list(Path(f"{model_path}").glob(f"*{ext}"))


def local_index_files(model_path: str, extension: str = ".safetensors") -> list[Path]:
"""Get the local .index.json filename."""
ext = "" if extension is None else extension
return list(Path(f"{model_path}").glob(f"*{ext}.index.json"))


def convert_file(pt_file: Path, sf_file: Path, discard_names: list[str]) -> None:
"""Convert a pytorch file to a safetensors file.

This will remove duplicate tensors from the file. Unfortunately, this might not
respect *transformers* convention forcing us to check for potentially different
keys during load when looking for specific tensors (making tensor sharing explicit).
"""
loaded = torch.load(pt_file, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)

metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}

sf_file.parent.mkdir(parents=True, exist_ok=True)
save_file(loaded, sf_file, metadata=metadata)
reloaded = load_file(sf_file)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")


def convert_index_file(
source_file: Path, dest_file: Path, pt_files: list[Path], sf_files: list[Path]
) -> None:
weight_file_map = {s.name: d.name for s, d in zip(pt_files, sf_files)}

logger.info("Converting pytorch .bin.index.json files to .safetensors.index.json")
with open(source_file) as f:
index = json.load(f)

index["weight_map"] = {
k: weight_file_map[v] for k, v in index["weight_map"].items()
}

with open(dest_file, "w") as f:
json.dump(index, f)


def convert_files(
pt_files: list[Path], sf_files: list[Path], discard_names: list[str] | None = None
) -> None:
assert len(pt_files) == len(sf_files)

# Filter non-inference files
pairs = [
p
for p in zip(pt_files, sf_files)
if not any(
s in p[0].name
for s in [
"arguments",
"args",
"training",
"optimizer",
"scheduler",
"index",
]
)
]

n = len(pairs)

if n == 0:
logger.warning("No pytorch .bin weight files found to convert")
return

logger.info("Converting %d pytorch .bin files to .safetensors...", n)

for i, (pt_file, sf_file) in enumerate(pairs):
file_count = (i + 1) / n
logger.info('Converting: [%d] "%s"', file_count, pt_file.name)
start = datetime.datetime.now(tz=datetime.UTC)
convert_file(pt_file, sf_file, discard_names)
elapsed = datetime.datetime.now(tz=datetime.UTC) - start
logger.info(
'Converted: [%d] "%s" -- Took: %d seconds',
file_count,
sf_file.name,
elapsed.total_seconds(),
)
Loading