Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port PT Profiler to habana_main #256

Merged
merged 10 commits into from
Sep 11, 2024
94 changes: 91 additions & 3 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
import math
import operator
import sys
import os
import time
from enum import IntEnum
Expand Down Expand Up @@ -183,6 +184,75 @@ def align_workers(value, op):
return value_t.item()


def setup_profiler():
DEVICE='hpu'
STEPS=3
activities = [torch.profiler.ProfilerActivity.CPU]
activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == 'hpu' else [])
wait = 0
active = 1
warmup = STEPS - active

schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1)
profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.', use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler


def pt_profiler(schedule):
DEVICE = 'hpu'
activities = [torch.profiler.ProfilerActivity.CPU]
activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == 'hpu' else [])
#from habana_frameworks.torch.activity_profiler import DebugActivity
#debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS]

profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
#debug_activities=debug_activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.', use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler


def hltv_profiler(schedule):
pt_tools_path = os.environ.get('PT_TOOLS_PATH', None)
assert pt_tools_path is not None, "Need to specify PT_TOOLS_PATH to use hltv profiling method"
sys.path.append(pt_tools_path)
from topologies import SynapseProfilerApi, TraceType
api = SynapseProfilerApi()
class SynapseProfiler:
def check(self):
if schedule(self.cur_step) == torch.profiler.ProfilerAction.RECORD_AND_SAVE:
api.profiler_start(TraceType.TraceAll, 0)
def start(self):
self.cur_step = 0
self.check()
def step(self):
self.cur_step = self.cur_step + 1
self.check()
def stop(self):
api.profiler_stop(TraceType.TraceAll, 0)
api.profiler_get_trace_json(TraceType.TraceAll, 0)
return SynapseProfiler()


def setup_profiler():
prof_wait = 0
prof_warmup = 2
prof_active = 1
prof_type = os.environ.get('VLLM_PT_PROFILE_METHOD', 'pt')
assert prof_type in ['pt', 'hltv']
method = pt_profiler if prof_type == 'pt' else hltv_profiler
schedule = torch.profiler.schedule(wait=prof_wait, warmup=prof_warmup, active=prof_active, repeat=1)
return method(schedule)


class HpuModelAdapter():

def __init__(self, model, enforce_eager):
Expand Down Expand Up @@ -1146,8 +1216,7 @@ def profile_run(self) -> None:
self.warmup_scenario(max_batch_size,
max_seq_len,
True,
kv_caches,
is_profile_run=True)
kv_caches)

def warmup_scenario(self,
batch_size,
Expand Down Expand Up @@ -1186,7 +1255,7 @@ def warmup_scenario(self,
for idx in range(max_num_seqs)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs else 1
times = 3 if use_graphs or is_profile_run else 1
if self.lora_config and not is_profile_run:
lora_mapping = LoRAMapping(
[0] * batch_size * seq_len,
Expand All @@ -1203,10 +1272,19 @@ def warmup_scenario(self,
for i in range(batch_size)
]
torch.hpu.synchronize()
profiler = None
if is_profile_run and self.is_driver_worker:
profiler = setup_profiler()
profiler.start()
self.profiler.start('internal', scenario_name)
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True)
torch.hpu.synchronize()
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()

Expand Down Expand Up @@ -1314,6 +1392,16 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):

@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graphs = profile.split('_')
is_prompt = phase == 'prompt'
bs = int(bs)
seq_len = int(seq_len)
graphs = graphs == 't'
if graphs:
self.graphed_buckets.add((bs, seq_len, is_prompt))
self.warmup_scenario(bs, seq_len, is_prompt, kv_caches, True)
assert False
if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true':
logger.info("Skipping warmup...")
return
Expand Down