From 0a55cdd53fb5f0bfbb42af19ebb533fa1c1b15e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Fri, 19 Jul 2024 17:35:53 +0200 Subject: [PATCH] tests: fix lora tests handling --- tests/conftest.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e2655226..96b16ac5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,24 +39,24 @@ def monkeysession(): @pytest.fixture(scope="session") -def lora_enabled(): +def lora_available() -> bool: # lora does not work on cpu return not vllm.config.is_cpu() @pytest.fixture(scope="session") -def requires_lora(lora_enabled): # noqa: PT004 - if not lora_enabled: - pytest.skip(reason="Lora is not enabled. (disabled on cpu)") +def lora_adapter_name(request: pytest.FixtureRequest): + if not request.getfixturevalue("lora_available"): + pytest.skip("Lora is not available with this configuration") - -@pytest.fixture(scope="session") -def lora_adapter_name(requires_lora): return "lora-test" @pytest.fixture(scope="session") -def lora_adapter_path(requires_lora): +def lora_adapter_path(request: pytest.FixtureRequest): + if not request.getfixturevalue("lora_available"): + pytest.skip("Lora is not available with this configuration") + from huggingface_hub import snapshot_download path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") @@ -64,25 +64,22 @@ def lora_adapter_path(requires_lora): @pytest.fixture(scope="session") -def args( # noqa: PLR0913 +def args( + request: pytest.FixtureRequest, monkeysession, grpc_server_thread_port, http_server_thread_port, - lora_enabled, - lora_adapter_name, - lora_adapter_path, + lora_available, ) -> argparse.Namespace: """Return parsed CLI arguments for the adapter/vLLM.""" # avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments extra_args: list[str] = [] - if lora_enabled: - extra_args.extend( - ( - "--enable-lora", - f"--lora-modules={lora_adapter_name}={lora_adapter_path}", - ) - ) + if lora_available: + name = request.getfixturevalue("lora_adapter_name") + path = request.getfixturevalue("lora_adapter_path") + + extra_args.extend(("--enable-lora", f"--lora-modules={name}={path}")) monkeysession.setattr( sys,