Skip to content

Commit

Permalink
feat: RAG API Server to use Async/Await (#835)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Endpoints are currently sequentially blocking synchronous. After review
we use async-await pattern correctly.

PR also fixes up the endpoints to be more REST-friendly - here's the
three new/fixed endpoints
@app.get("/indexes", response_model=List[str]) - returns all indexes
names (new endpoint)
...
 
@app.get("/indexes/{index_name}/documents",
response_model=ListDocumentsResponse) - returns all documents in an
index (new endpoint)
....
 
@app.get("/documents", response_model=ListDocumentsResponse) - returns
all documents across all indexes
  • Loading branch information
ishaansehgal99 authored Jan 28, 2025
1 parent 59583af commit 979b739
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 173 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ unit-test: ## Run unit tests.

.PHONY: rag-service-test
rag-service-test:
pip install -r presets/ragengine/requirements.txt
pip install -r presets/ragengine/requirements-test.txt
pip install pytest-cov
pytest --cov -o log_cli=true -o log_cli_level=INFO presets/ragengine/tests

Expand Down
45 changes: 38 additions & 7 deletions presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ragengine.config import (REMOTE_EMBEDDING_URL, REMOTE_EMBEDDING_ACCESS_SECRET,
EMBEDDING_SOURCE_TYPE, LOCAL_EMBEDDING_MODEL_ID)
from urllib.parse import unquote

app = FastAPI()

Expand Down Expand Up @@ -45,9 +46,9 @@ def health_check():
raise HTTPException(status_code=500, detail=str(e))

@app.post("/index", response_model=List[DocumentResponse])
async def index_documents(request: IndexRequest): # TODO: Research async/sync what to use (inference is calling)
async def index_documents(request: IndexRequest):
try:
doc_ids = rag_ops.index(request.index_name, request.documents)
doc_ids = await rag_ops.index(request.index_name, request.documents)
documents = [
DocumentResponse(doc_id=doc_id, text=doc.text, metadata=doc.metadata)
for doc_id, doc in zip(doc_ids, request.documents)
Expand All @@ -61,7 +62,7 @@ async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(
return await rag_ops.query(
request.index_name, request.query, request.top_k, llm_params, rerank_params
)
except ValueError as ve:
Expand All @@ -71,14 +72,44 @@ async def query_index(request: QueryRequest):
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
)

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
@app.get("/indexes", response_model=List[str])
def list_indexes():
try:
documents = rag_ops.list_all_indexed_documents()
return rag_ops.list_indexes()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/indexes/{index_name}/documents", response_model=ListDocumentsResponse)
async def list_documents_in_index(index_name: str):
"""
Handles URL-encoded index names sent by the client.
Examples:
Raw Index Name | URL-Encoded Form | Decoded Form
------------------|--------------------|--------------
my_index | my_index | my_index
my index | my%20index | my index
index/name | index%2Fname | index/name
index@name | index%40name | index@name
index#1 | index%231 | index#1
index?query | index%3Fquery | index?query
"""
try:
# Decode the index_name in case it was URL-encoded by the client
decoded_index_name = unquote(index_name)
documents = await rag_ops.list_documents_in_index(decoded_index_name)
return ListDocumentsResponse(documents=documents)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/documents", response_model=ListDocumentsResponse)
async def list_all_documents():
try:
documents = await rag_ops.list_all_documents()
return ListDocumentsResponse(documents=documents)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
uvicorn.run(app, host="0.0.0.0", port=5000, loop="asyncio")
6 changes: 6 additions & 0 deletions presets/ragengine/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Common dependencies
-r requirements.txt

# Test dependencies
pytest
pytest-asyncio
3 changes: 1 addition & 2 deletions presets/ragengine/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,4 @@ llama-index-vector-stores-faiss
llama-index-vector-stores-chroma
llama-index-vector-stores-azurecosmosmongo
uvicorn
# For UTs
pytest
asyncio
6 changes: 3 additions & 3 deletions presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def test_query_index_failure():
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


def test_list_all_indexed_documents_success():
response = client.get("/indexed-documents")
def test_list_all_documents_success():
response = client.get("/documents")
assert response.status_code == 200
assert response.json() == {'documents': {}}

Expand All @@ -195,7 +195,7 @@ def test_list_all_indexed_documents_success():
response = client.post("/index", json=request_data)
assert response.status_code == 200

response = client.get("/indexed-documents")
response = client.get("/documents")
assert response.status_code == 200
assert "test_index" in response.json()["documents"]
response_idx = response.json()["documents"]["test_index"]
Expand Down
1 change: 1 addition & 0 deletions presets/ragengine/tests/vector_store/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model
Expand Down
46 changes: 26 additions & 20 deletions presets/ragengine/tests/vector_store/test_base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,22 @@ def expected_query_score(self):
"""Override this in implementation-specific test classes."""
pass

def test_index_documents(self, vector_store_manager):
@pytest.mark.asyncio
async def test_index_documents(self, vector_store_manager):
first_doc_text, second_doc_text = "First document", "Second document"
documents = [
Document(text=first_doc_text, metadata={"type": "text"}),
Document(text=second_doc_text, metadata={"type": "text"})
]

doc_ids = vector_store_manager.index_documents("test_index", documents)
doc_ids = await vector_store_manager.index_documents("test_index", documents)

assert len(doc_ids) == 2
assert set(doc_ids) == {BaseVectorStore.generate_doc_id(first_doc_text),
BaseVectorStore.generate_doc_id(second_doc_text)}

def test_index_documents_isolation(self, vector_store_manager):
@pytest.mark.asyncio
async def test_index_documents_isolation(self, vector_store_manager):
documents1 = [
Document(text="First document in index1", metadata={"type": "text"}),
]
Expand All @@ -54,19 +56,20 @@ def test_index_documents_isolation(self, vector_store_manager):

# Index documents in separate indices
index_name_1, index_name_2 = "index1", "index2"
vector_store_manager.index_documents(index_name_1, documents1)
vector_store_manager.index_documents(index_name_2, documents2)
await vector_store_manager.index_documents(index_name_1, documents1)
await vector_store_manager.index_documents(index_name_2, documents2)

# Call the backend-specific check method
self.check_indexed_documents(vector_store_manager)
await self.check_indexed_documents(vector_store_manager)

@abstractmethod
def check_indexed_documents(self, vector_store_manager):
"""Abstract method to check indexed documents in backend-specific format."""
pass

@pytest.mark.asyncio
@patch('requests.post')
def test_query_documents(self, mock_post, vector_store_manager):
async def test_query_documents(self, mock_post, vector_store_manager):
mock_response = {
"result": "This is the completion from the API"
}
Expand All @@ -76,10 +79,10 @@ def test_query_documents(self, mock_post, vector_store_manager):
Document(text="First document", metadata={"type": "text"}),
Document(text="Second document", metadata={"type": "text"})
]
vector_store_manager.index_documents("test_index", documents)
await vector_store_manager.index_documents("test_index", documents)

params = {"temperature": 0.7}
query_result = vector_store_manager.query("test_index", "First", top_k=1,
query_result = await vector_store_manager.query("test_index", "First", top_k=1,
llm_params=params, rerank_params={})

assert query_result is not None
Expand All @@ -93,28 +96,31 @@ def test_query_documents(self, mock_post, vector_store_manager):
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", 'Content-Type': 'application/json'}
)

def test_add_document(self, vector_store_manager):
@pytest.mark.asyncio
async def test_add_document(self, vector_store_manager):
documents = [Document(text="Third document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
await vector_store_manager.index_documents("test_index", documents)

new_document = [Document(text="Fourth document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", new_document)
await vector_store_manager.index_documents("test_index", new_document)

assert vector_store_manager.document_exists("test_index", new_document[0],
assert await vector_store_manager.document_exists("test_index", new_document[0],
BaseVectorStore.generate_doc_id("Fourth document"))

def test_persist_index_1(self, vector_store_manager):
@pytest.mark.asyncio
async def test_persist_index_1(self, vector_store_manager):
documents = [Document(text="Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
vector_store_manager._persist("test_index")
await vector_store_manager.index_documents("test_index", documents)
await vector_store_manager._persist("test_index")
assert os.path.exists(VECTOR_DB_PERSIST_DIR)

def test_persist_index_2(self, vector_store_manager):
@pytest.mark.asyncio
async def test_persist_index_2(self, vector_store_manager):
documents = [Document(text="Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
await vector_store_manager.index_documents("test_index", documents)

documents = [Document(text="Another Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("another_test_index", documents)
await vector_store_manager.index_documents("another_test_index", documents)

vector_store_manager._persist_all()
await vector_store_manager._persist_all()
assert os.path.exists(VECTOR_DB_PERSIST_DIR)
33 changes: 0 additions & 33 deletions presets/ragengine/tests/vector_store/test_chromadb_store.py

This file was deleted.

5 changes: 3 additions & 2 deletions presets/ragengine/tests/vector_store/test_faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def vector_store_manager(self, init_embed_manager):
os.environ['PERSIST_DIR'] = temp_dir
yield FaissVectorStoreHandler(init_embed_manager)

def check_indexed_documents(self, vector_store_manager):
@pytest.mark.asyncio
async def check_indexed_documents(self, vector_store_manager):
expected_output = {
'index1': {"87117028123498eb7d757b1507aa3e840c63294f94c27cb5ec83c939dedb32fd": {
'hash': '1e64a170be48c45efeaa8667ab35919106da0489ec99a11d0029f2842db133aa',
Expand All @@ -29,7 +30,7 @@ def check_indexed_documents(self, vector_store_manager):
'text': 'First document in index2'
}}
}
assert vector_store_manager.list_all_indexed_documents() == expected_output
assert await vector_store_manager.list_all_documents() == expected_output

@property
def expected_query_score(self):
Expand Down
Loading

0 comments on commit 979b739

Please sign in to comment.