Skip to content

Commit

Permalink
Implement BGMV algorithm using index_select op on LoRA weights
Browse files Browse the repository at this point in the history
Signed-off-by: Sanju C Sudhakaran <scsudhakaran@habana.ai>
  • Loading branch information
SanjuCSudhakaran committed Aug 2, 2024
1 parent 5c6a312 commit 1ee15b4
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 44 deletions.
8 changes: 5 additions & 3 deletions examples/lora_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

sql_lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")

llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True, max_num_seqs=2)
llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True, max_num_seqs=2, dtype='bfloat16')

sampling_params = SamplingParams(
temperature=0,
Expand All @@ -20,9 +20,11 @@
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]

# References from GPU with dtype=bfloat16
expected_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
Expand All @@ -38,7 +40,7 @@
for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
matching = expected_output[i].lower() == generated_text.lower()
matching = expected_output[i] == generated_text
if not matching:
print(f"{i} matching::{matching} Prompt: {prompt!r}, Generated text: {generated_text!r} expected_output: {expected_output[i]!r}")
else:
Expand Down
136 changes: 136 additions & 0 deletions examples/multilora_inference_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
This example shows how to use the multi-LoRA functionality
for offline inference.
Requires HuggingFace credentials for access to Llama2.
"""

from typing import List, Optional, Tuple

from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


def create_test_prompts(
lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters.
2 requests for base model, 4 requests for the LoRA. We define 2
different LoRA adapters (using the same model for demo purposes).
Since we also set `max_loras=1`, the expectation is that the requests
with the second LoRA adapter will be ran after all requests with the
first adapter have finished.
"""
# TODO Fix issues when enabling paramerters [prompt_logprobs=1, presence_penalty=0.2,
# (n=3, best_of=3, use_beam_search=True)] in SamplingParams.

return [
("A robot may not injure a human being",
SamplingParams(temperature=0.0,
logprobs=1,
max_tokens=128), None),
("To be or not to be,",
SamplingParams(temperature=0.8,
top_k=5,
max_tokens=128), None),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0,
logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0,
logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora2", 2, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
]


def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
result = {}

while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1

request_outputs: List[RequestOutput] = engine.step()

for request_output in request_outputs:
if request_output.finished:
result[request_output.request_id] = request_output.outputs[0].text
return result


def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine."""
# max_loras: controls the number of LoRAs that can be used in the same
# batch. Larger numbers will cause higher memory usage, as each LoRA
# slot requires its own preallocated tensor.
# max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
# numbers will cause higher memory usage. If you know that all LoRAs will
# use the same rank, it is recommended to set this as low as possible.
# max_cpu_loras: controls the size of the CPU LoRA cache.
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
enable_lora=True,
max_loras=6,
max_lora_rank=8,
max_num_seqs=16,
dtype='bfloat16')
return LLMEngine.from_engine_args(engine_args)

# References from GPU with dtype=bfloat16
expected_output = [
" or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194",
" that is the question.\nI am not sure what I would do if I had to make a decision to live or die.\nI would probably die, just to be sure.\nI think I would die, because if I were to live, I would not be able to live.\nIf I were dead, I could live, but I would not be happy.\nIf I had to choose between living and dying, I would die.\nIf I chose to live, I would die, but I would be happy, because I would be dead.\nSo if I were alive, I would die and be happy.",
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ",
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' "
]

def main():
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine()
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
test_prompts = create_test_prompts(lora_path)
result = process_requests(engine, test_prompts)
for idx in result:
generated_text = result[idx]
matching = expected_output[int(idx)] == generated_text
if not matching:
print(f"{idx} matching::{matching} Generated text: {generated_text!r} expected_output: {expected_output[int(idx)]!r}")
else:
print(f"{idx} matching::{matching}")


if __name__ == '__main__':
main()
73 changes: 35 additions & 38 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.utils import is_hpu

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -64,33 +65,26 @@ def dec(*args, **kwargs):


def custom_bgmv(y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, wb_t_all: torch.Tensor, indices: torch.LongTensor, layer_idx: int, scale: float,):
import habana_frameworks.torch as htorch
htorch.core.mark_step()
unique_indices = torch.unique(indices)
unique_indices = unique_indices[unique_indices!=-1]

for lora_idx in unique_indices:
_indices = torch.where(indices == lora_idx)[0]
x_current = torch.index_select(x, 0, _indices)
wa_current = wa_t_all[lora_idx, layer_idx].transpose(-1, -2)
wb_current = wb_t_all[lora_idx, layer_idx].transpose(-1, -2)
tmp = x_current @ wa_current
tmp = tmp @ wb_current
tmp *= scale
y.index_add_(0, _indices, tmp)
max_loras = wa_t_all.size(0)
indices = (indices + max_loras + 1) % (max_loras + 1)
wa = torch.index_select(wa_t_all, 0, indices)[:,layer_idx,:,:].transpose(-1, -2)
wb = torch.index_select(wb_t_all, 0, indices)[:,layer_idx,:,:].transpose(-1, -2)

x = x.unsqueeze(1)
out = x @ wa
out = out @ wb
out = out.squeeze(1)
y += out * scale

def custom_bgmv_embed(y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, indices: torch.LongTensor, layer_idx: int, scale: float,):
import habana_frameworks.torch as htorch
htorch.core.mark_step()
unique_indices = torch.unique(indices)

for lora_idx in unique_indices:
_indices = torch.where(indices == lora_idx)[0]
x_current = torch.index_select(x, 0, _indices)
wa_current = wa_t_all[lora_idx, layer_idx].transpose(-1, -2)
tmp = x_current @ wa_current
tmp *= scale
y.index_add_(0, _indices, tmp)
max_loras = wa_t_all.size(0)
indices = (indices + max_loras + 1) % (max_loras + 1)
wa = torch.index_select(wa_t_all, 0, indices)[:,layer_idx,:,:].transpose(-1, -2)

x = x.unsqueeze(1)
out = x @ wa
out = out.squeeze(1)
y += out * scale

def _apply_lora(
x: torch.Tensor,
Expand Down Expand Up @@ -118,8 +112,10 @@ def _apply_lora(
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
# add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
custom_bgmv(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
if is_hpu():
custom_bgmv(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
else:
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
return output.view_as(org_output)


Expand Down Expand Up @@ -156,16 +152,15 @@ def _apply_lora_packed_nslice(
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
slice_size = output.shape[-1] // len(output_slices)
for slice_idx in range(len(output_slices)):
# add_lora_slice(output, x, lora_a_stacked[slice_idx],
# lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
# output_slices[slice_idx])
# offset_left += output_slices[slice_idx]
start = slice_idx * slice_size
end = min((slice_idx + 1)* slice_size, output.shape[-1])
custom_bgmv(output[:, start:end], x, lora_a_stacked[slice_idx],
if is_hpu():
custom_bgmv(output[:, offset_left: offset_left+output_slices[slice_idx]], x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], indices, 0, 1.0)
else:
add_lora_slice(output, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]
return output.view_as(org_output)


Expand Down Expand Up @@ -365,9 +360,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1], -1)
custom_bgmv_embed(full_output, full_lora_a_embeddings, self.lora_b_stacked, self.indices[:self.indices_len[0]], 0, 1.0)
# bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
# self.indices[:self.indices_len[0]], 0, 1.0)
if is_hpu():
custom_bgmv_embed(full_output, full_lora_a_embeddings, self.lora_b_stacked, self.indices[:self.indices_len[0]], 0, 1.0)
else:
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
return full_output.view_as(full_output_org)

@classmethod
Expand Down
12 changes: 9 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import is_pin_memory_available
from vllm.utils import is_pin_memory_available, is_hpu

logger = init_logger(__name__)

Expand Down Expand Up @@ -465,11 +465,17 @@ def __init__(

@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras
if is_hpu():
return self.lora_config.max_cpu_loras + 1
else:
return self.lora_config.max_cpu_loras

@property
def lora_slots(self) -> int:
return self.lora_config.max_loras
if is_hpu():
return self.lora_config.max_loras + 1
else:
return self.lora_config.max_loras

@property
def adapter_slots(self) -> int:
Expand Down

0 comments on commit 1ee15b4

Please sign in to comment.