Skip to content

Commit

Permalink
Enable PT profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
kdamaszk committed Aug 26, 2024
1 parent 55ea658 commit e805b88
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@
_TYPE_CACHE = {}


def setup_profiler():
prof_type = os.environ.get('VLLM_PT_PROFILE_METHOD', 'pt')
assert prof_type in ['pt']
schedule = torch.profiler.schedule(wait=1, warmup=2, active=1, repeat=1)
profiler = torch.profiler.profile(
schedule=schedule,
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.HPU],
on_trace_ready=torch.profiler.tensorboard_trace_handler('.', use_gzip=True),
record_shapes=False,
with_stack=True
)
return profiler


def read_bucket_settings(phase: str, dim: str, **defaults):
"""Read bucketing configuration from env variables.
Expand Down Expand Up @@ -1095,7 +1109,8 @@ def warmup_scenario(self,
seq_len,
is_prompt,
kv_caches,
is_profile_run=False) -> None:
is_profile_run=False,
profile=False) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
scenario_name = ("warmup_"
f"{'prompt' if is_prompt else 'decode'}_"
Expand Down Expand Up @@ -1127,7 +1142,7 @@ def warmup_scenario(self,
for idx in range(max_num_seqs)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs else 1
times = 5 if use_graphs or profile else 1
if self.lora_config and not is_profile_run:
lora_mapping = LoRAMapping(
[0] * batch_size * seq_len,
Expand All @@ -1144,10 +1159,19 @@ def warmup_scenario(self,
for i in range(batch_size)
]
torch.hpu.synchronize()
for _ in range(times):
profiler = None
if profile and self.is_driver_worker:
profiler = setup_profiler()
profiler.start()
for i in range(times):
print(f"Run {i}...")
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True)
torch.hpu.synchronize()
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()

Expand Down Expand Up @@ -1255,6 +1279,17 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):

@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graphs = profile.split('_')
is_prompt = phase == 'prompt'
bs = int(bs)
seq_len = int(seq_len)
graphs = graphs == 't'
if graphs:
self.graphed_buckets.add((bs, seq_len, is_prompt))
self.warmup_scenario(bs, seq_len, is_prompt, kv_caches, False, True)
assert False

if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true':
logger.info("Skipping warmup...")
return
Expand Down

0 comments on commit e805b88

Please sign in to comment.