Skip to content

Commit

Permalink
Merge branch 'main' into RAG_secret
Browse files Browse the repository at this point in the history
  • Loading branch information
bangqipropel authored Feb 26, 2025
2 parents 9229f8a + 0bbd6ef commit 23dbcc8
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 39 deletions.
15 changes: 9 additions & 6 deletions presets/ragengine/embedding/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List
from abc import ABC, abstractmethod
from llama_index.core.embeddings import BaseEmbedding
import asyncio

class BaseEmbeddingModel(BaseEmbedding, ABC):
async def _aget_text_embedding(self, text: str) -> List[float]:
return await asyncio.to_thread(self._get_text_embedding, text)

async def _aget_query_embedding(self, query: str) -> List[float]:
return await asyncio.to_thread(self._get_query_embedding, query)

class BaseEmbeddingModel(ABC):
@abstractmethod
def get_text_embedding(self, text: str):
"""Returns the text embedding for a given input string."""
pass

@abstractmethod
def get_embedding_dimension(self) -> int:
"""Returns the embedding dimension for the model."""
Expand Down
14 changes: 2 additions & 12 deletions presets/ragengine/embedding/huggingface_local_embedding.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from .base import BaseEmbeddingModel


class LocalHuggingFaceEmbedding(BaseEmbeddingModel):
def __init__(self, model_name: str):
self.model = HuggingFaceEmbedding(model_name=model_name) # TODO: Ensure/test loads on GPU (when available)

def get_text_embedding(self, text: str):
"""Returns the text embedding for a given input string."""
return self.model.get_text_embedding(text)

class LocalHuggingFaceEmbedding(HuggingFaceEmbedding, BaseEmbeddingModel):
def get_embedding_dimension(self) -> int:
"""Infers the embedding dimension by making a local call to get the embedding of a dummy text."""
dummy_input = "This is a dummy sentence."
embedding = self.get_text_embedding(dummy_input)

return len(embedding)
13 changes: 8 additions & 5 deletions presets/ragengine/embedding/remote_embedding.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any
import requests
import json
from .base import BaseEmbeddingModel


class RemoteEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_url: str, api_key: str):
def __init__(self, model_url: str, api_key: str, /, **data: Any):
"""
Initialize the RemoteEmbeddingModel.
Args:
model_url (str): The URL of the embedding model API endpoint.
api_key (str): The API key for accessing the API.
"""
super().__init__(**data)
self.model_url = model_url
self.api_key = api_key

def get_text_embedding(self, text: str):
def _get_text_embedding(self, text: str):
"""Returns the text embedding for a given input string."""
headers = {
"Authorization": f"Bearer {self.api_key}",
Expand All @@ -39,9 +39,12 @@ def get_text_embedding(self, text: str):
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Failed to get embedding from remote model: {e}")

def _get_query_embedding(self, query: str):
return self.get_text_embedding(query)

def get_embedding_dimension(self) -> int:
"""Infers the embedding dimension by making a remote call to get the embedding of a dummy text."""
dummy_input = "This is a dummy sentence."
embedding = self.get_text_embedding(dummy_input)
embedding = self._get_text_embedding(dummy_input)

return len(embedding)
35 changes: 35 additions & 0 deletions presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ragengine.config import (REMOTE_EMBEDDING_URL, REMOTE_EMBEDDING_ACCESS_SECRET,
EMBEDDING_SOURCE_TYPE, LOCAL_EMBEDDING_MODEL_ID, DEFAULT_VECTOR_DB_PERSIST_DIR)
from urllib.parse import unquote
import os

app = FastAPI()

Expand Down Expand Up @@ -279,11 +280,45 @@ async def persist_index(
): # TODO: Provide endpoint for loading existing index(es)
# TODO: Extend support for other vector databases/integrations besides FAISS
try:
# Append index to save path to prevent saving conflicts/overwriting
path = os.path.join(path, index_name) if path == DEFAULT_VECTOR_DB_PERSIST_DIR else path
await rag_ops.persist(index_name, path)
return {"message": f"Successfully persisted index {index_name} to {path}."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Persistence failed: {str(e)}")

@app.post(
"/load/{index_name}",
summary="Load Index Data from Disk",
description="""
Load an existing index from disk at a specified location.
## Request Example:
```
POST /load/example_index?path=./custom_path/example_index
```
If no path is provided, will attempt to load from the default directory.
## Response Example:
```json
{
"message": "Successfully loaded index example_index from ./custom_path/example_index."
}
```
"""
)
async def load_index(
index_name: str,
path: str = Query(DEFAULT_VECTOR_DB_PERSIST_DIR, description="Path to load the index from"),
overwrite: bool = Query(False, description="Overwrite the existing index if it already exists")
):
try:
await rag_ops.load(index_name, path, overwrite)
return {"message": f"Successfully loaded index {index_name} from {path}."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Loading failed: {str(e)}")

@app.on_event("shutdown")
async def shutdown_event():
""" Ensure the client is properly closed when the server shuts down. """
Expand Down
26 changes: 24 additions & 2 deletions presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ async def test_persist_documents(async_client):
response = await async_client.post(f"/persist/{index_name}")
assert response.status_code == 200
response_json = response.json()
assert response_json == {"message": f"Successfully persisted index {index_name} to {DEFAULT_VECTOR_DB_PERSIST_DIR}."}
assert response_json == {"message": f"Successfully persisted index {index_name} to {DEFAULT_VECTOR_DB_PERSIST_DIR}/{index_name}."}
assert os.path.exists(os.path.join(DEFAULT_VECTOR_DB_PERSIST_DIR, index_name))

# Persist documents for the specific index at a custom path
Expand All @@ -344,7 +344,29 @@ async def test_persist_documents(async_client):
assert response.status_code == 200
response_json = response.json()
assert response_json == {"message": f"Successfully persisted index {index_name} to {custom_path}."}
assert os.path.exists(os.path.join(custom_path, index_name))
assert os.path.exists(custom_path)

@pytest.mark.asyncio
async def test_load_documents(async_client):
index_name = "test_index"
response = await async_client.post(f"/load/{index_name}?path={DEFAULT_VECTOR_DB_PERSIST_DIR}/{index_name}")

assert response.status_code == 200
assert response.json() == {'message': 'Successfully loaded index test_index from storage/test_index.'}

response = await async_client.get(f"/indexes")
assert response.status_code == 200
assert response.json() == [index_name]

response = await async_client.get(f"/indexes/test_index/documents")
assert response.status_code == 200
response_data = response.json()

assert response_data["count"] == 2
assert len(response_data["documents"]) == 2
assert response_data["documents"][0]["text"] == "This is a test document"
assert response_data["documents"][1]["text"] == "Another test document"


"""
Example of a live query test. This test is currently commented out as it requires a valid
Expand Down
63 changes: 53 additions & 10 deletions presets/ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import os
import asyncio
from itertools import islice
from collections import defaultdict

from llama_index.core import Document as LlamaDocument
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core import (StorageContext, VectorStoreIndex)
from llama_index.core import (StorageContext, VectorStoreIndex, load_index_from_storage)
from llama_index.core.postprocessor import LLMRerank # Query with LLM Reranking

from llama_index.vector_stores.faiss import FaissVectorStore

from ragengine.models import Document, DocumentResponse
from ragengine.embedding.base import BaseEmbeddingModel
from ragengine.inference.inference import Inference
Expand All @@ -29,12 +30,12 @@
logger = logging.getLogger(__name__)

class BaseVectorStore(ABC):
def __init__(self, embedding_manager: BaseEmbeddingModel, use_rwlock: bool = False):
self.embedding_manager = embedding_manager
self.embed_model = self.embedding_manager.model
def __init__(self, embed_model: BaseEmbeddingModel, use_rwlock: bool = False):
super().__init__()
self.llm = Inference()
self.embed_model = embed_model
self.index_map = {}
self.index_store = SimpleIndexStore()
self.llm = Inference()
# Use a reader/writer lock only if needed
self.use_rwlock = use_rwlock
self.rwlock = aiorwlock.RWLock() if self.use_rwlock else None
Expand Down Expand Up @@ -281,13 +282,55 @@ async def _persist_internal(self, index_name: str, path: str):
if index_name not in self.index_map:
raise HTTPException(status_code=404, detail=f"No such index: '{index_name}' exists.")

logger.info(f"Persisting index {index_name} into {path}.")
await asyncio.to_thread(self.index_store.persist, os.path.join(path, "store.json"))

# Persist the specific index
storage_context = self.index_map[index_name].storage_context
await asyncio.to_thread(storage_context.persist, os.path.join(path, index_name))
await asyncio.to_thread(storage_context.persist, path)
logger.info(f"Successfully persisted index {index_name}.")
except Exception as e:
logger.error(f"Failed to persist index {index_name}. Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Persistence failed: {str(e)}")

async def load(self, index_name: str, path: str, overwrite: bool):
"""Common logic for loading an index."""
# Check path existence before acquiring any lock
if not os.path.exists(path):
raise HTTPException(status_code=404, detail=f"Path does not exist: {path}")
if self.use_rwlock:
async with self.rwlock.writer_lock:
await self._load_internal(index_name, path, overwrite)
else:
await self._load_internal(index_name, path, overwrite)

async def _load_internal(self, index_name: str, path: str, overwrite: bool):
"""Common logic for loading an index."""
try:
if index_name in self.index_map and not overwrite:
raise HTTPException(
status_code=409,
detail=f"Index '{index_name}' already exists. Use a different name or delete the existing index first."
)

logger.info(f"Loading index {index_name} from {path}.")

try:
storage_context = StorageContext.from_defaults(persist_dir=path)
except UnicodeDecodeError as ude:
# Failed to load the index in the default json format, trying faissdb
faiss_vs = FaissVectorStore.from_persist_dir(persist_dir=path)
storage_context = StorageContext.from_defaults(persist_dir=path, vector_store=faiss_vs)
except Exception as e:
logger.error(f"Failed to load index '{index_name}'. Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Loading failed: {str(e)}")

logger.info(f"Loading index '{index_name}' using the workspace's embedding model.")
# Load the index using the workspace's embedding model, assuming all indices
# were created using the same embedding model currently in use.
loaded_index = await asyncio.to_thread(load_index_from_storage,
storage_context,
embed_model=self.embed_model,
show_progress=True)
self.index_map[index_name] = loaded_index
logger.info(f"Successfully loaded index {index_name}.")
except Exception as e:
logger.error(f"Failed to load index {index_name}. Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Loading failed: {str(e)}")
7 changes: 4 additions & 3 deletions presets/ragengine/vector_store/faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import faiss
from llama_index.vector_stores.faiss import FaissVectorStore
from ragengine.models import Document
from ragengine.embedding.base import BaseEmbeddingModel
from .base import BaseVectorStore


class FaissVectorStoreHandler(BaseVectorStore):
def __init__(self, embedding_manager):
super().__init__(embedding_manager, use_rwlock=True)
self.dimension = self.embedding_manager.get_embedding_dimension()
def __init__(self, embed_model: BaseEmbeddingModel):
super().__init__(embed_model, use_rwlock=True)
self.dimension = self.embed_model.get_embedding_dimension()

async def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]:
faiss_index = faiss.IndexFlatL2(self.dimension)
Expand Down
6 changes: 5 additions & 1 deletion presets/ragengine/vector_store_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ async def list_documents_in_index(self,
)

async def persist(self, index_name: str, path: str) -> None:
"""Persist existing index(es)."""
"""Persist existing index."""
return await self.vector_store.persist(index_name, path)

async def load(self, index_name: str, path: str, overwrite: bool) -> None:
"""Load existing index."""
return await self.vector_store.load(index_name, path, overwrite)

async def shutdown(self):
"""Shutdown the manager."""
await self.vector_store.shutdown()

0 comments on commit 23dbcc8

Please sign in to comment.