Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci][bugfix] fix kernel tests #10431

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

if TYPE_CHECKING:
from vllm.config import CompilationConfig, VllmConfig
else:
CompilationConfig = None
VllmConfig = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,23 +47,23 @@ def load_general_plugins():
logger.exception("Failed to load plugin %s", plugin.name)


_compilation_config: Optional[CompilationConfig] = None
_compilation_config: Optional["CompilationConfig"] = None


def set_compilation_config(config: Optional[CompilationConfig]):
def set_compilation_config(config: Optional["CompilationConfig"]):
global _compilation_config
_compilation_config = config


def get_compilation_config() -> Optional[CompilationConfig]:
def get_compilation_config() -> Optional["CompilationConfig"]:
return _compilation_config


_current_vllm_config: Optional[VllmConfig] = None
_current_vllm_config: Optional["VllmConfig"] = None


@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig):
def set_current_vllm_config(vllm_config: "VllmConfig"):
"""
Temporarily set the current VLLM config.
Used during model initialization.
Expand All @@ -83,6 +80,12 @@ def set_current_vllm_config(vllm_config: VllmConfig):
_current_vllm_config = old_vllm_config


def get_current_vllm_config() -> VllmConfig:
assert _current_vllm_config is not None, "Current VLLM config is not set."
def get_current_vllm_config() -> "VllmConfig":
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config