Skip to content

Commit

Permalink
Changes name of cli and fixes bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
  • Loading branch information
rafvasq committed Aug 1, 2024
1 parent 156cb49 commit 90a8b47
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Source = "https://github.com/opendatahub-io/vllm_tgis_adapter"

[project.scripts]
grpc-healthcheck = "vllm_tgis_adapter.healthcheck:cli"
vllm = "vllm_tgis_adapter.tgis_utils.scripts:cli"
adapter = "vllm_tgis_adapter.tgis_utils.scripts:cli"

[project.optional-dependencies]
tests = [
Expand Down
22 changes: 13 additions & 9 deletions src/vllm_tgis_adapter/tgis_utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import concurrent
import datetime
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial
Expand All @@ -15,7 +14,9 @@
from safetensors.torch import _remove_duplicate_names, load_file, save_file
from tqdm import tqdm

logger = logging.getLogger(__name__)
from vllm_tgis_adapter.logging import init_logger

logger = init_logger(__name__)


def weight_hub_files(
Expand Down Expand Up @@ -84,7 +85,7 @@ def download_weights(
token=auth_token,
)

logger.info("Downloading {len(filenames)} files for model {model_name}")
logger.info("Downloading %s files for model %s", len(filenames), model_name)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_function, filename=filename) for filename in filenames
Expand Down Expand Up @@ -115,7 +116,7 @@ def get_model_path(model_name: str, revision: str | None = None) -> str:
except ValueError as e:
err = e

if Path.isfile(f"{model_name}/{config_file}"):
if Path(f"{model_name}/{config_file}").is_file():
return model_name # Just treat the model name as an explicit model path

if err is not None:
Expand All @@ -127,13 +128,13 @@ def get_model_path(model_name: str, revision: str | None = None) -> str:
def local_weight_files(model_path: str, extension: str = ".safetensors") -> list:
"""Get the local safetensors filenames."""
ext = "" if extension is None else extension
return Path.glob(f"{model_path}/*{ext}")
return list(Path(f"{model_path}").glob(f"*{ext}"))


def local_index_files(model_path: str, extension: str = ".safetensors") -> list:
"""Get the local .index.json filename."""
ext = "" if extension is None else extension
return Path.glob(f"{model_path}/*{ext}.index.json")
return list(Path(f"{model_path}").glob(f"*{ext}.index.json"))


def convert_file(pt_file: Path, sf_file: Path, discard_names: list[str]) -> None:
Expand All @@ -157,8 +158,8 @@ def convert_file(pt_file: Path, sf_file: Path, discard_names: list[str]) -> None
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}

dirname = Path.parent(sf_file)
Path(dirname).mkdir(parents=True)
dirname = Path(sf_file).parent
Path(dirname).mkdir(parents=True, exist_ok=True)
save_file(loaded, sf_file, metadata=metadata)
reloaded = load_file(sf_file)
for k in loaded:
Expand Down Expand Up @@ -222,5 +223,8 @@ def convert_files(
convert_file(pt_file, sf_file, discard_names)
elapsed = datetime.datetime.now(tz=datetime.UTC) - start
logger.info(
'Converted: [%d] "%s" -- Took: %d', file_count, sf_file.name, elapsed
'Converted: [%d] "%s" -- Took: %d seconds',
file_count,
sf_file.name,
elapsed.total_seconds(),
)
91 changes: 82 additions & 9 deletions src/vllm_tgis_adapter/tgis_utils/scripts.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# The CLI entrypoint to vLLM.
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

import Path
from vllm.model_executor.model_loader.weight_utils import convert_bin_to_safetensor_file
from vllm.scripts import registrer_signal_handlers
from vllm.utils import FlexibleArgumentParser

from vllm_tgis_adapter.logging import init_logger
from vllm_tgis_adapter.tgis_utils import hub

logger = logging.getLogger(__name__)
logger = init_logger(__name__)

if TYPE_CHECKING:
import argparse
Expand All @@ -29,7 +28,7 @@ def tgis_cli(args: argparse.Namespace) -> None:
args.auto_convert,
)
elif args.command == "convert-to-safetensors":
convert_bin_to_safetensor_file(args.model_name, args.revision)
convert_to_safetensors(args.model_name, args.revision)
elif args.command == "convert-to-fast-tokenizer":
convert_to_fast_tokenizer(args.model_name, args.revision, args.output_path)

Expand Down Expand Up @@ -73,7 +72,7 @@ def download_weights(
".safetensors weights not found, \
converting from pytorch weights..."
)
convert_bin_to_safetensor_file(model_name, revision)
convert_to_safetensors(model_name, revision)
elif not any(f.endswith(".safetensors") for f in files):
logger.info(
".safetensors weights not found on hub, \
Expand All @@ -83,6 +82,80 @@ def download_weights(
convert_to_fast_tokenizer(model_name, revision)


def convert_to_safetensors(
model_name: str,
revision: str | None = None,
) -> None:
# Get local pytorch file paths
model_path = hub.get_model_path(model_name, revision)
local_pt_files = hub.local_weight_files(model_path, ".bin")
local_pt_index_files = hub.local_index_files(model_path, ".bin")
if len(local_pt_index_files) > 1:
logger.info(
"Found more than one .bin.index.json file: %s", local_pt_index_files
)
return
if not local_pt_files:
logger.info("No pytorch .bin files found to convert")
return

local_pt_files = [Path(f) for f in local_pt_files]
local_pt_index_file = local_pt_index_files[0] if local_pt_index_files else None

# Safetensors final filenames
local_st_files = [
p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors"
for p in local_pt_files
]

if any(Path.exists(p) for p in local_st_files):
logger.info(
"Existing .safetensors weights found, remove them first to reconvert"
)
return

try:
import transformers

config = transformers.AutoConfig.from_pretrained(
model_name,
revision=revision,
)
architecture = config.architectures[0]

class_ = getattr(transformers, architecture)

# Name for this variable depends on transformers version
discard_names = getattr(class_, "_tied_weights_keys", [])
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))

except TypeError:
discard_names = []

if local_pt_index_file:
local_pt_index_file = Path(local_pt_index_file)
st_prefix = local_pt_index_file.stem.removeprefix("pytorch_").removesuffix(
".bin.index"
)
local_st_index_file = (
local_pt_index_file.parent / f"{st_prefix}.safetensors.index.json"
)

if Path.exists(local_st_index_file):
logger.info(
"Existing .safetensors.index.json file found, remove it first to \
reconvert"
)
return

hub.convert_index_file(
local_pt_index_file, local_st_index_file, local_pt_files, local_st_files
)

# Convert pytorch weights to safetensors
hub.convert_files(local_pt_files, local_st_files, discard_names)


def convert_to_fast_tokenizer(
model_name: str,
revision: str | None = None,
Expand Down Expand Up @@ -119,7 +192,7 @@ def cli() -> None:
download_weights_parser = subparsers.add_parser(
"download-weights",
help=("Download the weights of a given model"),
usage="vllm download-weights <model_name> [options]",
usage="adapter download-weights <model_name> [options]",
)
download_weights_parser.add_argument("model_name")
download_weights_parser.add_argument("--revision")
Expand All @@ -133,7 +206,7 @@ def cli() -> None:
convert_to_safetensors_parser = subparsers.add_parser(
"convert-to-safetensors",
help=("Convert model weights to safetensors"),
usage="vllm convert-to-safetensors <model_name> [options]",
usage="adapter convert-to-safetensors <model_name> [options]",
)
convert_to_safetensors_parser.add_argument("model_name")
convert_to_safetensors_parser.add_argument("--revision")
Expand All @@ -144,7 +217,7 @@ def cli() -> None:
convert_to_fast_tokenizer_parser = subparsers.add_parser(
"convert-to-fast-tokenizer",
help=("Convert to fast tokenizer"),
usage="vllm convert-to-fast-tokenizer <model_name> [options]",
usage="adapter convert-to-fast-tokenizer <model_name> [options]",
)
convert_to_fast_tokenizer_parser.add_argument("model_name")
convert_to_fast_tokenizer_parser.add_argument("--revision")
Expand Down
3 changes: 2 additions & 1 deletion tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pytest
from huggingface_hub.utils import LocalEntryNotFoundError
from tgis_utils.hub import (

from vllm_tgis_adapter.tgis_utils.hub import (
convert_files,
download_weights,
weight_files,
Expand Down

0 comments on commit 90a8b47

Please sign in to comment.