Skip to content

Commit

Permalink
fix: Ensure model provided in vLLM inference (#820)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Update to include model name in requests made to VLLM (has a /v1/models
endpoint)
  • Loading branch information
ishaansehgal99 authored Jan 14, 2025
1 parent fd8cead commit c0f21fb
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 29 deletions.
2 changes: 1 addition & 1 deletion presets/ragengine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"""

# LLM (Large Language Model) configuration
LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/chat")
LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/v1/completions")
LLM_ACCESS_SECRET = os.getenv("LLM_ACCESS_SECRET", "default-access-secret")
# LLM_RESPONSE_FIELD = os.getenv("LLM_RESPONSE_FIELD", "result") # Uncomment if needed in the future

Expand Down
118 changes: 101 additions & 17 deletions presets/ragengine/inference/inference.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Any
from dataclasses import field
from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen
from llama_index.llms.openai import OpenAI
from llama_index.core.llms.callbacks import llm_completion_callback
import requests
from requests.exceptions import HTTPError
from urllib.parse import urlparse, urljoin
from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

OPENAI_URL_PREFIX = "https://api.openai.com"
HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co"
DEFAULT_HEADERS = {
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
}

class Inference(CustomLLM):
params: dict = {}
_default_model: str = None
_model_retrieval_attempted: bool = False

def set_params(self, params: dict) -> None:
self.params = params
Expand All @@ -25,7 +39,7 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
pass

@llm_completion_callback()
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
def complete(self, prompt: str, formatted: bool, **kwargs) -> CompletionResponse:
try:
if LLM_INFERENCE_URL.startswith(OPENAI_URL_PREFIX):
return self._openai_complete(prompt, **kwargs, **self.params)
Expand All @@ -38,29 +52,99 @@ def complete(self, prompt: str, **kwargs) -> CompletionResponse:
self.params = {}

def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
llm = OpenAI(
api_key=LLM_ACCESS_SECRET,
**kwargs # Pass all kwargs directly; kwargs may include model, temperature, max_tokens, etc.
)
return llm.complete(prompt)
return OpenAI(api_key=LLM_ACCESS_SECRET, **kwargs).complete(prompt)

def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
data = {"messages": [{"role": "user", "content": prompt}]}
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response_data = response.json()
return CompletionResponse(text=str(response_data))
return self._post_request(
{"messages": [{"role": "user", "content": prompt}]},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
)

def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
model = kwargs.pop("model", self._get_default_model())
data = {"prompt": prompt, **kwargs}
if model:
data["model"] = model # Include the model only if it is not None

# DEBUG: Call the debugging function
# self._debug_curl_command(data)
try:
return self._post_request(data, headers=DEFAULT_HEADERS)
except HTTPError as e:
if e.response.status_code == 400:
logger.warning(
f"Potential issue with 'model' parameter in API response. "
f"Response: {str(e)}. Attempting to update the model name as a mitigation..."
)
self._default_model = self._fetch_default_model() # Fetch default model dynamically
if self._default_model:
logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...")
data["model"] = self._default_model
return self._post_request(data, headers=DEFAULT_HEADERS)
else:
logger.error("Failed to fetch a default model. Aborting retry.")
raise # Re-raise the exception if not recoverable
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
raise

def _get_models_endpoint(self) -> str:
"""
Constructs the URL for the /v1/models endpoint based on LLM_INFERENCE_URL.
"""
parsed = urlparse(LLM_INFERENCE_URL)
return urljoin(f"{parsed.scheme}://{parsed.netloc}", "/v1/models")

response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response_data = response.json()
def _fetch_default_model(self) -> str:
"""
Fetch the default model from the /v1/models endpoint.
"""
try:
models_url = self._get_models_endpoint()
response = requests.get(models_url, headers=DEFAULT_HEADERS)
response.raise_for_status() # Raise an exception for HTTP errors (includes 404)

models = response.json().get("data", [])
return models[0].get("id") if models else None
except Exception as e:
logger.error(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.")
return None

def _get_default_model(self) -> str:
"""
Returns the cached default model if available, otherwise fetches and caches it.
"""
if not self._default_model and not self._model_retrieval_attempted:
self._model_retrieval_attempted = True
self._default_model = self._fetch_default_model()
return self._default_model

# Dynamically extract the field from the response based on the specified response_field
# completion_text = response_data.get(RESPONSE_FIELD, "No response field found") # not necessary for now
return CompletionResponse(text=str(response_data))
def _post_request(self, data: dict, headers: dict) -> CompletionResponse:
try:
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response.raise_for_status() # Raise exception for HTTP errors
response_data = response.json()
return CompletionResponse(text=str(response_data))
except requests.RequestException as e:
logger.error(f"Error during POST request to {LLM_INFERENCE_URL}: {e}")
raise

def _debug_curl_command(self, data: dict) -> None:
"""
Constructs and prints the equivalent curl command for debugging purposes.
"""
import json
# Construct curl command
curl_command = (
f"curl -X POST {LLM_INFERENCE_URL} "
+ " ".join([f'-H "{key}: {value}"' for key, value in {
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
}.items()])
+ f" -d '{json.dumps(data)}'"
)
logger.info("Equivalent curl command:")
logger.info(curl_command)

@property
def metadata(self) -> LLMMetadata:
Expand Down
14 changes: 10 additions & 4 deletions presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,17 @@ async def index_documents(request: IndexRequest): # TODO: Research async/sync wh
@app.post("/query", response_model=QueryResponse)
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params)
llm_params = request.llm_params or {} # Default to empty dict if no params provided
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(
request.index_name, request.query, request.top_k, llm_params, rerank_params
)
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve)) # Validation issue
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
)

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
Expand Down
33 changes: 29 additions & 4 deletions presets/ragengine/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, model_validator

from pydantic import BaseModel

class Document(BaseModel):
text: str
Expand All @@ -22,8 +23,32 @@ class QueryRequest(BaseModel):
index_name: str
query: str
top_k: int = 10
llm_params: Optional[Dict] = None # Accept a dictionary for parameters
rerank_params: Optional[Dict] = None # Accept a dictionary for parameters
# Accept a dictionary for our LLM parameters
llm_params: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Optional parameters for the language model, e.g., temperature, top_p",
)
# Accept a dictionary for rerank parameters
rerank_params: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Optional parameters for reranking, e.g., top_n, batch_size",
)

@model_validator(mode="before")
def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm_params = values.get("llm_params", {})
rerank_params = values.get("rerank_params", {})

# Validate LLM parameters
if "temperature" in llm_params and not (0.0 <= llm_params["temperature"] <= 1.0):
raise ValueError("Temperature must be between 0.0 and 1.0.")
# TODO: More LLM Param Validations here
# Validate rerank parameters
top_k = values.get("top_k")
if "top_n" in rerank_params and rerank_params["top_n"] > top_k:
raise ValueError("Invalid configuration: 'top_n' for reranking cannot exceed 'top_k' from the RAG query.")

return values

class ListDocumentsResponse(BaseModel):
documents: Dict[str, Dict[str, Dict[str, str]]]
Expand Down
2 changes: 1 addition & 1 deletion presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_query_index_failure():
}

response = client.post("/query", json=request_data)
assert response.status_code == 500
assert response.status_code == 400
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


Expand Down
4 changes: 2 additions & 2 deletions presets/ragengine/tests/vector_store/test_base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def test_query_documents(self, mock_post, vector_store_manager):

mock_post.assert_called_once_with(
LLM_INFERENCE_URL,
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", 'temperature': 0.7},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", 'Content-Type': 'application/json'}
)

def test_add_document(self, vector_store_manager):
Expand Down

0 comments on commit c0f21fb

Please sign in to comment.