Skip to content

Commit

Permalink
Add host memory profiling to HabanaMemoryProfiler (vllm-project#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel authored Jun 5, 2024
1 parent ab359ac commit cf6952d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 31 deletions.
3 changes: 1 addition & 2 deletions vllm/executor/habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def initialize_cache(self, num_gpu_blocks : int, num_cpu_blocks) -> None:

with HabanaMemoryProfiler() as cache_init_m:
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
logger.info(f"init_cache_engine took "
f"{format_bytes(cache_init_m.consumed_memory)} ({cache_init_m.consumed_memory/HabanaMemoryProfiler.total_memory():.2%} of total memory, gpu_memory_utilization: {self.cache_config.gpu_memory_utilization}, {format_bytes(HabanaMemoryProfiler.current_memory_usage())}/{format_bytes(HabanaMemoryProfiler.total_memory())} used)")
logger.info(f"init_cache_engine took {cache_init_m.get_summary_string()}")

def execute_model(
self,
Expand Down
40 changes: 31 additions & 9 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,33 +496,55 @@ class HabanaMemoryProfiler:
def __init__(self, device=None):
self.device = device

def current_memory_usage() -> float:
# Return the memory usage in bytes.
def current_device_memory_usage() -> float:
# Return the device memory usage in bytes.
free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory - free_hpu_memory

def current_free_memory() -> float:
# Return the memory usage in bytes.
def current_free_device_memory() -> float:
# Return the device memory usage in bytes.
free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory

def total_memory() -> float:
# Return the memory usage in bytes.
def total_device_memory() -> float:
# Return the device memory usage in bytes.
_, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory

def current_host_memory_usage() -> float:
# Return the host memory usage in bytes.
return HabanaMemoryProfiler.total_host_memory() - HabanaMemoryProfiler.current_free_host_memory()

def current_free_host_memory() -> float:
# Return the host memory usage in bytes.
return psutil.virtual_memory().available

def total_host_memory() -> float:
# Return the host memory usage in bytes.
return psutil.virtual_memory().total

def get_summary_string(self):
if getattr(self, 'final_device_memory', None) is None or getattr(self, 'final_host_memory', None) is None:
raise RuntimeError("HabanaMemoryProfiler.get_summary_string() can only be called after closing context manager")
return (f"{format_bytes(self.consumed_device_memory)} of device memory ({format_bytes(self.final_device_memory)}/{format_bytes(HabanaMemoryProfiler.total_device_memory())} used) and "
f"{format_bytes(self.consumed_host_memory)} of host memory ({format_bytes(self.final_host_memory)}/{format_bytes(HabanaMemoryProfiler.total_host_memory())} used)")

def __enter__(self):
# Force garbage collection
gc.collect()
self.initial_memory = HabanaMemoryProfiler.current_memory_usage()
self.initial_device_memory = HabanaMemoryProfiler.current_device_memory_usage()
self.initial_host_memory = HabanaMemoryProfiler.current_host_memory_usage()
# This allows us to call methods of the context manager if needed
return self

def __exit__(self, exc_type, exc_val, exc_tb):
# Force garbage collection
gc.collect()
self.final_memory = HabanaMemoryProfiler.current_memory_usage()
self.consumed_memory = self.final_memory - self.initial_memory
self.final_device_memory = HabanaMemoryProfiler.current_device_memory_usage()
self.final_host_memory = HabanaMemoryProfiler.current_host_memory_usage()
self.consumed_device_memory = self.final_device_memory - self.initial_device_memory
self.consumed_host_memory = self.final_host_memory - self.initial_host_memory



# Adapted from https://stackoverflow.com/a/49361727
Expand Down
44 changes: 24 additions & 20 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,25 @@ def __init__(

def load_model(self) -> None:
with HabanaMemoryProfiler() as m:
self.model = get_model(
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
)
# FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged
self.model = _maybe_wrap_in_hpu_graph(self.model)
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
)
logger.info(f"Pre-loading model weights on {next(self.model.parameters()).device} took {m_getmodel.get_summary_string()}")

self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took "
f"{format_bytes(self.model_memory_usage)} ({format_bytes(HabanaMemoryProfiler.current_memory_usage())}/{format_bytes(HabanaMemoryProfiler.total_memory())} used)")
# FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged
with HabanaMemoryProfiler() as m_wrap:
self.model = _maybe_wrap_in_hpu_graph(self.model)
logger.info(f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}")

self.model_memory_usage = m.consumed_device_memory
logger.info(f"Loading model weights took in total {m.get_summary_string()}")

if self.lora_config:
assert hasattr(self.model, "supported_lora_modules"
Expand Down Expand Up @@ -932,12 +936,12 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, kv_caches) -> None:
gc.collect()

def log_warmup(self, phase, i, max_i, batch_size, seq_len):
free_mem = format_bytes(HabanaMemoryProfiler.current_free_memory())
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
logger.info(f"[Warmup][{phase}][{i+1}/{max_i}] batch_size:{batch_size} seq_len:{seq_len} free_mem:{free_mem}")

def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
mem_usage = 100.0 * HabanaMemoryProfiler.current_memory_usage() / HabanaMemoryProfiler.total_memory()
mem_usage = 100.0 * HabanaMemoryProfiler.current_device_memory_usage() / HabanaMemoryProfiler.total_device_memory()
self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)

Expand Down Expand Up @@ -966,7 +970,7 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem):
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
used_mem = align_workers(mem_prof.consumed_memory, torch.distributed.ReduceOp.MAX)
used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX)
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
Expand All @@ -980,14 +984,14 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
logger.info("Skipping warmup...")
return
self.profiler.start('internal', 'warmup')
start_mem = HabanaMemoryProfiler.current_memory_usage()
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
start_time = time.perf_counter()
self.warmup_all_buckets(self.prompt_buckets, True, kv_caches)
self.warmup_all_buckets(self.decode_buckets, False, kv_caches)

if not self.enforce_eager:
mem_margin = 1.0 - float(os.environ.get('VLLM_GRAPH_MEM_MARGIN', '0.02'))
free_mem = mem_margin * HabanaMemoryProfiler.current_free_memory()
free_mem = mem_margin * HabanaMemoryProfiler.current_free_device_memory()
free_mem = align_workers(free_mem, torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5'))
prompt_available_memory = prompt_graph_mem_ratio * free_mem
Expand All @@ -998,7 +1002,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
self.warmup_graphs(decode_strategy, self.decode_buckets, False, kv_caches, decode_available_memory)

end_time = time.perf_counter()
end_mem = HabanaMemoryProfiler.current_memory_usage()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
elapsed_time = end_time - start_time
logger.info(f"Warmup finished in {elapsed_time:.0f} secs, allocated {format_bytes(end_mem - start_mem)} of device memory")
self.profiler.end()
Expand Down

0 comments on commit cf6952d

Please sign in to comment.