-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Part 4 (Final) - Introduce Main RAG Service API and its tests (#…
…603) **Reason for Change**: This series of PR will integrate llamaindex RAG service for Kaito. This PR introduces the main API and its endpoints. `main.py` introduces the main API Three endpoints are introduced: ## 1. `POST /index` **Description**: Indexes a list of documents into the specified index. **Request Body** (`IndexRequest`): - `index_name` (str): The name of the index. - `documents` (List[Document]): - `text` (str): The document's content. - `metadata` (Optional[dict]): Additional metadata (default: empty). **Response** (`List[DocumentResponse]`): - A list of indexed documents: - `doc_id` (str): The generated document ID. - `text` (str): The document content. - `metadata` (Optional[dict]): Document metadata. ## 2. `POST /query` **Description**: Queries the specified index and returns relevant results. **Request Body** (`QueryRequest`): - `index_name` (str): The index to query. - `query` (str): The search query. - `top_k` (int, default=10): The number of top results to return. - `llm_params` (Optional[Dict]): Parameters for LLM processing. **Response** (`QueryResponse`): - `response` (str): The result or completion from the API. - `source_nodes` (List[NodeWithScore]): - `node_id` (str): The node ID. - `text` (str): The node content. - `score` (float): The relevance score. - `metadata` (Optional[dict]): Node metadata. - `metadata` (Optional[dict]): Query metadata. ## 3. `GET /indexed-documents` **Description**: Lists all documents currently indexed. **Response** (`ListDocumentsResponse`): - `documents` (Dict[str, Dict[str, Dict[str, str]]]): - Key: `index_name` - Value: A dictionary of documents in that index, where: - `doc_id`: The document ID. - `text`: The document content. ## 4. `Additional Files` - `config.py` introduces the configurable params passed via environment variables from the ragengine controller - `models.py` specifies the schema required for valid HTTP requests to the endpoints specified in main.py --------- Signed-off-by: ishaansehgal99 <ishaanforthewin@gmail.com> Co-authored-by: Fei-Guo <vrgf2003@gmail.com>
- Loading branch information
1 parent
791c175
commit 8906190
Showing
13 changed files
with
327 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import List | ||
from vector_store_manager.manager import VectorStoreManager | ||
from embedding.huggingface_local import LocalHuggingFaceEmbedding | ||
from embedding.huggingface_remote import RemoteHuggingFaceEmbedding | ||
from fastapi import FastAPI, HTTPException | ||
from models import (IndexRequest, ListDocumentsResponse, | ||
QueryRequest, QueryResponse, DocumentResponse) | ||
from vector_store.faiss_store import FaissVectorStoreHandler | ||
|
||
from ragengine.config import ACCESS_SECRET, EMBEDDING_TYPE, MODEL_ID | ||
|
||
app = FastAPI() | ||
|
||
# Initialize embedding model | ||
if EMBEDDING_TYPE.lower() == "local": | ||
embedding_manager = LocalHuggingFaceEmbedding(MODEL_ID) | ||
elif EMBEDDING_TYPE.lower() == "remote": | ||
embedding_manager = RemoteHuggingFaceEmbedding(MODEL_ID, ACCESS_SECRET) | ||
else: | ||
raise ValueError("Invalid Embedding Type Specified (Must be Local or Remote)") | ||
|
||
# Initialize vector store | ||
# TODO: Dynamically set VectorStore from EnvVars (which ultimately comes from CRD StorageSpec) | ||
vector_store_handler = FaissVectorStoreHandler(embedding_manager) | ||
|
||
# Initialize RAG operations | ||
rag_ops = VectorStoreManager(vector_store_handler) | ||
|
||
@app.post("/index", response_model=List[DocumentResponse]) | ||
async def index_documents(request: IndexRequest): # TODO: Research async/sync what to use (inference is calling) | ||
try: | ||
doc_ids = rag_ops.index(request.index_name, request.documents) | ||
documents = [ | ||
DocumentResponse(doc_id=doc_id, text=doc.text, metadata=doc.metadata) | ||
for doc_id, doc in zip(doc_ids, request.documents) | ||
] | ||
return documents | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
@app.post("/query", response_model=QueryResponse) | ||
async def query_index(request: QueryRequest): | ||
try: | ||
llm_params = request.llm_params or {} # Default to empty dict if no params provided | ||
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
@app.get("/indexed-documents", response_model=ListDocumentsResponse) | ||
async def list_all_indexed_documents(): | ||
try: | ||
documents = rag_ops.list_all_indexed_documents() | ||
return ListDocumentsResponse(documents=documents) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,15 @@ | ||
# RAG Library Requirements | ||
llama-index | ||
# HF Embeddings | ||
llama-index-embeddings-huggingface | ||
llama-index-embeddings-huggingface-api | ||
# HF LLMs | ||
llama-index-llms-huggingface | ||
llama-index-llms-huggingface-api | ||
|
||
fastapi | ||
faiss-cpu | ||
llama-index-vector-stores-faiss | ||
uvicorn | ||
# For UTs | ||
pytest |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import sys | ||
import os | ||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing | ||
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model | ||
os.environ["MKL_NUM_THREADS"] = "1" # Force MKL to use a single thread |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from unittest.mock import patch | ||
|
||
from llama_index.core.storage.index_store import SimpleIndexStore | ||
|
||
from ragengine.main import app, vector_store_handler, rag_ops | ||
from fastapi.testclient import TestClient | ||
import pytest | ||
|
||
AUTO_GEN_DOC_ID_LEN = 64 | ||
|
||
client = TestClient(app) | ||
|
||
@pytest.fixture(autouse=True) | ||
def clear_index(): | ||
vector_store_handler.index_map.clear() | ||
vector_store_handler.index_store = SimpleIndexStore() | ||
|
||
def test_index_documents_success(): | ||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
{"text": "This is a test document"}, | ||
{"text": "Another test document"} | ||
] | ||
} | ||
|
||
response = client.post("/index", json=request_data) | ||
assert response.status_code == 200 | ||
doc1, doc2 = response.json() | ||
assert (doc1["text"] == "This is a test document") | ||
assert len(doc1["doc_id"]) == AUTO_GEN_DOC_ID_LEN | ||
assert not doc1["metadata"] | ||
|
||
assert (doc2["text"] == "Another test document") | ||
assert len(doc2["doc_id"]) == AUTO_GEN_DOC_ID_LEN | ||
assert not doc2["metadata"] | ||
|
||
@patch('requests.post') | ||
def test_query_index_success(mock_post): | ||
# Define Mock Response for Custom Inference API | ||
mock_response = { | ||
"result": "This is the completion from the API" | ||
} | ||
mock_post.return_value.json.return_value = mock_response | ||
# Index | ||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
{"text": "This is a test document"}, | ||
{"text": "Another test document"} | ||
] | ||
} | ||
|
||
response = client.post("/index", json=request_data) | ||
assert response.status_code == 200 | ||
|
||
# Query | ||
request_data = { | ||
"index_name": "test_index", | ||
"query": "test query", | ||
"top_k": 1, | ||
"llm_params": {"temperature": 0.7} | ||
} | ||
|
||
response = client.post("/query", json=request_data) | ||
assert response.status_code == 200 | ||
assert response.json()["response"] == "{'result': 'This is the completion from the API'}" | ||
assert len(response.json()["source_nodes"]) == 1 | ||
assert response.json()["source_nodes"][0]["text"] == "This is a test document" | ||
assert response.json()["source_nodes"][0]["score"] == pytest.approx(0.5354418754577637, rel=1e-6) | ||
assert response.json()["source_nodes"][0]["metadata"] == {} | ||
assert mock_post.call_count == 1 | ||
|
||
def test_query_index_failure(): | ||
# Prepare request data for querying. | ||
request_data = { | ||
"index_name": "non_existent_index", # Use an index name that doesn't exist | ||
"query": "test query", | ||
"top_k": 1, | ||
"llm_params": {"temperature": 0.7} | ||
} | ||
|
||
response = client.post("/query", json=request_data) | ||
assert response.status_code == 500 | ||
assert response.json()["detail"] == "No such index: 'non_existent_index' exists." | ||
|
||
|
||
def test_list_all_indexed_documents_success(): | ||
response = client.get("/indexed-documents") | ||
assert response.status_code == 200 | ||
assert response.json() == {'documents': {}} | ||
|
||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
{"text": "This is a test document"}, | ||
{"text": "Another test document"} | ||
] | ||
} | ||
|
||
response = client.post("/index", json=request_data) | ||
assert response.status_code == 200 | ||
|
||
response = client.get("/indexed-documents") | ||
assert response.status_code == 200 | ||
assert "test_index" in response.json()["documents"] | ||
response_idx = response.json()["documents"]["test_index"] | ||
assert len(response_idx) == 2 # Two Documents Indexed | ||
assert ({item["text"] for item in response_idx.values()} | ||
== {item["text"] for item in request_data["documents"]}) | ||
|
||
|
||
""" | ||
Example of a live query test. This test is currently commented out as it requires a valid | ||
INFERENCE_URL in config.py. To run the test, ensure that a valid INFERENCE_URL is provided. | ||
Upon execution, RAG results should be observed. | ||
def test_live_query_test(): | ||
# Index | ||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
{"text": "Polar bear – can lift 450Kg (approximately 0.7 times their body weight) \ | ||
Adult male polar bears can grow to be anywhere between 300 and 700kg"}, | ||
{"text": "Giraffes are the tallest mammals and are well-adapted to living in trees. \ | ||
They have few predators as adults."} | ||
] | ||
} | ||
response = client.post("/index", json=request_data) | ||
assert response.status_code == 200 | ||
# Query | ||
request_data = { | ||
"index_name": "test_index", | ||
"query": "What is the strongest bear?", | ||
"top_k": 1, | ||
"llm_params": {"temperature": 0.7} | ||
} | ||
response = client.post("/query", json=request_data) | ||
assert response.status_code == 200 | ||
""" |
Oops, something went wrong.