Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port PT Profiler to habana_main #256

Merged
merged 10 commits into from
Sep 11, 2024
46 changes: 40 additions & 6 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,26 @@ def align_workers(value, op):
return value_t.item()


def setup_profiler():
schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1)
DEVICE = 'hpu'
activities = [torch.profiler.ProfilerActivity.CPU]
activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE ==
'hpu' else [])
#from habana_frameworks.torch.activity_profiler import DebugActivity
#debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS]

profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
#debug_activities=debug_activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.',
use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler


def pad_list(list, k, v):
target_len = round_up(len(list), k)
padding = target_len - len(list)
Expand Down Expand Up @@ -1237,11 +1257,7 @@ def profile_run(self) -> None:
max_seq_len = min(self.prompt_seq_bucket_cfg[-1],
self.max_num_batched_tokens // max_batch_size)

self.warmup_scenario(max_batch_size,
max_seq_len,
True,
kv_caches,
is_profile_run=True)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches)
return

def warmup_scenario(self,
Expand Down Expand Up @@ -1281,7 +1297,7 @@ def warmup_scenario(self,
for idx in range(max_num_seqs)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs else 1
times = 3 if use_graphs or is_profile_run else 1
if self.lora_config and not is_profile_run:
lora_mapping = LoRAMapping(
[0] * batch_size * seq_len,
Expand Down Expand Up @@ -1312,10 +1328,19 @@ def warmup_scenario(self,
for i, b in enumerate(blocks)
]
torch.hpu.synchronize()
profiler = None
if is_profile_run and self.is_driver_worker:
profiler = setup_profiler()
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True)
torch.hpu.synchronize()
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()

def remove_all_loras(self):
Expand Down Expand Up @@ -1427,6 +1452,15 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):

@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graph = profile.split('_')
is_prompt = phase == 'prompt'
graphs = graph == 't'
if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
True)
raise AssertionError("Finished profiling")
if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true':
logger.info("Skipping warmup...")
return
Expand Down
Loading