Skip to content

Commit

Permalink
push changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Aug 18, 2024
1 parent e680349 commit ca748a4
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 10 deletions.
2 changes: 1 addition & 1 deletion vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _get_worker_kwargs(
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
local_rank=self.device_config.device.index,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(

# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(
self.num_gpu_blocks, self.device_config.device_type)
self.num_gpu_blocks, self.device_config.device)
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")

def _allocate_kv_cache(
Expand Down
12 changes: 6 additions & 6 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def __init__(

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with CudaMemoryProfiler() as m:
with CudaMemoryProfiler(self.device) as m:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
Expand Down Expand Up @@ -1206,12 +1206,12 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:

# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
input_tokens = torch.zeros(max_batch_size, dtype=torch.long, device=self.device)
input_positions = torch.zeros(max_batch_size, dtype=torch.long, device=self.device)
slot_mapping = torch.empty(max_batch_size, dtype=torch.long, device=self.device)
slot_mapping.fill_(_PAD_SLOT_ID)
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
seq_lens = torch.ones(max_batch_size, dtype=torch.int32, device=self.device)
block_tables = torch.from_numpy(self.graph_block_tables).to(self.device)
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def init_device(self) -> None:

_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
self.init_gpu_memory = torch.cuda.mem_get_info(self.device)[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
Expand Down Expand Up @@ -193,7 +193,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info(self.device)
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
Expand Down
32 changes: 32 additions & 0 deletions vllm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

from transformers import AutoTokenizer

from vllm import SamplingParams, LLM

tok = AutoTokenizer.from_pretrained("facebook/opt-125m")
tok.chat_template = (
"{% for message in messages %}"
"{{'\n\n' if not loop.first else ''}}"
"{{message['role']|capitalize + ': ' +message['content']}}"
"{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}"
"{% endfor %}"
)
prompts = [
{"role": "user", "content": "Compose a speech about the need for more affordable dental care."},
]

prompt_ids = tok.apply_chat_template(prompts, add_generation_prompt=True)
sampling_params = SamplingParams(temperature=0.001, top_p=1.0, max_tokens=1024, include_stop_str_in_output=True)
llm = LLM(model="facebook/opt-125m",
tensor_parallel_size=1,
device="cuda:1")
llmp = llm.llm_engine.model_executor.driver_worker.model_runner.model
print(f"🔥🔥🔥 vllm lives in {llmp.lm_head.weight.device}")
print("prepare to generate")
outputs = llm.generate(prompt_token_ids=[prompt_ids],
sampling_params=sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

0 comments on commit ca748a4

Please sign in to comment.