From c0f21fbe8034b17f006a8953ad9a2de52c16b4a4 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 14 Jan 2025 10:36:53 -0800 Subject: [PATCH] fix: Ensure `model` provided in vLLM inference (#820) **Reason for Change**: Update to include model name in requests made to VLLM (has a /v1/models endpoint) --- presets/ragengine/config.py | 2 +- presets/ragengine/inference/inference.py | 118 +++++++++++++++--- presets/ragengine/main.py | 14 ++- presets/ragengine/models.py | 33 ++++- presets/ragengine/tests/api/test_main.py | 2 +- .../tests/vector_store/test_base_store.py | 4 +- 6 files changed, 144 insertions(+), 29 deletions(-) diff --git a/presets/ragengine/config.py b/presets/ragengine/config.py index 0eae6413b..8adde1fbd 100644 --- a/presets/ragengine/config.py +++ b/presets/ragengine/config.py @@ -38,7 +38,7 @@ """ # LLM (Large Language Model) configuration -LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/chat") +LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/v1/completions") LLM_ACCESS_SECRET = os.getenv("LLM_ACCESS_SECRET", "default-access-secret") # LLM_RESPONSE_FIELD = os.getenv("LLM_RESPONSE_FIELD", "result") # Uncomment if needed in the future diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index f48248463..7728c7ab3 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -1,18 +1,32 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import logging from typing import Any +from dataclasses import field from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen from llama_index.llms.openai import OpenAI from llama_index.core.llms.callbacks import llm_completion_callback import requests +from requests.exceptions import HTTPError +from urllib.parse import urlparse, urljoin from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + OPENAI_URL_PREFIX = "https://api.openai.com" HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co" +DEFAULT_HEADERS = { + "Authorization": f"Bearer {LLM_ACCESS_SECRET}", + "Content-Type": "application/json" +} class Inference(CustomLLM): params: dict = {} + _default_model: str = None + _model_retrieval_attempted: bool = False def set_params(self, params: dict) -> None: self.params = params @@ -25,7 +39,7 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: pass @llm_completion_callback() - def complete(self, prompt: str, **kwargs) -> CompletionResponse: + def complete(self, prompt: str, formatted: bool, **kwargs) -> CompletionResponse: try: if LLM_INFERENCE_URL.startswith(OPENAI_URL_PREFIX): return self._openai_complete(prompt, **kwargs, **self.params) @@ -38,29 +52,99 @@ def complete(self, prompt: str, **kwargs) -> CompletionResponse: self.params = {} def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - llm = OpenAI( - api_key=LLM_ACCESS_SECRET, - **kwargs # Pass all kwargs directly; kwargs may include model, temperature, max_tokens, etc. - ) - return llm.complete(prompt) + return OpenAI(api_key=LLM_ACCESS_SECRET, **kwargs).complete(prompt) def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} - data = {"messages": [{"role": "user", "content": prompt}]} - response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) - response_data = response.json() - return CompletionResponse(text=str(response_data)) + return self._post_request( + {"messages": [{"role": "user", "content": prompt}]}, + headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} + ) def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} + model = kwargs.pop("model", self._get_default_model()) data = {"prompt": prompt, **kwargs} + if model: + data["model"] = model # Include the model only if it is not None + + # DEBUG: Call the debugging function + # self._debug_curl_command(data) + try: + return self._post_request(data, headers=DEFAULT_HEADERS) + except HTTPError as e: + if e.response.status_code == 400: + logger.warning( + f"Potential issue with 'model' parameter in API response. " + f"Response: {str(e)}. Attempting to update the model name as a mitigation..." + ) + self._default_model = self._fetch_default_model() # Fetch default model dynamically + if self._default_model: + logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...") + data["model"] = self._default_model + return self._post_request(data, headers=DEFAULT_HEADERS) + else: + logger.error("Failed to fetch a default model. Aborting retry.") + raise # Re-raise the exception if not recoverable + except Exception as e: + logger.error(f"An unexpected error occurred: {e}") + raise + + def _get_models_endpoint(self) -> str: + """ + Constructs the URL for the /v1/models endpoint based on LLM_INFERENCE_URL. + """ + parsed = urlparse(LLM_INFERENCE_URL) + return urljoin(f"{parsed.scheme}://{parsed.netloc}", "/v1/models") - response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) - response_data = response.json() + def _fetch_default_model(self) -> str: + """ + Fetch the default model from the /v1/models endpoint. + """ + try: + models_url = self._get_models_endpoint() + response = requests.get(models_url, headers=DEFAULT_HEADERS) + response.raise_for_status() # Raise an exception for HTTP errors (includes 404) + + models = response.json().get("data", []) + return models[0].get("id") if models else None + except Exception as e: + logger.error(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.") + return None + + def _get_default_model(self) -> str: + """ + Returns the cached default model if available, otherwise fetches and caches it. + """ + if not self._default_model and not self._model_retrieval_attempted: + self._model_retrieval_attempted = True + self._default_model = self._fetch_default_model() + return self._default_model - # Dynamically extract the field from the response based on the specified response_field - # completion_text = response_data.get(RESPONSE_FIELD, "No response field found") # not necessary for now - return CompletionResponse(text=str(response_data)) + def _post_request(self, data: dict, headers: dict) -> CompletionResponse: + try: + response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) + response.raise_for_status() # Raise exception for HTTP errors + response_data = response.json() + return CompletionResponse(text=str(response_data)) + except requests.RequestException as e: + logger.error(f"Error during POST request to {LLM_INFERENCE_URL}: {e}") + raise + + def _debug_curl_command(self, data: dict) -> None: + """ + Constructs and prints the equivalent curl command for debugging purposes. + """ + import json + # Construct curl command + curl_command = ( + f"curl -X POST {LLM_INFERENCE_URL} " + + " ".join([f'-H "{key}: {value}"' for key, value in { + "Authorization": f"Bearer {LLM_ACCESS_SECRET}", + "Content-Type": "application/json" + }.items()]) + + f" -d '{json.dumps(data)}'" + ) + logger.info("Equivalent curl command:") + logger.info(curl_command) @property def metadata(self) -> LLMMetadata: diff --git a/presets/ragengine/main.py b/presets/ragengine/main.py index 56f891178..fd2d26efc 100644 --- a/presets/ragengine/main.py +++ b/presets/ragengine/main.py @@ -60,11 +60,17 @@ async def index_documents(request: IndexRequest): # TODO: Research async/sync wh @app.post("/query", response_model=QueryResponse) async def query_index(request: QueryRequest): try: - llm_params = request.llm_params or {} # Default to empty dict if no params provided - rerank_params = request.rerank_params or {} # Default to empty dict if no params provided - return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params) + llm_params = request.llm_params or {} # Default to empty dict if no params provided + rerank_params = request.rerank_params or {} # Default to empty dict if no params provided + return rag_ops.query( + request.index_name, request.query, request.top_k, llm_params, rerank_params + ) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) # Validation issue except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, detail=f"An unexpected error occurred: {str(e)}" + ) @app.get("/indexed-documents", response_model=ListDocumentsResponse) async def list_all_indexed_documents(): diff --git a/presets/ragengine/models.py b/presets/ragengine/models.py index a1b2ff529..a1e89a21f 100644 --- a/presets/ragengine/models.py +++ b/presets/ragengine/models.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field, model_validator -from pydantic import BaseModel class Document(BaseModel): text: str @@ -22,8 +23,32 @@ class QueryRequest(BaseModel): index_name: str query: str top_k: int = 10 - llm_params: Optional[Dict] = None # Accept a dictionary for parameters - rerank_params: Optional[Dict] = None # Accept a dictionary for parameters + # Accept a dictionary for our LLM parameters + llm_params: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Optional parameters for the language model, e.g., temperature, top_p", + ) + # Accept a dictionary for rerank parameters + rerank_params: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Optional parameters for reranking, e.g., top_n, batch_size", + ) + + @model_validator(mode="before") + def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]: + llm_params = values.get("llm_params", {}) + rerank_params = values.get("rerank_params", {}) + + # Validate LLM parameters + if "temperature" in llm_params and not (0.0 <= llm_params["temperature"] <= 1.0): + raise ValueError("Temperature must be between 0.0 and 1.0.") + # TODO: More LLM Param Validations here + # Validate rerank parameters + top_k = values.get("top_k") + if "top_n" in rerank_params and rerank_params["top_n"] > top_k: + raise ValueError("Invalid configuration: 'top_n' for reranking cannot exceed 'top_k' from the RAG query.") + + return values class ListDocumentsResponse(BaseModel): documents: Dict[str, Dict[str, Dict[str, str]]] diff --git a/presets/ragengine/tests/api/test_main.py b/presets/ragengine/tests/api/test_main.py index fee67dd7b..102936d9d 100644 --- a/presets/ragengine/tests/api/test_main.py +++ b/presets/ragengine/tests/api/test_main.py @@ -175,7 +175,7 @@ def test_query_index_failure(): } response = client.post("/query", json=request_data) - assert response.status_code == 500 + assert response.status_code == 400 assert response.json()["detail"] == "No such index: 'non_existent_index' exists." diff --git a/presets/ragengine/tests/vector_store/test_base_store.py b/presets/ragengine/tests/vector_store/test_base_store.py index d3f49848f..9f55bad95 100644 --- a/presets/ragengine/tests/vector_store/test_base_store.py +++ b/presets/ragengine/tests/vector_store/test_base_store.py @@ -89,8 +89,8 @@ def test_query_documents(self, mock_post, vector_store_manager): mock_post.assert_called_once_with( LLM_INFERENCE_URL, - json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7}, - headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} + json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", 'temperature': 0.7}, + headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", 'Content-Type': 'application/json'} ) def test_add_document(self, vector_store_manager):