From 6e6a67e01af177de016e0e0b11e874485ff142ed Mon Sep 17 00:00:00 2001 From: Tony Lin Date: Mon, 11 Nov 2024 06:01:24 +0000 Subject: [PATCH] Enable DeepseekV2 Lite/Chat models --- .jenkins/test_config.yaml | 3 + README_GAUDI.md | 24 +- .../getting_started/gaudi-installation.rst | 1 + examples/offline_inference_eaglespeculator.py | 68 ++++++ .../offline_inference_medusaspeculator.py | 67 ++++++ requirements-hpu.txt | 2 +- vllm/attention/backends/hpu_attn.py | 4 + vllm/executor/hpu_executor.py | 1 - vllm/executor/ray_hpu_executor.py | 3 +- .../model_executor/layers/rotary_embedding.py | 6 +- vllm/model_executor/models/deepseek_v2.py | 54 ++++- vllm/model_executor/models/gpt_bigcode.py | 6 +- vllm/model_executor/models/llama.py | 7 +- vllm/spec_decode/hpu_draft_model_runner.py | 62 +++++ vllm/spec_decode/medusa_worker.py | 4 +- vllm/spec_decode/multi_step_worker.py | 21 +- vllm/spec_decode/spec_decode_worker.py | 19 +- vllm/spec_decode/target_model_runner.py | 45 +++- vllm/worker/hpu_model_runner.py | 215 ++++++++++-------- vllm/worker/hpu_worker.py | 30 ++- vllm/worker/selector.py | 35 ++- 21 files changed, 509 insertions(+), 168 deletions(-) create mode 100644 examples/offline_inference_eaglespeculator.py create mode 100644 examples/offline_inference_medusaspeculator.py create mode 100644 vllm/spec_decode/hpu_draft_model_runner.py diff --git a/.jenkins/test_config.yaml b/.jenkins/test_config.yaml index e57bd37c5fb7e..3707725161576 100644 --- a/.jenkins/test_config.yaml +++ b/.jenkins/test_config.yaml @@ -27,6 +27,9 @@ stages: - name: gsm8k_small_g3_tp1_fp8 flavor: g3 command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-fp8.txt -t 1 + - name: gsm8k_small_g3_tp2_fp8 + flavor: g3 + command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-fp8.txt -t 2 - name: test_gsm8k_mss steps: - name: gsm8k_small_g3_tp1_mss diff --git a/README_GAUDI.md b/README_GAUDI.md index 61aee39768210..22e4320eec384 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -277,16 +277,34 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi - block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`): `block_size` - block size step (`VLLM_DECODE_BLOCK_BUCKET_STEP`): `block_size` - block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`): `max(128, (max_num_seqs*max_model_len)/block_size)` -- ``VLLM_HANDLE_TOPK_DUPLICATES``: if ``true``, will handle duplicates that are outside of top-k, ``false`` by default +- `VLLM_HANDLE_TOPK_DUPLICATES`: if ``true``, will handle duplicates that are outside of top-k, ``false`` by default +- `VLLM_CONFIG_HIDDEN_LAYERS`: configure how many hidden layers to run in a HPUGraph for model splitting among hidden layers when TP is 1. The default is 1. It helps with throughput improvement under inter-token latency limitation for some models. Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: - `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used, `1` is default - `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs -# Quantization and FP8 model calibration process +# Quantization, FP8 inference and model calibration process -The FP8 model calibration procedure has been described as a part of [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. +> [!NOTE] +> Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + +Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command: +```bash +export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json +vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --weights-load-device cpu --tensor_paralel_size 8 +``` + +`QUANT_CONFIG` is an environment variable that points to the measurement or quantization configuration file. The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference. + +> [!TIP] +> If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments, as it causes a dramatic performance drop. + +> [!TIP] +> When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use these two environment variables: +> - `VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes. +> - `VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes. # Troubleshooting: Tweaking HPU Graphs diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 8ff983044dd76..79d40293fd470 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -379,6 +379,7 @@ Environment variables - sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``block_size`` - sequence length max (``VLLM_DECODE_BLOCK_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` - ``VLLM_HANDLE_TOPK_DUPLICATES``: if ``true``, will handle duplicates that are outside of top-k, ``false`` by default +- ``VLLM_CONFIG_HIDDEN_LAYERS``: configure how many hidden layers to run in a HPUGraph for model splitting among hidden layers when TP is 1. The default is 1. It helps with throughput improvement under inter-token latency limitation for some models. Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/examples/offline_inference_eaglespeculator.py b/examples/offline_inference_eaglespeculator.py new file mode 100644 index 0000000000000..e13965d77e6ea --- /dev/null +++ b/examples/offline_inference_eaglespeculator.py @@ -0,0 +1,68 @@ +import gc +import time +from typing import List + +from vllm import LLM, SamplingParams + + +def time_generation(llm: LLM, prompts: List[str], + sampling_params: SamplingParams): + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + # Warmup first + llm.generate(prompts, sampling_params) + llm.generate(prompts, sampling_params) + start = time.time() + outputs = llm.generate(prompts, sampling_params) + end = time.time() + latency_per_token = (end - start) / sum( + [len(o.outputs[0].token_ids) for o in outputs]) + # Print the outputs. + ret = [] + for output in outputs: + generated_text = output.outputs[0].text + ret.append(generated_text) + return ret, latency_per_token + + +if __name__ == "__main__": + + prompts = [ + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=20) + + # Create an LLM without spec decoding + print("==============Without speculation==================") + llm = LLM(model="JackFram/llama-68m") + + ret_non_spec, latency_per_token_non_spec = time_generation( + llm, prompts, sampling_params) + + del llm + gc.collect() + + # Create an LLM with spec decoding + print("==============With speculation=====================") + llm = LLM( + model="JackFram/llama-68m", + speculative_model="abhigoyal/vllm-eagle-llama-68m-random", + num_speculative_tokens=5, + # These are currently required for MLPSpeculator decoding + use_v2_block_manager=True, + ) + + ret_spec, latency_per_token_spec = time_generation(llm, prompts, + sampling_params) + + del llm + gc.collect() + print("================= Summary =====================") + print("input is ", prompts, "\n") + print("Non Spec Decode - latency_per_token is ", + latency_per_token_non_spec) + print("Generated Text is :", ret_non_spec, "\n") + print("Spec Decode - latency_per_token is ", latency_per_token_spec) + print("Generated Text is :", ret_spec) diff --git a/examples/offline_inference_medusaspeculator.py b/examples/offline_inference_medusaspeculator.py new file mode 100644 index 0000000000000..100d452d1bc75 --- /dev/null +++ b/examples/offline_inference_medusaspeculator.py @@ -0,0 +1,67 @@ +import gc +import time +from typing import List + +from vllm import LLM, SamplingParams + + +def time_generation(llm: LLM, prompts: List[str], + sampling_params: SamplingParams): + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + # Warmup first + llm.generate(prompts, sampling_params) + llm.generate(prompts, sampling_params) + start = time.time() + outputs = llm.generate(prompts, sampling_params) + end = time.time() + latency_per_token = (end - start) / sum( + [len(o.outputs[0].token_ids) for o in outputs]) + # Print the outputs. + ret = [] + for output in outputs: + generated_text = output.outputs[0].text + ret.append(generated_text) + return ret, latency_per_token + + +if __name__ == "__main__": + + prompts = [ + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=20) + + # Create an LLM without spec decoding + print("==============Without speculation==================") + llm = LLM(model="JackFram/llama-68m") + + ret_non_spec, latency_per_token_non_spec = time_generation( + llm, prompts, sampling_params) + + del llm + gc.collect() + + # Create an LLM with spec decoding + print("==============With speculation=====================") + llm = LLM( + model="JackFram/llama-68m", + speculative_model="abhigoyal/vllm-medusa-llama-68m-random", + num_speculative_tokens=5, + use_v2_block_manager=True, + ) + + ret_spec, latency_per_token_spec = time_generation(llm, prompts, + sampling_params) + + del llm + gc.collect() + print("================= Summary =====================") + print("input is ", prompts, "\n") + print("Non Spec Decode - latency_per_token is ", + latency_per_token_non_spec) + print("Generated Text is :", ret_non_spec, "\n") + print("Spec Decode - latency_per_token is ", latency_per_token_spec) + print("Generated Text is :", ret_spec) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index d7eafaa86638d..ede320866b9b2 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,5 +8,5 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0063520 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@250622e diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 6448278c2f10c..28e27919f2e3b 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -113,6 +113,8 @@ def __init__( self.matmul_qk = Matmul() self.softmax = Softmax() self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads @@ -251,6 +253,8 @@ def forward( scale=self.scale, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, + batch2block_matmul_op=self.batch2block_matmul, + block2batch_matmul_op=self.block2batch_matmul, keys_fetch_func=self.k_cache.fetch_from_cache, values_fetch_func=self.v_cache.fetch_from_cache) # Reshape the output tensor. diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py index 1f78144814c24..e82cc10d0e9f0 100644 --- a/vllm/executor/hpu_executor.py +++ b/vllm/executor/hpu_executor.py @@ -42,7 +42,6 @@ def _get_worker_kwargs( rank=rank, distributed_init_method=distributed_init_method, is_driver_worker=rank == 0, - speculative_config=self.speculative_config, ) def _create_worker(self, diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index ebfaafd29f92c..041511d64cdfa 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -73,9 +73,8 @@ def _init_executor(self) -> None: def shutdown(self) -> None: if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() - import ray for worker in self.workers: - ray.kill(worker) + worker.__ray_terminate__.remote() self.forward_dag = None def finish_measurements(self): diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b81ef6c03278b..46ba73feb1a06 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -28,6 +28,8 @@ from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +is_hpu = current_platform.is_hpu() def _rotate_neox(x: torch.Tensor) -> torch.Tensor: x1 = x[..., :x.shape[-1] // 2] @@ -653,7 +655,7 @@ def __init__( def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + 0, self.rotary_dim, 2, dtype=torch.float, device="hpu" if is_hpu else "cuda") / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) @@ -672,7 +674,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device="cuda", + device="hpu" if is_hpu else "cuda", dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = (freqs.cos() * self.mscale) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 834be78bce87b..ce96182e6f876 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -52,7 +52,8 @@ from .interfaces import SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) - +from vllm.platforms import current_platform +is_hpu = current_platform.is_hpu() class DeepseekV2MLP(nn.Module): @@ -110,8 +111,21 @@ def __init__( if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") + if is_hpu: + self.experts = FusedMoE(num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=False, + #num_expert_group=config.n_group, + #topk_group=config.topk_group, + prefix=f"{prefix}.experts") - self.experts = FusedMoE(num_experts=config.n_routed_experts, + else: + self.experts = FusedMoE(num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, @@ -276,9 +290,19 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: + if is_hpu: + # need reshape from tensor(x0, y0) to tensor(x1) for hpu + _batch_size = positions.shape[0] + positions = positions.reshape(positions.shape[0] * positions.shape[1]) + hidden_states = hidden_states.reshape(hidden_states.shape[0] * hidden_states.shape[1], hidden_states.shape[2]) if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] - q = self.q_a_layernorm(q) + if is_hpu: + # w/a of SW-208144 + q = self.q_a_proj(hidden_states)[0].unsqueeze(0) + q = self.q_a_layernorm(q).squeeze(0) + else: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: @@ -290,7 +314,10 @@ def forward( kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) + if is_hpu: + kv_a = self.kv_a_layernorm(kv_a.contiguous().unsqueeze(0)).squeeze(0) # w/a of SW-208144 + else: + kv_a = self.kv_a_layernorm(kv_a.contiguous()) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) @@ -310,11 +337,21 @@ def forward( v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(-1, self.num_local_heads * 256) + if is_hpu: + # need restore from tensor(x0, y0) to tensor(x1, y1, z1) for hpu + q = q.reshape(_batch_size, q.shape[0] // _batch_size, q.shape[1]) + k = k.reshape(_batch_size, k.shape[0] // _batch_size, k.shape[1]) + v = v.reshape(_batch_size, v.shape[0] // _batch_size, v.shape[1]) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + if is_hpu: + # need restore from tensor(x0, y0, z0) to tensor(x1, y1) for hpu + attn_output = attn_output.reshape(attn_output.shape[0] * attn_output.shape[1], attn_output.shape[2]) attn_output = attn_output.view( -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( -1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) + if is_hpu: + output = output.reshape(_batch_size, output.shape[0] // _batch_size, output.shape[1]) return output @@ -382,6 +419,8 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: + if is_hpu: + _batch_size = positions.shape[0] # Self Attention if residual is None: residual = hidden_states @@ -399,7 +438,12 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + if is_hpu: + # need reshape from tensor(x0, y0) to tensor(x1) for hpu + hidden_states = hidden_states.reshape(hidden_states.shape[0] * hidden_states.shape[1], hidden_states.shape[2]) hidden_states = self.mlp(hidden_states) + if is_hpu: + hidden_states = hidden_states.reshape(_batch_size, hidden_states.shape[0] // _batch_size, hidden_states.shape[1]) return hidden_states, residual diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index e27200d8e5167..1f4a785b7d379 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -222,6 +222,10 @@ def __init__( self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + if is_hpu: + import os + self.config_hidden_layers = int( + os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) def forward( self, @@ -246,7 +250,7 @@ def forward( hidden_states = layer(hidden_states, kv_caches[i - self.start_layer], attn_metadata) - if is_hpu: + if is_hpu and i % self.config_hidden_layers == 0: htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 7403a2e56ff94..5f01ca42cb744 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -316,6 +316,11 @@ def __init__( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + if is_hpu: + import os + self.config_hidden_layers = int( + os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -347,7 +352,7 @@ def forward( hidden_states, residual = layer(positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual) - if is_hpu: + if is_hpu and i % self.config_hidden_layers == 0: htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/spec_decode/hpu_draft_model_runner.py b/vllm/spec_decode/hpu_draft_model_runner.py new file mode 100644 index 0000000000000..a5943bdd7d804 --- /dev/null +++ b/vllm/spec_decode/hpu_draft_model_runner.py @@ -0,0 +1,62 @@ +from typing import List, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import IntermediateTensors +from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerBaseCls +from vllm.worker.hpu_model_runner import ModelInputForHPUWithSamplingMetadata + +logger = init_logger(__name__) + +# A flag to enable debug prints for the updated input tensors +# before each step. +debug_advance_input = False +# A flag to allow GPU advance step for draft model runner. +# Set to False for debugging. +allow_gpu_advance_step = True + + +class HPUTP1DraftModelRunner(ModelRunnerBaseCls): + """Specialized model runner for speculative decoding draft model. + Since the draft model always execute k forward passes consecutively to + generate k speculative tokens in a single speculative decoding step, + we could get rid of most CPU-GPU synchronization and data transfer + overheads by keeping model input and output tensors on GPU all the time. + + TODOs: + 1. Support TP > 1 (this requires some designs because we do not expect + any broadcasting inside execute_model). + """ + + def __init__(self, *args, **kwargs): + if kwargs.get("return_hidden_states"): + raise ValueError( + "return_hidden_states is not supported for TP1DraftModelRunner." + ) + + super().__init__(*args, **kwargs) + + self.indices_of_seq_with_bonus_tokens = None + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForHPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if previous_hidden_states is not None: + _, block_size = model_input.input_tokens.shape + previous_hidden_states = previous_hidden_states.expand( + block_size, -1).unsqueeze(0) + return super().execute_model( + model_input=model_input, + kv_caches=kv_caches, + previous_hidden_states=previous_hidden_states, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + ) diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index 0d233f393cb8c..961d0c07db35a 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -9,10 +9,10 @@ from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker import Worker +from vllm.worker.selector import WorkerCls -class MedusaWorker(NonLLMProposerWorkerBase, Worker): +class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls): """Worker for Medusa. """ diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d030a8a25a6ee..85cfdeec6816d 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -5,7 +5,6 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, SequenceGroupMetadata) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner @@ -13,24 +12,10 @@ SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.selector import WorkerCls -if current_platform.is_neuron(): - from vllm.worker.neuron_worker import NeuronWorker as WorkerBaseCls -elif current_platform.is_hpu(): - from vllm.worker.hpu_worker import HPUWorker as WorkerBaseCls -elif current_platform.is_openvino: - from vllm.worker.openvino_worker import OpenVINOWorker as WorkerBaseCls -elif current_platform.is_cpu(): - from vllm.worker.cpu_worker import CPUWorker as WorkerBaseCls -elif current_platform.is_tpu(): - from vllm.worker.tpu_worker import TPUWorker as WorkerBaseCls -elif current_platform.is_xpu(): - from vllm.worker.xpu_worker import XPUWorker as WorkerBaseCls -else: - from vllm.worker.worker import Worker as WorkerBaseCls - - -class MultiStepWorker(WorkerBaseCls, ProposerWorkerBase): + +class MultiStepWorker(WorkerCls, ProposerWorkerBase): """The MultiStepWorker is equivalent to a Worker except that it allows multiple forward passes in a single call, assuming the scheduler has allocated enough space to store the additional KV. This reduces overhead diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 80cfa55efd9bc..7b5415d60dd1e 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -14,12 +14,12 @@ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) +from vllm.platforms import current_platform from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer -from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.medusa_worker import MedusaWorker @@ -40,6 +40,11 @@ from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase +if current_platform.is_hpu(): + from vllm.spec_decode.hpu_draft_model_runner import HPUTP1DraftModelRunner +else: + from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + logger = init_logger(__name__) @@ -159,8 +164,16 @@ def create_worker( proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: - draft_worker_kwargs[ - "model_runner_cls"] = TP1DraftModelRunner + if current_platform.is_cuda_alike(): + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner + elif current_platform.is_hpu(): + draft_worker_kwargs[ + "model_runner_cls"] = HPUTP1DraftModelRunner + else: + raise NotImplementedError( + "DraftModelRunner not implemented for this platform" + ) else: if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index e61cde5b17f20..f1c87c7bfda3c 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -1,12 +1,42 @@ from typing import List, Optional from vllm.config import VllmConfig +from vllm.platforms import current_platform from vllm.sequence import SequenceGroupMetadata -from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, - ModelRunner) +if current_platform.is_cuda_alike(): + from vllm.worker.model_runner import ( + ModelInputForGPUWithSamplingMetadata as ModelInputCls) # yapf: disable + from vllm.worker.model_runner import ModelRunner as ModelRunnerCls +elif current_platform.is_neuron(): + from vllm.worker.neuron_model_runner import ( + ModelInputForNeuron as ModelInputCls) # yapf: disable + from vllm.worker.neuron_model_runner import ( + NeuronModelRunner as ModelRunnerCls) # yapf: disable +elif current_platform.is_hpu(): + from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerCls + from vllm.worker.hpu_model_runner import ( + ModelInputForHPUWithSamplingMetadata as ModelInputCls) # yapf: disable +elif current_platform.is_openvino(): + from vllm.worker.openvino_model_runner import ModelInput as ModelInputCls + from vllm.worker.openvino_model_runner import ( + OpenVINOModelRunner as ModelRunnerCls) # yapf: disable +elif current_platform.is_cpu(): + from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerCls + from vllm.worker.cpu_model_runner import ( + ModelInputForCPUWithSamplingMetadata as ModelInputCls) # yapf: disable +elif current_platform.is_tpu(): + from vllm.worker.tpu_model_runner import ModelInputForTPU as ModelInputCls + from vllm.worker.tpu_model_runner import TPUModelRunner as ModelRunnerCls +elif current_platform.is_xpu(): + from vllm.worker.xpu_model_runner import ( + ModelInputForXPUWithSamplingMetadata as ModelInputCls) # yapf: disable + from vllm.worker.xpu_model_runner import XPUModelRunner as ModelRunnerCls +else: + raise ValueError(f"Unsupported platform: {current_platform}") -class TargetModelRunner(ModelRunner): + +class TargetModelRunner(ModelRunnerCls): """Specialized model runner for speculative decoding target model. In speculative decoding, the log probabilities selected finally may not be the same ones as selected by the target model sampling. This means @@ -39,11 +69,10 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithSamplingMetadata: - model_input: ModelInputForGPUWithSamplingMetadata = super( - ).prepare_model_input(seq_group_metadata_list, virtual_engine, - finished_requests_ids) + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputCls: + model_input: ModelInputCls = super().prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) # If token log probabilities is disabled then skip generating sampler # CPU output. We directly serialize the GPU sampled_token_id tensors # as needed. If log probabilities is enabled then synchronize all the diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 59aa52879c39c..97ad0a6893dd4 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -22,11 +22,11 @@ import habana_frameworks.torch.internal.bridge_config as bc import torch from vllm_hpu_extension.ops import LoraMask as LoraMask +from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.hpu_attn import HPUAttentionBackend from vllm.config import DeviceConfig, VllmConfig from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import get_world_group @@ -263,10 +263,19 @@ def setup_profiler(): return profiler -def pad_list(list, k, v): - target_len = round_up(len(list), k) - padding = target_len - len(list) - return list + [v] * padding +def pad_list(input, k, v): + input_len = len(input) + target_len = round_up(input_len, k) + padding = target_len - input_len + return input + [v] * padding + + +def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + +def flatten(in_list): + return list(itertools.chain(*in_list)) def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): @@ -344,32 +353,44 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): mask, -math.inf)) if not is_fake_hpu() and htorch.utils.internal.is_lazy(): - block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, + block_mapping = torch.nn.functional.one_hot(metadata.block_groups, num_classes=batch_size) else: # Unfortunately one_hot on CPU/torch.compile mode/eager mode - # doesn't handle out of bounds classes, - # so we convert all negative values to 0. - block_mapping = torch.nn.functional.relu(metadata.block_mapping) + # doesn't handle out of bounds classes so we need to convert + # all negative values to 0 (block_mapping) or bs (block_groups) + block_groups = metadata.block_groups.to(torch.long) + block_mapping = torch.nn.functional.relu(block_groups) block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) - oob_values = metadata.block_mapping.lt(0) + oob_values = block_groups.lt(0) block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + block_groups.masked_fill_(oob_values, batch_size) + metadata = metadata._replace(block_groups=block_groups) block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) return metadata + def _set_block_scales(self, metadata, device): + block_mapping = metadata.block_mapping + ones = torch.ones((block_mapping.size(0), ), + device=device, + dtype=block_mapping.dtype) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + metadata = metadata._replace(block_scales=block_scales) + return metadata + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: - meta = attn_metadata - attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, - device, dtype) + attn_metadata = self._set_attn_bias(attn_metadata, batch_size, + seq_len, device, dtype) else: - meta = attn_metadata - attn_metadata = self._set_block_mapping(meta, batch_size, device, - dtype) + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, + device, dtype) + attn_metadata = self._set_block_scales(attn_metadata, device) return attn_metadata def forward(self, *args, **kwargs): @@ -393,6 +414,9 @@ def compute_logits(self, *args, **kwargs): def sample(self, *args, **kwargs): return self.model.sample(*args, **kwargs) + def generate_proposals(self, *args, **kwargs): + return self.model.generate_proposals(*args, **kwargs) + # sampler property will be used by spec_decode_worker # don't rename @property @@ -563,6 +587,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): def __init__( self, vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, return_hidden_states: bool = False, ): @@ -591,14 +616,17 @@ def __init__( self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = self.cache_config.cache_dtype + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, - ) - assert self.attn_backend == HPUAttentionBackend + ) if needs_attn_backend else None # Lazy initialization self.lora_manager: LRUCacheWorkerLoRAManager = None @@ -1076,86 +1104,49 @@ def _prepare_decode( num_decode_tokens = sum(seq_lens) - block_mapping: Union[List[Union[None, int]], torch.Tensor] - block_usage: Union[List[Union[None, int]], torch.Tensor] - block_scales: Union[List[Union[None, float]], torch.Tensor] - block_list: Union[List[int], torch.Tensor] + last_block_usage = [ + slot[0] % self.block_size + 1 for slot in slot_mapping + ] + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [[self.block_size] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + padding_fn = None if self.use_contiguous_pa: - block_list = list(itertools.chain(*block_tables)) - max_idx = max(block_list) - max_blocks = max(max_idx + 1, len(block_list)) + block_bucket_size = max(max(block_list) + 1, len(block_list)) block_bucket_size = find_bucket( - max_blocks, + block_bucket_size, self.bucketing_global_state.decode_block_bucket_cfg) - block_bucket_size = min(block_bucket_size, - self.cache_config.num_gpu_blocks) - - block_mapping = [None] * block_bucket_size - block_usage = [None] * block_bucket_size - block_scales = [None] * block_bucket_size - - for i, bt in enumerate(block_tables): - if bt: - blocks_in_group = len(bt) - scale = 1.0 / blocks_in_group - for b in bt: - if block_mapping[b] is None: - block_mapping[b] = i - block_usage[b] = self.block_size - block_scales[b] = scale - - block_mapping = [b if b is not None else -1 for b in block_mapping] - block_scales = [b if b is not None else 0.0 for b in block_scales] - - for bt, sl in zip(block_tables, slot_mapping): - if bt: - block_usage[bt[-1]] = sl[-1] % self.block_size + 1 - block_usage = [u if u is not None else 1 for u in block_usage] - + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + padding_fn = lambda tensor, pad_value: gather_list( + tensor, indices, pad_value) else: - blocks_used = [len(bt) for bt in block_tables if bt] - block_list = [] - block_scales = [] - for bt in block_tables: - block_list.extend(bt) - blocks_in_group = len(bt) - if blocks_in_group > 0: - scale = 1.0 / blocks_in_group - block_scales.extend([scale] * blocks_in_group) - - block_mapping_nested: List[List[int]] = [ - [i] * b_u for i, b_u in enumerate(blocks_used) - ] - block_mapping = list( - itertools.chain.from_iterable(block_mapping_nested)) - - last_block = [ - sl % self.block_size + 1 - for sl in itertools.chain(*slot_mapping) - ] - block_usage_ = [[self.block_size] * (b_u - 1) + [lb] - for b_u, lb in zip(blocks_used, last_block)] - block_usage = list(itertools.chain(*block_usage_)) - block_bucket_size = find_bucket( len(block_list), self.bucketing_global_state.decode_block_bucket_cfg) - block_mapping = pad_list(block_mapping, block_bucket_size, -1) - block_usage = pad_list(block_usage, block_bucket_size, 1) - block_scales = pad_list(block_scales, block_bucket_size, 0.0) + padding_fn = lambda tensor, pad_value: pad_list( + tensor, block_bucket_size, pad_value) + + block_list = padding_fn(block_list, _PAD_BLOCK_ID) + block_groups = padding_fn(block_groups, -1) + block_usage = padding_fn(block_usage, 1) - block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) - block_groups = pad_list(block_mapping, block_bucket_size, - len(block_tables)) block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) - block_mapping = torch.tensor(block_mapping, - dtype=torch.long, - device=self.device) block_groups = torch.tensor(block_groups, - dtype=torch.long, + dtype=torch.int, device=self.device) block_usage = torch.tensor(block_usage, dtype=self.model_config.dtype, @@ -1166,18 +1157,15 @@ def _prepare_decode( block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, False) - block_scales = torch.tensor(block_scales, - dtype=self.model_config.dtype, - device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, - block_mapping=block_mapping, + block_mapping=None, block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, - block_scales=block_scales, + block_scales=None, block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, @@ -1503,7 +1491,27 @@ def warmup_scenario(self, profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + is_single_step = \ + self.vllm_config.scheduler_config.num_scheduler_steps == 1 + if is_prompt or is_single_step: + self.execute_model(inputs, kv_caches, warmup_mode=True) + else: # decode with multi-step + inputs = dataclasses.replace(inputs, + is_first_multi_step=True, + is_last_step=False) + self.execute_model(inputs, + kv_caches, + warmup_mode=True, + num_steps=2, + seqs=seqs) + inputs = dataclasses.replace(inputs, + is_first_multi_step=False, + is_last_step=True) + self.execute_model(inputs, + kv_caches, + warmup_mode=True, + num_steps=2, + seqs=seqs) torch.hpu.synchronize() if profiler: profiler.step() @@ -2030,6 +2038,8 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, warmup_mode=False, + previous_hidden_states: Optional[torch.Tensor] = None, + seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if not model_input.is_first_multi_step: if not model_input.is_last_step: @@ -2080,6 +2090,9 @@ def execute_model( "lora_mask": lora_mask, **(model_input.multi_modal_kwargs or {}), } + if previous_hidden_states is not None: + execute_model_kwargs.update( + {"previous_hidden_states": previous_hidden_states}) if htorch.utils.internal.is_lazy(): execute_model_kwargs.update( {"bypass_hpu_graphs": not use_graphs}) @@ -2172,9 +2185,16 @@ def try_revert_dummy_output_tokens(): htorch.core.mark_step() if i < num_steps - 1: if i == 0: - ctx = model_input.async_callback.keywords[ # type: ignore - "ctx"] - seq_group_metadata_list = ctx.seq_group_metadata_list + if model_input.async_callback is not None: + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + seq_group_metadata_list = \ + ctx.seq_group_metadata_list + elif seqs is not None: + seq_group_metadata_list = seqs + else: + raise RuntimeError( + "seq_group_metadata_list is uninitialized") # Cache the original output token ids for i, seq_group_metadata in enumerate( seq_group_metadata_list): @@ -2232,9 +2252,16 @@ def try_revert_dummy_output_tokens(): is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) if num_steps == 1: + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states return [output] if self.is_driver_worker else [] else: return [] + return output if type(output) is list else [output] def _decode_sampler_outputs(self, model_input): diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 8aa159172d0e1..2b8f955265792 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -12,7 +12,7 @@ from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes import vllm.envs as envs -from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -23,7 +23,6 @@ from vllm.utils import hpu_backend_string, hpu_device_string, is_fake_hpu from vllm.worker.cache_engine import CacheEngine from vllm.worker.hpu_model_runner import HPUModelRunner -from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) @@ -45,8 +44,7 @@ def __init__( rank: int, distributed_init_method: str, is_driver_worker: bool = False, - speculative_config: Optional[SpeculativeConfig] = None, - model_runner_cls: Optional[Type[ModelRunnerBase]] = None, + model_runner_cls: Optional[Type[HPUModelRunner]] = None, ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.parallel_config.rank = rank @@ -62,8 +60,28 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner: HPUModelRunner = HPUModelRunner( - vllm_config=vllm_config, is_driver_worker=is_driver_worker) + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) \ + else {"return_hidden_states": True} + + ModelRunnerClass: Type[HPUModelRunner] = HPUModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + else: + ModelRunnerClass = HPUModelRunner + self.model_runner: HPUModelRunner = ModelRunnerClass( + vllm_config=vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + **speculative_args, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[HPUCacheEngine] diff --git a/vllm/worker/selector.py b/vllm/worker/selector.py index 4afb3a694693a..8e5fb0e867719 100644 --- a/vllm/worker/selector.py +++ b/vllm/worker/selector.py @@ -1,25 +1,18 @@ from vllm.platforms import current_platform +if current_platform.is_neuron(): + from vllm.worker.neuron_worker import NeuronWorker as WorkerCls +elif current_platform.is_hpu(): + from vllm.worker.hpu_worker import HPUWorker as WorkerCls # type: ignore +elif current_platform.is_cpu(): + from vllm.worker.cpu_worker import CPUWorker as WorkerCls # type: ignore +elif current_platform.is_tpu(): + from vllm.worker.tpu_worker import TPUWorker as WorkerCls # type: ignore +elif current_platform.is_xpu(): + from vllm.worker.xpu_worker import XPUWorker as WorkerCls # type: ignore +else: + from vllm.worker.worker import Worker as WorkerCls # type: ignore + def init_worker(*args, **kwargs): - if current_platform.is_neuron(): - from vllm.worker.neuron_worker import NeuronWorker - return NeuronWorker(*args, **kwargs) - elif current_platform.is_tpu(): - from vllm.worker.tpu_worker import TPUWorker - return TPUWorker(*args, **kwargs) - elif current_platform.is_cpu(): - from vllm.worker.cpu_worker import CPUWorker - return CPUWorker(*args, **kwargs) - elif current_platform.is_hpu(): - from vllm.worker.hpu_worker import HPUWorker - return HPUWorker(*args, **kwargs) - elif current_platform.is_openvino(): - from vllm.worker.openvino_worker import OpenVINOWorker - return OpenVINOWorker(*args, **kwargs) - elif current_platform.is_xpu(): - from vllm.worker.xpu_worker import XPUWorker - return XPUWorker(*args, **kwargs) - else: - from vllm.worker.worker import Worker - return Worker(*args, **kwargs) + WorkerCls(*args, **kwargs)