Skip to content

Commit

Permalink
feat: Part 4 (Final) - Introduce Main RAG Service API and its tests (#…
Browse files Browse the repository at this point in the history
…603)

**Reason for Change**:
This series of PR will integrate llamaindex RAG service for Kaito.

This PR introduces the main API and its endpoints. 

`main.py` introduces the main API
Three endpoints are introduced: 
## 1. `POST /index`

**Description**:  
Indexes a list of documents into the specified index.

**Request Body** (`IndexRequest`):
- `index_name` (str): The name of the index.
- `documents` (List[Document]):
  - `text` (str): The document's content.
  - `metadata` (Optional[dict]): Additional metadata (default: empty).

**Response** (`List[DocumentResponse]`):
- A list of indexed documents:
  - `doc_id` (str): The generated document ID.
  - `text` (str): The document content.
  - `metadata` (Optional[dict]): Document metadata.

## 2. `POST /query`

**Description**:  
Queries the specified index and returns relevant results.

**Request Body** (`QueryRequest`):
- `index_name` (str): The index to query.
- `query` (str): The search query.
- `top_k` (int, default=10): The number of top results to return.
- `llm_params` (Optional[Dict]): Parameters for LLM processing.

**Response** (`QueryResponse`):
- `response` (str): The result or completion from the API.
- `source_nodes` (List[NodeWithScore]):
  - `node_id` (str): The node ID.
  - `text` (str): The node content.
  - `score` (float): The relevance score.
  - `metadata` (Optional[dict]): Node metadata.
- `metadata` (Optional[dict]): Query metadata.


## 3. `GET /indexed-documents`

**Description**:  
Lists all documents currently indexed.

**Response** (`ListDocumentsResponse`):
- `documents` (Dict[str, Dict[str, Dict[str, str]]]):  
  - Key: `index_name`
  - Value: A dictionary of documents in that index, where:
    - `doc_id`: The document ID.
    - `text`: The document content.

## 4. `Additional Files`

- `config.py` introduces the configurable params passed via environment
variables from the ragengine controller

- `models.py` specifies the schema required for valid HTTP requests to
the endpoints specified in main.py

---------

Signed-off-by: ishaansehgal99 <ishaanforthewin@gmail.com>
Co-authored-by: Fei-Guo <vrgf2003@gmail.com>
  • Loading branch information
ishaansehgal99 and Fei-Guo authored Oct 23, 2024
1 parent 791c175 commit 8906190
Show file tree
Hide file tree
Showing 13 changed files with 327 additions and 84 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ jobs:
- name: Run unit tests & Generate coverage
run: |
make unit-test
make rag-service-test
make tuning-metrics-server-test
- name: Run inference api unit tests
- name: Run inference api e2e tests
run: |
make inference-api-e2e
Expand Down
13 changes: 12 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,22 @@ unit-test: ## Run unit tests.
-race -coverprofile=coverage.txt -covermode=atomic
go tool cover -func=coverage.txt

.PHONY: rag-service-test
rag-service-test:
pip install -r ragengine/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO ragengine/tests

.PHONY: tuning-metrics-server-test
tuning-metrics-server-test:
pip install -r presets/inference/text-generation/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO presets/tuning/text-generation/metrics

## --------------------------------------
## E2E tests
## --------------------------------------

inference-api-e2e:
.PHONY: inference-api-e2e
inference-api-e2e:
pip install -r presets/inference/text-generation/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO presets/inference/text-generation/tests

Expand Down
Empty file added ragengine/README.md
Empty file.
59 changes: 59 additions & 0 deletions ragengine/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import List
from vector_store_manager.manager import VectorStoreManager
from embedding.huggingface_local import LocalHuggingFaceEmbedding
from embedding.huggingface_remote import RemoteHuggingFaceEmbedding
from fastapi import FastAPI, HTTPException
from models import (IndexRequest, ListDocumentsResponse,
QueryRequest, QueryResponse, DocumentResponse)
from vector_store.faiss_store import FaissVectorStoreHandler

from ragengine.config import ACCESS_SECRET, EMBEDDING_TYPE, MODEL_ID

app = FastAPI()

# Initialize embedding model
if EMBEDDING_TYPE.lower() == "local":
embedding_manager = LocalHuggingFaceEmbedding(MODEL_ID)
elif EMBEDDING_TYPE.lower() == "remote":
embedding_manager = RemoteHuggingFaceEmbedding(MODEL_ID, ACCESS_SECRET)
else:
raise ValueError("Invalid Embedding Type Specified (Must be Local or Remote)")

# Initialize vector store
# TODO: Dynamically set VectorStore from EnvVars (which ultimately comes from CRD StorageSpec)
vector_store_handler = FaissVectorStoreHandler(embedding_manager)

# Initialize RAG operations
rag_ops = VectorStoreManager(vector_store_handler)

@app.post("/index", response_model=List[DocumentResponse])
async def index_documents(request: IndexRequest): # TODO: Research async/sync what to use (inference is calling)
try:
doc_ids = rag_ops.index(request.index_name, request.documents)
documents = [
DocumentResponse(doc_id=doc_id, text=doc.text, metadata=doc.metadata)
for doc_id, doc in zip(doc_ids, request.documents)
]
return documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query", response_model=QueryResponse)
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
try:
documents = rag_ops.list_all_indexed_documents()
return ListDocumentsResponse(documents=documents)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
19 changes: 18 additions & 1 deletion ragengine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ class Document(BaseModel):
text: str
metadata: Optional[dict] = {}

class DocumentResponse(BaseModel):
doc_id: str
text: str
metadata: Optional[dict] = None

class IndexRequest(BaseModel):
index_name: str
documents: List[Document]
Expand All @@ -17,4 +22,16 @@ class QueryRequest(BaseModel):
llm_params: Optional[Dict] = None # Accept a dictionary for parameters

class ListDocumentsResponse(BaseModel):
documents:Dict[str, Dict[str, Dict[str, str]]]
documents: Dict[str, Dict[str, Dict[str, str]]]

# Define models for NodeWithScore, and QueryResponse
class NodeWithScore(BaseModel):
node_id: str
text: str
score: float
metadata: Optional[dict] = None

class QueryResponse(BaseModel):
response: str
source_nodes: List[NodeWithScore]
metadata: Optional[dict] = None
8 changes: 8 additions & 0 deletions ragengine/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# RAG Library Requirements
llama-index
# HF Embeddings
llama-index-embeddings-huggingface
llama-index-embeddings-huggingface-api
# HF LLMs
llama-index-llms-huggingface
llama-index-llms-huggingface-api

fastapi
faiss-cpu
llama-index-vector-stores-faiss
uvicorn
# For UTs
pytest
Empty file added ragengine/tests/api/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions ragengine/tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model
os.environ["MKL_NUM_THREADS"] = "1" # Force MKL to use a single thread
143 changes: 143 additions & 0 deletions ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from unittest.mock import patch

from llama_index.core.storage.index_store import SimpleIndexStore

from ragengine.main import app, vector_store_handler, rag_ops
from fastapi.testclient import TestClient
import pytest

AUTO_GEN_DOC_ID_LEN = 64

client = TestClient(app)

@pytest.fixture(autouse=True)
def clear_index():
vector_store_handler.index_map.clear()
vector_store_handler.index_store = SimpleIndexStore()

def test_index_documents_success():
request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200
doc1, doc2 = response.json()
assert (doc1["text"] == "This is a test document")
assert len(doc1["doc_id"]) == AUTO_GEN_DOC_ID_LEN
assert not doc1["metadata"]

assert (doc2["text"] == "Another test document")
assert len(doc2["doc_id"]) == AUTO_GEN_DOC_ID_LEN
assert not doc2["metadata"]

@patch('requests.post')
def test_query_index_success(mock_post):
# Define Mock Response for Custom Inference API
mock_response = {
"result": "This is the completion from the API"
}
mock_post.return_value.json.return_value = mock_response
# Index
request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200

# Query
request_data = {
"index_name": "test_index",
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}

response = client.post("/query", json=request_data)
assert response.status_code == 200
assert response.json()["response"] == "{'result': 'This is the completion from the API'}"
assert len(response.json()["source_nodes"]) == 1
assert response.json()["source_nodes"][0]["text"] == "This is a test document"
assert response.json()["source_nodes"][0]["score"] == pytest.approx(0.5354418754577637, rel=1e-6)
assert response.json()["source_nodes"][0]["metadata"] == {}
assert mock_post.call_count == 1

def test_query_index_failure():
# Prepare request data for querying.
request_data = {
"index_name": "non_existent_index", # Use an index name that doesn't exist
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}

response = client.post("/query", json=request_data)
assert response.status_code == 500
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


def test_list_all_indexed_documents_success():
response = client.get("/indexed-documents")
assert response.status_code == 200
assert response.json() == {'documents': {}}

request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200

response = client.get("/indexed-documents")
assert response.status_code == 200
assert "test_index" in response.json()["documents"]
response_idx = response.json()["documents"]["test_index"]
assert len(response_idx) == 2 # Two Documents Indexed
assert ({item["text"] for item in response_idx.values()}
== {item["text"] for item in request_data["documents"]})


"""
Example of a live query test. This test is currently commented out as it requires a valid
INFERENCE_URL in config.py. To run the test, ensure that a valid INFERENCE_URL is provided.
Upon execution, RAG results should be observed.
def test_live_query_test():
# Index
request_data = {
"index_name": "test_index",
"documents": [
{"text": "Polar bear – can lift 450Kg (approximately 0.7 times their body weight) \
Adult male polar bears can grow to be anywhere between 300 and 700kg"},
{"text": "Giraffes are the tallest mammals and are well-adapted to living in trees. \
They have few predators as adults."}
]
}
response = client.post("/index", json=request_data)
assert response.status_code == 200
# Query
request_data = {
"index_name": "test_index",
"query": "What is the strongest bear?",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}
response = client.post("/query", json=request_data)
assert response.status_code == 200
"""
Loading

0 comments on commit 8906190

Please sign in to comment.