Skip to content

Commit

Permalink
Add FileBucketingScheme + use LegacyBucketingScheme by default
Browse files Browse the repository at this point in the history
  • Loading branch information
madamczykhabana committed Feb 24, 2025
1 parent 37d74e0 commit 1d1d30f
Showing 1 changed file with 42 additions and 18 deletions.
60 changes: 42 additions & 18 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
LORA_WARMUP_RANK = 8

VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true'
VLLM_BUCKETING_SCHEME = os.environ.get('VLLM_BUCKETING_SCHEME', 'polynomial')
VLLM_BUCKETING_SCHEME = os.environ.get('VLLM_BUCKETING_SCHEME', 'legacy')
VLLM_STRICT_BUCKETS = os.environ.get('VLLM_STRICT_BUCKETS', 'false') == 'true'
DUMMY_TOKEN_ID = -1

Expand Down Expand Up @@ -245,7 +245,41 @@ def generate_prompt_buckets(self):
print(msg)


class PolynomialBucketingScheme:
class BucketingScheme:
def find_bucket(self, real_bs, real_blocks):
for bs, blocks in self.buckets:
if bs < real_bs:
continue
for b in blocks:
if b < real_blocks:
continue
return (bs, b)
assert False, "Couldn't find bucket for {} {}".format(real_bs, real_blocks)

def list_buckets(self):
buckets = [[(bs, b) for b in blocks] for bs, blocks in self.buckets]
buckets = list(itertools.chain(*buckets))
return buckets


class FileBucketingScheme(BucketingScheme):
def __init__(self, filename, max_bs, max_blocks):
self.buckets = self._read_buckets(filename, max_bs, max_blocks)
logger.info('Decode buckets [file]: {}'.format(self.list_buckets()))

def _read_buckets(self, filename, max_bs, max_blocks):
buckets = {}
with open(filename) as f:
for line in f.readlines():
bs, blocks = line.strip().split(':')
bs = min(int(bs), max_bs)
blocks = set(min(int(b), max_blocks) for b in blocks.split())
buckets.setdefault(bs, set())
buckets[bs].update(blocks)
return [(bs, list(sorted(buckets[bs]))) for bs in sorted(buckets.keys())]


class PolynomialBucketingScheme(BucketingScheme):
def __init__(self, max_bs, max_blocks):
min_blocks_per_seq = 2
max_blocks_per_seq = 8
Expand All @@ -256,6 +290,7 @@ def __init__(self, max_bs, max_blocks):
bs_div = 4
bs_range = self._gen_bs_range(min_bs, max_bs, bs_div)
self.buckets = self._gen_buckets(bs_range, max_block_steps, min_blocks_per_seq, max_blocks_per_seq, max_blocks, block_beta, block_rounding)
logger.info('Decode buckets [polynomial]: {}'.format(self.list_buckets()))

def _poly_fn(self, min_val, max_val, alpha, beta):
z = (max_val - min_val) * beta
Expand Down Expand Up @@ -304,21 +339,6 @@ def _gen_buckets(self, bs_range, max_block_steps, min_blocks_per_seq, max_blocks
buckets.append((bs, block_range))
return buckets

def list_buckets(self):
buckets = [[(bs, b) for b in blocks] for bs, blocks in self.buckets]
buckets = list(itertools.chain(*buckets))
return buckets

def find_bucket(self, real_bs, real_blocks):
for bs, blocks in self.buckets:
if bs < real_bs:
continue
for b in blocks:
if b < real_blocks:
continue
return (bs, b)
assert False, "Couldn't find bucket for {} {}".format(real_bs, real_blocks)


class DummyBucketingScheme:
def list_buckets(self):
Expand Down Expand Up @@ -892,6 +912,10 @@ def _init_buckets(self, max_blocks):
if VLLM_BUCKETING_SCHEME == 'polynomial':
self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx)
self.bucketing_ctx.decode_scheme = PolynomialBucketingScheme(self.max_num_seqs, max_blocks)
elif VLLM_BUCKETING_SCHEME.startswith('file:'):
_, filename = VLLM_BUCKETING_SCHEME.split(':')
self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx)
self.bucketing_ctx.decode_scheme = FileBucketingScheme(filename, self.max_num_seqs, max_blocks)
else:
self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx)
self.bucketing_ctx.decode_scheme = LegacyDecodeBucketingScheme(legacy_ctx, max_blocks)
Expand Down Expand Up @@ -2950,4 +2974,4 @@ def _patch_prev_output(self):
# This is a hack. Assigning output_token_ids triggers
# a cache recomputation and we only need to update the last token
seq_data.output_token_ids_array[-1] = real_out
seq_data._cached_all_token_ids[-1] = real_out
seq_data._cached_all_token_ids[-1] = real_out

0 comments on commit 1d1d30f

Please sign in to comment.