diff --git a/examples/lora_inference.py b/examples/lora_inference.py index 537e13a42c8ea..df8a4670730ba 100644 --- a/examples/lora_inference.py +++ b/examples/lora_inference.py @@ -4,7 +4,7 @@ sql_lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") -llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True, max_num_seqs=2) +llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True, max_num_seqs=2, dtype='bfloat16') sampling_params = SamplingParams( temperature=0, @@ -20,9 +20,11 @@ "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 ] + +# References from GPU with dtype=bfloat16 expected_output = [ " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 - " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501 " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 @@ -38,7 +40,7 @@ for i, output in enumerate(outputs): prompt = output.prompt generated_text = output.outputs[0].text - matching = expected_output[i].lower() == generated_text.lower() + matching = expected_output[i] == generated_text if not matching: print(f"{i} matching::{matching} Prompt: {prompt!r}, Generated text: {generated_text!r} expected_output: {expected_output[i]!r}") else: diff --git a/examples/multilora_inference_hpu.py b/examples/multilora_inference_hpu.py new file mode 100644 index 0000000000000..308179945752c --- /dev/null +++ b/examples/multilora_inference_hpu.py @@ -0,0 +1,136 @@ +""" +This example shows how to use the multi-LoRA functionality +for offline inference. + +Requires HuggingFace credentials for access to Llama2. +""" + +from typing import List, Optional, Tuple + +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest + + +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + Since we also set `max_loras=1`, the expectation is that the requests + with the second LoRA adapter will be ran after all requests with the + first adapter have finished. + """ + # TODO Fix issues when enabling paramerters [prompt_logprobs=1, presence_penalty=0.2, + # (n=3, best_of=3, use_beam_search=True)] in SamplingParams. + + return [ + ("A robot may not injure a human being", + SamplingParams(temperature=0.0, + logprobs=1, + max_tokens=128), None), + ("To be or not to be,", + SamplingParams(temperature=0.8, + top_k=5, + max_tokens=128), None), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams( + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + result = {} + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + result[request_output.request_id] = request_output.outputs[0].text + return result + + +def initialize_engine() -> LLMEngine: + """Initialize the LLMEngine.""" + # max_loras: controls the number of LoRAs that can be used in the same + # batch. Larger numbers will cause higher memory usage, as each LoRA + # slot requires its own preallocated tensor. + # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger + # numbers will cause higher memory usage. If you know that all LoRAs will + # use the same rank, it is recommended to set this as low as possible. + # max_cpu_loras: controls the size of the CPU LoRA cache. + engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_loras=6, + max_lora_rank=8, + max_num_seqs=16, + dtype='bfloat16') + return LLMEngine.from_engine_args(engine_args) + +# References from GPU with dtype=bfloat16 +expected_output = [ +" or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", +" that is the question.\nI am not sure what I would do if I had to make a decision to live or die.\nI would probably die, just to be sure.\nI think I would die, because if I were to live, I would not be able to live.\nIf I were dead, I could live, but I would not be happy.\nIf I had to choose between living and dying, I would die.\nIf I chose to live, I would die, but I would be happy, because I would be dead.\nSo if I were alive, I would die and be happy.", +" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", +" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", +" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", +" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' " +] + +def main(): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine() + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + result = process_requests(engine, test_prompts) + for idx in result: + generated_text = result[idx] + matching = expected_output[int(idx)] == generated_text + if not matching: + print(f"{idx} matching::{matching} Generated text: {generated_text!r} expected_output: {expected_output[int(idx)]!r}") + else: + print(f"{idx} matching::{matching}") + + +if __name__ == '__main__': + main() diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index acf62c0375000..fd00f323da6fd 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -27,6 +27,7 @@ LinearScalingRotaryEmbedding, RotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.utils import is_hpu if TYPE_CHECKING: pass @@ -64,33 +65,26 @@ def dec(*args, **kwargs): def custom_bgmv(y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, wb_t_all: torch.Tensor, indices: torch.LongTensor, layer_idx: int, scale: float,): - import habana_frameworks.torch as htorch - htorch.core.mark_step() - unique_indices = torch.unique(indices) - unique_indices = unique_indices[unique_indices!=-1] - - for lora_idx in unique_indices: - _indices = torch.where(indices == lora_idx)[0] - x_current = torch.index_select(x, 0, _indices) - wa_current = wa_t_all[lora_idx, layer_idx].transpose(-1, -2) - wb_current = wb_t_all[lora_idx, layer_idx].transpose(-1, -2) - tmp = x_current @ wa_current - tmp = tmp @ wb_current - tmp *= scale - y.index_add_(0, _indices, tmp) + max_loras = wa_t_all.size(0) + indices = (indices + max_loras + 1) % (max_loras + 1) + wa = torch.index_select(wa_t_all, 0, indices)[:,layer_idx,:,:].transpose(-1, -2) + wb = torch.index_select(wb_t_all, 0, indices)[:,layer_idx,:,:].transpose(-1, -2) + + x = x.unsqueeze(1) + out = x @ wa + out = out @ wb + out = out.squeeze(1) + y += out * scale def custom_bgmv_embed(y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, indices: torch.LongTensor, layer_idx: int, scale: float,): - import habana_frameworks.torch as htorch - htorch.core.mark_step() - unique_indices = torch.unique(indices) - - for lora_idx in unique_indices: - _indices = torch.where(indices == lora_idx)[0] - x_current = torch.index_select(x, 0, _indices) - wa_current = wa_t_all[lora_idx, layer_idx].transpose(-1, -2) - tmp = x_current @ wa_current - tmp *= scale - y.index_add_(0, _indices, tmp) + max_loras = wa_t_all.size(0) + indices = (indices + max_loras + 1) % (max_loras + 1) + wa = torch.index_select(wa_t_all, 0, indices)[:,layer_idx,:,:].transpose(-1, -2) + + x = x.unsqueeze(1) + out = x @ wa + out = out.squeeze(1) + y += out * scale def _apply_lora( x: torch.Tensor, @@ -118,8 +112,10 @@ def _apply_lora( x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) - # add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) - custom_bgmv(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + if is_hpu(): + custom_bgmv(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + else: + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) return output.view_as(org_output) @@ -156,16 +152,15 @@ def _apply_lora_packed_nslice( output = output.view(-1, output.shape[-1]) indices = indices.view(-1) offset_left = 0 - slice_size = output.shape[-1] // len(output_slices) for slice_idx in range(len(output_slices)): - # add_lora_slice(output, x, lora_a_stacked[slice_idx], - # lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, - # output_slices[slice_idx]) - # offset_left += output_slices[slice_idx] - start = slice_idx * slice_size - end = min((slice_idx + 1)* slice_size, output.shape[-1]) - custom_bgmv(output[:, start:end], x, lora_a_stacked[slice_idx], + if is_hpu(): + custom_bgmv(output[:, offset_left: offset_left+output_slices[slice_idx]], x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], indices, 0, 1.0) + else: + add_lora_slice(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] return output.view_as(org_output) @@ -365,9 +360,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings = full_lora_a_embeddings.view( full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1], -1) - custom_bgmv_embed(full_output, full_lora_a_embeddings, self.lora_b_stacked, self.indices[:self.indices_len[0]], 0, 1.0) - # bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, - # self.indices[:self.indices_len[0]], 0, 1.0) + if is_hpu(): + custom_bgmv_embed(full_output, full_lora_a_embeddings, self.lora_b_stacked, self.indices[:self.indices_len[0]], 0, 1.0) + else: + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) return full_output.view_as(full_output_org) @classmethod diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 04446d46c3059..34ae8b1373481 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -24,7 +24,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.utils import is_pin_memory_available +from vllm.utils import is_pin_memory_available, is_hpu logger = init_logger(__name__) @@ -465,11 +465,17 @@ def __init__( @property def capacity(self) -> int: - return self.lora_config.max_cpu_loras + if is_hpu(): + return self.lora_config.max_cpu_loras + 1 + else: + return self.lora_config.max_cpu_loras @property def lora_slots(self) -> int: - return self.lora_config.max_loras + if is_hpu(): + return self.lora_config.max_loras + 1 + else: + return self.lora_config.max_loras @property def adapter_slots(self) -> int: