Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable LoRA support for HPU #170

Merged
merged 10 commits into from
Aug 20, 2024
47 changes: 47 additions & 0 deletions examples/lora_inference_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from huggingface_hub import snapshot_download

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

sql_lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")

llm = LLM(model="meta-llama/Llama-2-7b-hf",
enable_lora=True,
max_num_seqs=2,
dtype='bfloat16')

sampling_params = SamplingParams(temperature=0,
max_tokens=1024,
stop=["[/assistant]"])

prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]

expected_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
]

outputs = llm.generate(prompts,
sampling_params,
lora_request=LoRARequest("sql_adapter", 1,
sql_lora_path))

for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
match = expected_output[i] == generated_text
if not match:
print(
f"Comparison failed for request_id::{i}\n\t[PROMPT]{prompt!r}\n\t[GENERATED]{generated_text!r}\n\t[EXPECTED]{expected_output[i]!r}" # noqa: E501
)
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,17 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
yield caplog


def is_hpu():
from importlib import util
return util.find_spec('habana_frameworks') is not None


@pytest.fixture(scope="session")
def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context
in current process."""

if is_hpu():
return torch.hpu.device_count()

return cuda_device_count_stateless()
8 changes: 7 additions & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,19 @@ class ContextInfo(TypedDict):
}]


def is_hpu():
from importlib import util
return util.find_spec('habana_frameworks') is not None


def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
if not is_hpu():
torch.cuda.empty_cache()
ray.shutdown()


Expand Down
100 changes: 100 additions & 0 deletions tests/lora/test_llama_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from multiprocessing import Process
from typing import List

from conftest import cleanup

import vllm
from vllm.lora.request import LoRARequest

MODEL_PATH = "meta-llama/Llama-2-7b-hf"


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


def _test_llama_lora(sql_lora_files, tp_size):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
dtype='float32',
tensor_parallel_size=tp_size)

expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
]
expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
]

print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output

print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output

print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output

print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output

print("removing lora")
cleanup()


def test_llama_lora_1x(sql_lora_files):
p = Process(target=_test_llama_lora, args=(sql_lora_files, 1))
p.start()
p.join()
assert p.exitcode == 0


def test_llama_lora_2x(sql_lora_files):
# Work-around to resolve stalling issue in multi-card scenario
p = Process(target=_test_llama_lora, args=(sql_lora_files, 2))
p.start()
p.join()
assert p.exitcode == 0


def test_llama_lora_4x(sql_lora_files):
# Work-around to resolve stalling issue in multi-card scenario
p = Process(target=_test_llama_lora, args=(sql_lora_files, 4))
p.start()
p.join()
assert p.exitcode == 0
Loading
Loading