diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 980092b5f7ac4..1f4e1defd148c 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -75,7 +75,7 @@ LORA_WARMUP_RANK = 8 VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true' -VLLM_BUCKETING_SCHEME = os.environ.get('VLLM_BUCKETING_SCHEME', 'polynomial') +VLLM_BUCKETING_SCHEME = os.environ.get('VLLM_BUCKETING_SCHEME', 'legacy') VLLM_STRICT_BUCKETS = os.environ.get('VLLM_STRICT_BUCKETS', 'false') == 'true' DUMMY_TOKEN_ID = -1 @@ -245,7 +245,41 @@ def generate_prompt_buckets(self): print(msg) -class PolynomialBucketingScheme: +class BucketingScheme: + def find_bucket(self, real_bs, real_blocks): + for bs, blocks in self.buckets: + if bs < real_bs: + continue + for b in blocks: + if b < real_blocks: + continue + return (bs, b) + assert False, "Couldn't find bucket for {} {}".format(real_bs, real_blocks) + + def list_buckets(self): + buckets = [[(bs, b) for b in blocks] for bs, blocks in self.buckets] + buckets = list(itertools.chain(*buckets)) + return buckets + + +class FileBucketingScheme(BucketingScheme): + def __init__(self, filename, max_bs, max_blocks): + self.buckets = self._read_buckets(filename, max_bs, max_blocks) + logger.info('Decode buckets [file]: {}'.format(self.list_buckets())) + + def _read_buckets(self, filename, max_bs, max_blocks): + buckets = {} + with open(filename) as f: + for line in f.readlines(): + bs, blocks = line.strip().split(':') + bs = min(int(bs), max_bs) + blocks = set(min(int(b), max_blocks) for b in blocks.split()) + buckets.setdefault(bs, set()) + buckets[bs].update(blocks) + return [(bs, list(sorted(buckets[bs]))) for bs in sorted(buckets.keys())] + + +class PolynomialBucketingScheme(BucketingScheme): def __init__(self, max_bs, max_blocks): min_blocks_per_seq = 2 max_blocks_per_seq = 8 @@ -256,6 +290,7 @@ def __init__(self, max_bs, max_blocks): bs_div = 4 bs_range = self._gen_bs_range(min_bs, max_bs, bs_div) self.buckets = self._gen_buckets(bs_range, max_block_steps, min_blocks_per_seq, max_blocks_per_seq, max_blocks, block_beta, block_rounding) + logger.info('Decode buckets [polynomial]: {}'.format(self.list_buckets())) def _poly_fn(self, min_val, max_val, alpha, beta): z = (max_val - min_val) * beta @@ -304,21 +339,6 @@ def _gen_buckets(self, bs_range, max_block_steps, min_blocks_per_seq, max_blocks buckets.append((bs, block_range)) return buckets - def list_buckets(self): - buckets = [[(bs, b) for b in blocks] for bs, blocks in self.buckets] - buckets = list(itertools.chain(*buckets)) - return buckets - - def find_bucket(self, real_bs, real_blocks): - for bs, blocks in self.buckets: - if bs < real_bs: - continue - for b in blocks: - if b < real_blocks: - continue - return (bs, b) - assert False, "Couldn't find bucket for {} {}".format(real_bs, real_blocks) - class DummyBucketingScheme: def list_buckets(self): @@ -892,6 +912,10 @@ def _init_buckets(self, max_blocks): if VLLM_BUCKETING_SCHEME == 'polynomial': self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx) self.bucketing_ctx.decode_scheme = PolynomialBucketingScheme(self.max_num_seqs, max_blocks) + elif VLLM_BUCKETING_SCHEME.startswith('file:'): + _, filename = VLLM_BUCKETING_SCHEME.split(':') + self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx) + self.bucketing_ctx.decode_scheme = FileBucketingScheme(filename, self.max_num_seqs, max_blocks) else: self.bucketing_ctx.prompt_scheme = LegacyPromptBucketingScheme(legacy_ctx) self.bucketing_ctx.decode_scheme = LegacyDecodeBucketingScheme(legacy_ctx, max_blocks) @@ -2950,4 +2974,4 @@ def _patch_prev_output(self): # This is a hack. Assigning output_token_ids triggers # a cache recomputation and we only need to update the last token seq_data.output_token_ids_array[-1] = real_out - seq_data._cached_all_token_ids[-1] = real_out \ No newline at end of file + seq_data._cached_all_token_ids[-1] = real_out