From 7171eb27b3c6008b8bf52c5e54786db7f8430600 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Fri, 20 Dec 2024 11:55:49 -0800 Subject: [PATCH] update unit tests --- .../properties_manager/vllm_rb_properties.py | 90 ++++++----- .../tests/test_properties_manager.py | 152 ++++++++++++++++-- engines/python/setup/djl_python_engine.py | 2 +- 3 files changed, 194 insertions(+), 50 deletions(-) diff --git a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py index e972c9fb3..59231ef2e 100644 --- a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py +++ b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py @@ -21,36 +21,62 @@ from djl_python.properties_manager.properties import Properties DTYPE_MAPPER = { + "float32": "float32", "fp32": "float32", + "float16": "float16", "fp16": "float16", + "bfloat16": "bfloat16", "bf16": "bfloat16", "auto": "auto" } +def construct_vllm_args_list(vllm_engine_args: dict, + parser: FlexibleArgumentParser): + # Modified from https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/utils.py#L1258 + args_list = [] + store_boolean_arguments = { + action.dest + for action in parser._actions if isinstance(action, StoreBoolean) + } + for engine_arg, engine_arg_value in vllm_engine_args.items(): + if str(engine_arg_value).lower() in { + 'true', 'false' + } and engine_arg not in store_boolean_arguments: + if str(engine_arg_value).lower() == 'true': + args_list.append(f"--{engine_arg}") + else: + args_list.append(f"--{engine_arg}={engine_arg_value}") + return args_list + + class VllmRbProperties(Properties): engine: Optional[str] = None # The following configs have different names in DJL compared to vLLM, we only accept DJL name currently tensor_parallel_degree: int = 1 pipeline_parallel_degree: int = 1 # The following configs have different names in DJL compared to vLLM, either is accepted - quantize: Optional[str] = Field(alias="quantization", default=None) + quantize: Optional[str] = Field(alias="quantization", + default=EngineArgs.quantization) max_rolling_batch_prefill_tokens: Optional[int] = Field( - alias="max_num_batched_tokens", default=None) - cpu_offload_gb_per_gpu: Optional[float] = Field(alias="cpu_offload_gb", - default=None) + alias="max_num_batched_tokens", + default=EngineArgs.max_num_batched_tokens) + cpu_offload_gb_per_gpu: float = Field(alias="cpu_offload_gb", + default=EngineArgs.cpu_offload_gb) # The following configs have different defaults, or additional processing in DJL compared to vLLM dtype: str = "auto" max_loras: int = 4 + # The following configs have broken processing in vllm via the FlexibleArgumentParser long_lora_scaling_factors: Optional[Tuple[float, ...]] = None + use_v2_block_manager: bool = True # Neuron vLLM properties - device: Optional[str] = None + device: str = 'auto' preloaded_model: Optional[Any] = None generation_config: Optional[Any] = None # This allows generic vllm engine args to be passed in and set with vllm - model_config = ConfigDict(extra='allow') + model_config = ConfigDict(extra='allow', populate_by_name=True) @field_validator('engine') def validate_engine(cls, engine): @@ -59,6 +85,14 @@ def validate_engine(cls, engine): f"Need python engine to start vLLM RollingBatcher") return engine + @field_validator('dtype') + def validate_dtype(cls, val): + if val not in DTYPE_MAPPER: + raise ValueError( + f"Invalid dtype={val} provided. Must be one of {DTYPE_MAPPER.keys()}" + ) + return DTYPE_MAPPER[val] + @model_validator(mode='after') def validate_pipeline_parallel(self): if self.pipeline_parallel_degree != 1: @@ -67,9 +101,9 @@ def validate_pipeline_parallel(self): ) return self - @field_validator('long_lora_scaling_factors', mode='before') # TODO: processing of this field is broken in vllm via from_cli_args # we should upstream a fix for this to vllm + @field_validator('long_lora_scaling_factors', mode='before') def validate_long_lora_scaling_factors(cls, val): if isinstance(val, str): val = ast.literal_eval(val) @@ -96,7 +130,7 @@ def validate_potential_lmi_vllm_config_conflict( if vllm_config_val != lmi_config_val: raise ValueError( f"Both the DJL {lmi_config_val}={lmi_config_val} and vLLM {vllm_config_name}={vllm_config_val} configs have been set with conflicting values." - f"We currently only accept the DJL config {lmi_config_val}, please remove the vllm {vllm_config_name} configuration." + f"We currently only accept the DJL config {lmi_config_name}, please remove the vllm {vllm_config_name} configuration." ) validate_potential_lmi_vllm_config_conflict("tensor_parallel_degree", @@ -117,20 +151,18 @@ def generate_vllm_engine_arg_dict(self, 'revision': self.revision, 'max_loras': self.max_loras, 'enable_lora': self.enable_lora, + 'trust_remote_code': self.trust_remote_code, + 'cpu_offload_gb': self.cpu_offload_gb_per_gpu, + 'use_v2_block_manager': self.use_v2_block_manager, + 'quantization': self.quantize, + 'max_num_batched_tokens': self.max_rolling_batch_prefill_tokens, + 'device': self.device, } - if self.quantize is not None: - vllm_engine_args['quantization'] = self.quantize - if self.max_rolling_batch_prefill_tokens is not None: - vllm_engine_args[ - 'max_num_batched_tokens'] = self.max_rolling_batch_prefill_tokens - if self.cpu_offload_gb_per_gpu is not None: - vllm_engine_args['cpu_offload_gb'] = self.cpu_offload_gb_per_gpu - if self.device is not None: - vllm_engine_args['device'] = self.device - if self.preloaded_model is not None: + if self.device == 'neuron': vllm_engine_args['preloaded_model'] = self.preloaded_model - if self.generation_config is not None: vllm_engine_args['generation_config'] = self.generation_config + vllm_engine_args['block_size'] = passthrough_vllm_engine_args.get( + "max_model_len") vllm_engine_args.update(passthrough_vllm_engine_args) return vllm_engine_args @@ -143,7 +175,7 @@ def get_engine_args(self) -> EngineArgs: f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}" ) parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) - args_list = self.construct_vllm_args_list(vllm_engine_arg_dict, parser) + args_list = construct_vllm_args_list(vllm_engine_arg_dict, parser) args = parser.parse_args(args=args_list) engine_args = EngineArgs.from_cli_args(args) # we have to do this separately because vllm converts it into a string @@ -156,21 +188,3 @@ def get_additional_vllm_engine_args(self) -> Dict[str, Any]: for k, v in self.__pydantic_extra__.items() if k in EngineArgs.__annotations__ } - - def construct_vllm_args_list(self, vllm_engine_args: dict, - parser: FlexibleArgumentParser): - # Modified from https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/utils.py#L1258 - args_list = [] - store_boolean_arguments = { - action.dest - for action in parser._actions if isinstance(action, StoreBoolean) - } - for engine_arg, engine_arg_value in vllm_engine_args.items(): - if str(engine_arg_value).lower() in { - 'true', 'false' - } and engine_arg not in store_boolean_arguments: - if str(engine_arg_value).lower() == 'true': - args_list.append(f"--{engine_arg}") - else: - args_list.append(f"--{engine_arg}={engine_arg_value}") - return args_list diff --git a/engines/python/setup/djl_python/tests/test_properties_manager.py b/engines/python/setup/djl_python/tests/test_properties_manager.py index 5647bf88f..b3de24f8d 100644 --- a/engines/python/setup/djl_python/tests/test_properties_manager.py +++ b/engines/python/setup/djl_python/tests/test_properties_manager.py @@ -11,7 +11,7 @@ TnXMemoryLayout, TnXDtypeName, TnXModelLoaders) from djl_python.properties_manager.trt_properties import TensorRtLlmProperties from djl_python.properties_manager.hf_properties import HuggingFaceProperties -from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties, DTYPE_MAPPER +from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties from djl_python.properties_manager.sd_inf2_properties import StableDiffusionNeuronXProperties from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties from djl_python.properties_manager.scheduler_rb_properties import SchedulerRbProperties @@ -423,7 +423,7 @@ def test_hf_error_case(self, params): HuggingFaceProperties(**params) def test_vllm_properties(self): - # test with valid vllm properties + def validate_vllm_config_and_engine_args_match( vllm_config_value, engine_arg_value, @@ -435,7 +435,7 @@ def validate_vllm_config_and_engine_args_match( def test_vllm_default_properties(): required_properties = { "engine": "Python", - "model_id_or_path": "some_model", + "model_id": "some_model", } vllm_configs = VllmRbProperties(**required_properties) engine_args = vllm_configs.get_engine_args() @@ -451,22 +451,120 @@ def test_vllm_default_properties(): vllm_configs.quantize, engine_args.quantization, None) validate_vllm_config_and_engine_args_match( vllm_configs.max_rolling_batch_size, engine_args.max_num_seqs, - HuggingFaceProperties.max_rolling_batch_size) + 32) validate_vllm_config_and_engine_args_match(vllm_configs.dtype, engine_args.dtype, 'auto') validate_vllm_config_and_engine_args_match(vllm_configs.max_loras, engine_args.max_loras, 4) - self.assertEqual(vllm_configs.cpu_offload_gb_per_gpu, None) + validate_vllm_config_and_engine_args_match( + vllm_configs.cpu_offload_gb_per_gpu, + engine_args.cpu_offload_gb, EngineArgs.cpu_offload_gb) self.assertEqual( len(vllm_configs.get_additional_vllm_engine_args()), 0) + def test_invalid_pipeline_parallel(): + properties = { + "engine": "Python", + "model_id": "some_model", + "tensor_parallel_degree": "4", + "pipeline_parallel_degree": "2", + } + with self.assertRaises(ValueError): + _ = VllmRbProperties(**properties) + + def test_invalid_engine(): + properties = { + "engine": "bad_engine", + "model_id": "some_model", + } + with self.assertRaises(ValueError): + _ = VllmRbProperties(**properties) + + def test_aliases(): + properties = { + "engine": "Python", + "model_id": "some_model", + "quantization": "awq", + "max_num_batched_tokens": "546", + "cpu_offload_gb": "7" + } + vllm_configs = VllmRbProperties(**properties) + engine_args = vllm_configs.get_engine_args() + validate_vllm_config_and_engine_args_match( + vllm_configs.quantize, engine_args.quantization, "awq") + validate_vllm_config_and_engine_args_match( + vllm_configs.max_rolling_batch_prefill_tokens, + engine_args.max_num_batched_tokens, 546) + validate_vllm_config_and_engine_args_match( + vllm_configs.cpu_offload_gb_per_gpu, + engine_args.cpu_offload_gb, 7) + + def test_vllm_passthrough_properties(): + properties = { + "engine": "Python", + "model_id": "some_model", + "tensor_parallel_degree": "4", + "pipeline_parallel_degree": "1", + "max_rolling_batch_size": "111", + "quantize": "awq", + "max_rolling_batch_prefill_tokens": "400", + "cpu_offload_gb_per_gpu": "8", + "dtype": "bf16", + "max_loras": "7", + "long_lora_scaling_factors": "1.1, 2.0", + "trust_remote_code": "true", + "max_model_len": "1024", + "enforce_eager": "true", + "enable_chunked_prefill": "False", + "gpu_memory_utilization": "0.4", + } + vllm_configs = VllmRbProperties(**properties) + engine_args = vllm_configs.get_engine_args() + self.assertTrue( + len(vllm_configs.get_additional_vllm_engine_args()) > 0) + validate_vllm_config_and_engine_args_match( + vllm_configs.model_id_or_path, engine_args.model, "some_model") + validate_vllm_config_and_engine_args_match( + vllm_configs.tensor_parallel_degree, + engine_args.tensor_parallel_size, 4) + validate_vllm_config_and_engine_args_match( + vllm_configs.pipeline_parallel_degree, + engine_args.pipeline_parallel_size, 1) + validate_vllm_config_and_engine_args_match( + vllm_configs.max_rolling_batch_size, engine_args.max_num_seqs, + 111) + validate_vllm_config_and_engine_args_match( + vllm_configs.quantize, engine_args.quantization, "awq") + validate_vllm_config_and_engine_args_match( + vllm_configs.max_rolling_batch_prefill_tokens, + engine_args.max_num_batched_tokens, 400) + validate_vllm_config_and_engine_args_match( + vllm_configs.cpu_offload_gb_per_gpu, + engine_args.cpu_offload_gb, 8.0) + validate_vllm_config_and_engine_args_match(vllm_configs.dtype, + engine_args.dtype, + "bfloat16") + validate_vllm_config_and_engine_args_match(vllm_configs.max_loras, + engine_args.max_loras, + 7) + validate_vllm_config_and_engine_args_match( + vllm_configs.long_lora_scaling_factors, + engine_args.long_lora_scaling_factors, (1.1, 2.0)) + validate_vllm_config_and_engine_args_match( + vllm_configs.trust_remote_code, engine_args.trust_remote_code, + True) + self.assertEqual(engine_args.max_model_len, 1024) + self.assertEqual(engine_args.enforce_eager, True) + self.assertEqual(engine_args.enable_chunked_prefill, False) + self.assertEqual(engine_args.gpu_memory_utilization, 0.4) + def test_long_lora_scaling_factors(): properties = { "engine": "Python", - "model_id_or_path": "some_model", - 'long_lora_scaling_factors': "3.0" + "model_id": "some_model", + "long_lora_scaling_factors": "3.0" } vllm_props = VllmRbProperties(**properties) engine_args = vllm_props.get_engine_args() @@ -500,16 +598,48 @@ def test_long_lora_scaling_factors(): def test_invalid_long_lora_scaling_factors(): properties = { "engine": "Python", - "model_id_or_path": "some_model", - 'long_lora_scaling_factors': "a,b" + "model_id": "some_model", + "long_lora_scaling_factors": "a,b" + } + with self.assertRaises(ValueError): + _ = VllmRbProperties(**properties) + + def test_conflicting_djl_vllm_conflicts(): + properties = { + "engine": "Python", + "model_id": "some_model", + "tensor_parallel_degree": 2, + "tensor_parallel_size": 1, + } + with self.assertRaises(ValueError): + _ = VllmRbProperties(**properties) + + properties = { + "engine": "Python", + "model_id": "some_model", + "pipeline_parallel_degree": 1, + "pipeline_parallel_size": 0, + } + with self.assertRaises(ValueError): + _ = VllmRbProperties(**properties) + + properties = { + "engine": "Python", + "model_id": "some_model", + "max_rolling_batch_size": 1, + "max_num_seqs": 2, } - vllm_props = VllmRbProperties(**properties) with self.assertRaises(ValueError): - vllm_props.get_engine_args() + _ = VllmRbProperties(**properties) test_vllm_default_properties() + test_invalid_pipeline_parallel() + test_invalid_engine() + test_aliases() + test_vllm_passthrough_properties() test_long_lora_scaling_factors() test_invalid_long_lora_scaling_factors() + test_conflicting_djl_vllm_conflicts() def test_sd_inf2_properties(self): properties = { diff --git a/engines/python/setup/djl_python_engine.py b/engines/python/setup/djl_python_engine.py index e06270f82..f11c9fa2d 100644 --- a/engines/python/setup/djl_python_engine.py +++ b/engines/python/setup/djl_python_engine.py @@ -189,7 +189,7 @@ def main(): # noinspection PyBroadException try: - args = ArgParser.python_engine_args().parse_args(args=sys.argv[1:]) + args = ArgParser.python_engine_args().parse_args() logging.basicConfig(stream=sys.stdout, format="%(levelname)s::%(message)s", level=args.log_level.upper())