From 79e5ac6e245fb5c00d3ec9f866b3a22c6a142557 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 18 Nov 2024 10:45:11 -0800 Subject: [PATCH 1/3] fix rot Signed-off-by: youkaichao --- tests/kernels/test_rotary_embedding.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/test_rotary_embedding.py index da879406b3936..d5e9b875bb1b7 100644 --- a/tests/kernels/test_rotary_embedding.py +++ b/tests/kernels/test_rotary_embedding.py @@ -8,7 +8,9 @@ import torch from tests.kernels.utils import opcheck +from vllm.config import VllmConfig from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.plugins import set_current_vllm_config def rotary_embedding_opcheck(rot, @@ -42,8 +44,9 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, batch_size = 1 base = 0 num_heads = 7 - rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, torch.float32) + with set_current_vllm_config(VllmConfig()): + rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, torch.float32) positions = torch.randint(0, max_position, (batch_size, seq_len), From 18e68fb2c13cba45be92113e6bba1c780446bbc0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 18 Nov 2024 10:53:19 -0800 Subject: [PATCH 2/3] fix from root Signed-off-by: youkaichao --- vllm/plugins/__init__.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index a0c73a752b5e8..5ef6b3d6449d9 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -6,9 +6,6 @@ if TYPE_CHECKING: from vllm.config import CompilationConfig, VllmConfig -else: - CompilationConfig = None - VllmConfig = None logger = logging.getLogger(__name__) @@ -50,23 +47,23 @@ def load_general_plugins(): logger.exception("Failed to load plugin %s", plugin.name) -_compilation_config: Optional[CompilationConfig] = None +_compilation_config: Optional["CompilationConfig"] = None -def set_compilation_config(config: Optional[CompilationConfig]): +def set_compilation_config(config: Optional["CompilationConfig"]): global _compilation_config _compilation_config = config -def get_compilation_config() -> Optional[CompilationConfig]: +def get_compilation_config() -> Optional["CompilationConfig"]: return _compilation_config -_current_vllm_config: Optional[VllmConfig] = None +_current_vllm_config: Optional["VllmConfig"] = None @contextmanager -def set_current_vllm_config(vllm_config: VllmConfig): +def set_current_vllm_config(vllm_config: "VllmConfig"): """ Temporarily set the current VLLM config. Used during model initialization. @@ -83,6 +80,12 @@ def set_current_vllm_config(vllm_config: VllmConfig): _current_vllm_config = old_vllm_config -def get_current_vllm_config() -> VllmConfig: - assert _current_vllm_config is not None, "Current VLLM config is not set." +def get_current_vllm_config() -> "VllmConfig": + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current VLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() return _current_vllm_config From 421f5f5b4635533a9a5cca91b70f8fe2ede2eae6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 18 Nov 2024 11:06:22 -0800 Subject: [PATCH 3/3] revert --- tests/kernels/test_rotary_embedding.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/test_rotary_embedding.py index d5e9b875bb1b7..da879406b3936 100644 --- a/tests/kernels/test_rotary_embedding.py +++ b/tests/kernels/test_rotary_embedding.py @@ -8,9 +8,7 @@ import torch from tests.kernels.utils import opcheck -from vllm.config import VllmConfig from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.plugins import set_current_vllm_config def rotary_embedding_opcheck(rot, @@ -44,9 +42,8 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, batch_size = 1 base = 0 num_heads = 7 - with set_current_vllm_config(VllmConfig()): - rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, torch.float32) + rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, torch.float32) positions = torch.randint(0, max_position, (batch_size, seq_len),