-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Part 3 - Introduce Vector Store Manager and Vector Store Class (#…
…633) **Reason for Change**: This series of PR will integrate llamaindex RAG service for Kaito. This PR introduces the main logic for indexing, querying, adding and listing documents.`faiss_store.py` is the most important file. `faiss_store.py` introduces the code to perform RAG operations, including persisting and loading DB from disk incase it is needed. PR also includes tests for `faiss_store.py` in `test_faiss_store.py`. `conftest.py` is a test configuration file, in it we ensure the tests run synchronously and CPU only. `manager.py` is a light abstraction around the API (`main.py`) which will be introduced in subsequent PRs. --------- Signed-off-by: Ishaan Sehgal <ishaanforthewin@gmail.com> Co-authored-by: jerryzhuang <zhuangqhc@gmail.com> Co-authored-by: Fei Guo <guofei@microsoft.com>
- Loading branch information
1 parent
65b844a
commit 941170b
Showing
12 changed files
with
391 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from typing import Dict, List, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
class Document(BaseModel): | ||
text: str | ||
metadata: Optional[dict] = {} | ||
|
||
class IndexRequest(BaseModel): | ||
index_name: str | ||
documents: List[Document] | ||
|
||
class QueryRequest(BaseModel): | ||
index_name: str | ||
query: str | ||
top_k: int = 10 | ||
llm_params: Optional[Dict] = None # Accept a dictionary for parameters | ||
|
||
class ListDocumentsResponse(BaseModel): | ||
documents:Dict[str, Dict[str, Dict[str, str]]] |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import sys | ||
import os | ||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing | ||
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model | ||
os.environ["MKL_NUM_THREADS"] = "1" # Force MKL to use a single thread |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import os | ||
from tempfile import TemporaryDirectory | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from ragengine.vector_store.faiss_store import FaissVectorStoreHandler | ||
from ragengine.models import Document | ||
from ragengine.embedding.huggingface_local import LocalHuggingFaceEmbedding | ||
from ragengine.config import MODEL_ID, INFERENCE_URL, INFERENCE_ACCESS_SECRET | ||
|
||
@pytest.fixture(scope='session') | ||
def init_embed_manager(): | ||
return LocalHuggingFaceEmbedding(MODEL_ID) | ||
|
||
@pytest.fixture | ||
def vector_store_manager(init_embed_manager): | ||
with TemporaryDirectory() as temp_dir: | ||
print(f"Saving temporary test storage at: {temp_dir}") | ||
# Mock the persistence directory | ||
os.environ['PERSIST_DIR'] = temp_dir | ||
yield FaissVectorStoreHandler(init_embed_manager) | ||
|
||
def test_index_documents(vector_store_manager): | ||
documents = [ | ||
Document(doc_id="1", text="First document", metadata={"type": "text"}), | ||
Document(doc_id="2", text="Second document", metadata={"type": "text"}) | ||
] | ||
|
||
doc_ids = vector_store_manager.index_documents("test_index", documents) | ||
|
||
assert len(doc_ids) == 2 | ||
assert doc_ids == ["1", "2"] | ||
|
||
def test_index_documents_isolation(vector_store_manager): | ||
doc_1_id, doc_2_id = "1", "2" | ||
documents1 = [ | ||
Document(doc_id=doc_1_id, text="First document in index1", metadata={"type": "text"}), | ||
] | ||
documents2 = [ | ||
Document(doc_id=doc_2_id, text="First document in index2", metadata={"type": "text"}), | ||
] | ||
|
||
# Index documents in separate indices | ||
index_name_1, index_name_2 = "index1", "index2" | ||
vector_store_manager.index_documents(index_name_1, documents1) | ||
vector_store_manager.index_documents(index_name_2, documents2) | ||
|
||
# Ensure documents are correctly persisted and separated by index | ||
doc_1 = vector_store_manager.get_document(index_name_1, doc_1_id) | ||
assert doc_1 and doc_1.node_ids # Ensure documents were created | ||
|
||
doc_2 = vector_store_manager.get_document(index_name_2, doc_2_id) | ||
assert doc_2 and doc_2.node_ids # Ensure documents were created | ||
|
||
# Ensure that the documents do not mix between indices | ||
assert vector_store_manager.get_document(index_name_2, doc_1_id) is None, f"Document {doc_1_id} should not exist in {index_name_2}" | ||
assert vector_store_manager.get_document(index_name_1, doc_2_id) is None, f"Document {doc_2_id} should not exist in {index_name_1}" | ||
|
||
@patch('requests.post') | ||
def test_query_documents(mock_post, vector_store_manager): | ||
# Define Mock Response for Custom Inference API | ||
mock_response = { | ||
"result": "This is the completion from the API" | ||
} | ||
|
||
mock_post.return_value.json.return_value = mock_response | ||
|
||
# Add documents to index | ||
documents = [ | ||
Document(doc_id="1", text="First document", metadata={"type": "text"}), | ||
Document(doc_id="2", text="Second document", metadata={"type": "text"}) | ||
] | ||
vector_store_manager.index_documents("test_index", documents) | ||
|
||
params = {"temperature": 0.7} | ||
# Mock query and results | ||
query_result = vector_store_manager.query("test_index", "First", top_k=1, params=params) | ||
|
||
assert query_result is not None | ||
assert query_result.response == "This is the completion from the API" | ||
|
||
mock_post.assert_called_once_with( | ||
INFERENCE_URL, | ||
# Auto-Generated by LlamaIndex | ||
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7}, | ||
headers={"Authorization": f"Bearer {INFERENCE_ACCESS_SECRET}"} | ||
) | ||
|
||
def test_add_document(vector_store_manager, capsys): | ||
documents = [Document(doc_id="3", text="Third document", metadata={"type": "text"})] | ||
vector_store_manager.index_documents("test_index", documents) | ||
|
||
# Add a document to the existing index | ||
new_document = Document(doc_id="4", text="Fourth document", metadata={"type": "text"}) | ||
vector_store_manager.index_documents("test_index", new_document) | ||
|
||
# Assert that the document exists | ||
assert vector_store_manager.document_exists("test_index", "4") | ||
|
||
def test_persist_and_load_index_store(vector_store_manager): | ||
"""Test that the index store is persisted and loaded correctly.""" | ||
# Add a document and persist the index | ||
documents = [Document(doc_id="1", text="Test document", metadata={"type": "text"})] | ||
vector_store_manager.index_documents("test_index", documents) | ||
vector_store_manager._persist("test_index") | ||
|
||
# Simulate a fresh load of the index store (clearing in-memory state) | ||
vector_store_manager.index_store = None # Clear current in-memory store | ||
vector_store_manager._load_index_store() | ||
|
||
# Verify that the store was reloaded and contains the expected index structure | ||
assert vector_store_manager.index_store is not None | ||
assert len(vector_store_manager.index_store.index_structs()) > 0 | ||
|
||
# TODO: Prevent default re-indexing from load_index_from_storage | ||
def test_persist_and_load_index(vector_store_manager): | ||
"""Test that an index is persisted and then loaded correctly.""" | ||
# Add a document and persist the index | ||
documents = [Document(doc_id="1", text="Test document", metadata={"type": "text"})] | ||
vector_store_manager.index_documents("test_index", documents) | ||
|
||
documents = [Document(doc_id="1", text="Another Test document", metadata={"type": "text"})] | ||
vector_store_manager.index_documents("another_test_index", documents) | ||
|
||
vector_store_manager._persist_all() | ||
|
||
# Simulate a fresh load of the index (clearing in-memory state) | ||
vector_store_manager.index_map = {} # Clear current in-memory index map | ||
loaded_indices = vector_store_manager._load_indices() | ||
|
||
# Verify that the index was reloaded and contains the expected document | ||
assert loaded_indices is not None | ||
assert vector_store_manager.document_exists("test_index", "1") | ||
assert vector_store_manager.document_exists("another_test_index", "1") | ||
|
||
vector_store_manager.index_map = {} # Clear current in-memory index map | ||
loaded_index = vector_store_manager._load_index("test_index") | ||
|
||
assert loaded_index is not None | ||
assert vector_store_manager.document_exists("test_index", "1") | ||
assert not vector_store_manager.document_exists("another_test_index", "1") # Since we didn't load this index | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, List | ||
|
||
from ragengine.models import Document | ||
from llama_index.core import VectorStoreIndex | ||
import hashlib | ||
|
||
|
||
class BaseVectorStore(ABC): | ||
def generate_doc_id(text: str) -> str: | ||
"""Generates a unique document ID based on the hash of the document text.""" | ||
return hashlib.sha256(text.encode('utf-8')).hexdigest() | ||
|
||
@abstractmethod | ||
def index_documents(self, index_name: str, documents: List[Document]) -> List[str]: | ||
pass | ||
|
||
@abstractmethod | ||
def query(self, index_name: str, query: str, top_k: int, params: dict): | ||
pass | ||
|
||
@abstractmethod | ||
def add_document(self, index_name: str, document: Document): | ||
pass | ||
|
||
@abstractmethod | ||
def list_all_indexed_documents(self) -> Dict[str, VectorStoreIndex]: | ||
pass | ||
|
||
@abstractmethod | ||
def document_exists(self, index_name: str, doc_id: str) -> bool: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import os | ||
from typing import Dict, List | ||
|
||
import faiss | ||
from llama_index.core import Document as LlamaDocument | ||
from llama_index.core import (StorageContext, VectorStoreIndex) | ||
from llama_index.core.storage.index_store import SimpleIndexStore | ||
from llama_index.core.storage.docstore.types import RefDocInfo | ||
from llama_index.vector_stores.faiss import FaissVectorStore | ||
|
||
from ragengine.models import Document | ||
from ragengine.inference.inference import Inference | ||
|
||
from config import PERSIST_DIR | ||
|
||
from .base import BaseVectorStore | ||
from ragengine.embedding.base import BaseEmbeddingModel | ||
|
||
|
||
class FaissVectorStoreHandler(BaseVectorStore): | ||
def __init__(self, embedding_manager: BaseEmbeddingModel): | ||
self.embedding_manager = embedding_manager | ||
self.embed_model = self.embedding_manager.model | ||
self.dimension = self.embedding_manager.get_embedding_dimension() | ||
# TODO: Consider allowing user custom indexing method (would require configmap?) e.g. | ||
""" | ||
# Choose the FAISS index type based on the provided index_method | ||
if index_method == 'FlatL2': | ||
faiss_index = faiss.IndexFlatL2(self.dimension) # L2 (Euclidean distance) index | ||
elif index_method == 'FlatIP': | ||
faiss_index = faiss.IndexFlatIP(self.dimension) # Inner product (cosine similarity) index | ||
elif index_method == 'IVFFlat': | ||
quantizer = faiss.IndexFlatL2(self.dimension) # Quantizer for IVF | ||
faiss_index = faiss.IndexIVFFlat(quantizer, self.dimension, 100) # IVF with flat quantization | ||
elif index_method == 'HNSW': | ||
faiss_index = faiss.IndexHNSWFlat(self.dimension, 32) # HNSW index with 32 neighbors | ||
else: | ||
raise ValueError(f"Unknown index method: {index_method}") | ||
""" | ||
self.index_map = {} # Used to store the in-memory index via namespace (e.g. index_name -> VectorStoreIndex) | ||
self.index_store = SimpleIndexStore() # Use to store global index metadata | ||
self.llm = Inference() | ||
|
||
def index_documents(self, index_name: str, documents: List[Document]) -> List[str]: | ||
""" | ||
Called by the /index endpoint to index documents into the specified index. | ||
If the index already exists, appends new documents to it. | ||
Otherwise, creates a new index with the provided documents. | ||
Args: | ||
index_name (str): The name of the index to update or create. | ||
documents (List[Document]): A list of documents to index. | ||
Returns: | ||
List[str]: A list of document IDs that were successfully indexed. | ||
""" | ||
if index_name in self.index_map: | ||
return self._append_documents_to_index(index_name, documents) | ||
else: | ||
return self._create_new_index(index_name, documents) | ||
|
||
def _append_documents_to_index(self, index_name: str, documents: List[Document]) -> List[str]: | ||
""" | ||
Appends documents to an existing index. | ||
Args: | ||
index_name (str): The name of the existing index. | ||
documents (List[Document]): A list of documents to append. | ||
Returns: | ||
List[str]: A list of document IDs that were successfully indexed. | ||
""" | ||
print(f"Index {index_name} already exists. Appending documents to existing index.") | ||
indexed_doc_ids = set() | ||
|
||
for doc in documents: | ||
doc.doc_id = self.generate_doc_id(doc.text) | ||
if not self.document_exists(index_name, doc.doc_id): | ||
self.add_document_to_index(index_name, doc) | ||
indexed_doc_ids.add(doc.doc_id) | ||
else: | ||
print(f"Document {doc.doc_id} already exists in index {index_name}. Skipping.") | ||
|
||
if indexed_doc_ids: | ||
self._persist(index_name) | ||
return list(indexed_doc_ids) | ||
|
||
def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]: | ||
""" | ||
Creates a new index with the provided documents. | ||
Args: | ||
index_name (str): The name of the new index to create. | ||
documents (List[Document]): A list of documents to index. | ||
Returns: | ||
List[str]: A list of document IDs that were successfully indexed. | ||
""" | ||
faiss_index = faiss.IndexFlatL2(self.dimension) | ||
vector_store = FaissVectorStore(faiss_index=faiss_index) | ||
storage_context = StorageContext.from_defaults(vector_store=vector_store) | ||
|
||
llama_docs = [] | ||
indexed_doc_ids = set() | ||
|
||
for doc in documents: | ||
doc.doc_id = self.generate_doc_id(doc.text) | ||
llama_doc = LlamaDocument(id_=doc.doc_id, text=doc.text, metadata=doc.metadata) | ||
llama_docs.append(llama_doc) | ||
indexed_doc_ids.add(doc.doc_id) | ||
|
||
if llama_docs: | ||
index = VectorStoreIndex.from_documents( | ||
llama_docs, | ||
storage_context=storage_context, | ||
embed_model=self.embed_model, | ||
# use_async=True # TODO: Indexing Process Performed Async | ||
) | ||
index.set_index_id(index_name) | ||
self.index_map[index_name] = index | ||
self.index_store.add_index_struct(index.index_struct) | ||
self._persist(index_name) | ||
return list(indexed_doc_ids) | ||
|
||
def add_document_to_index(self, index_name: str, document: Document): | ||
"""Inserts a single document into the existing FAISS index.""" | ||
if index_name not in self.index_map: | ||
raise ValueError(f"No such index: '{index_name}' exists.") | ||
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=document.doc_id) | ||
self.index_map[index_name].insert(llama_doc) | ||
|
||
def query(self, index_name: str, query: str, top_k: int, llm_params: dict): | ||
"""Queries the FAISS vector store.""" | ||
if index_name not in self.index_map: | ||
raise ValueError(f"No such index: '{index_name}' exists.") | ||
self.llm.set_params(llm_params) | ||
|
||
query_engine = self.index_map[index_name].as_query_engine(llm=self.llm, similarity_top_k=top_k) | ||
return query_engine.query(query) | ||
|
||
def list_all_indexed_documents(self) -> Dict[str, VectorStoreIndex]: | ||
"""Lists all documents in the vector store.""" | ||
return self.index_map | ||
|
||
def document_exists(self, index_name: str, doc_id: str) -> bool: | ||
"""Checks if a document exists in the vector store.""" | ||
if index_name not in self.index_map: | ||
print(f"No such index: '{index_name}' exists in vector store.") | ||
return False | ||
return doc_id in self.index_map[index_name].ref_doc_info | ||
|
||
def _persist_all(self): | ||
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) # Persist global index store | ||
for idx in self.index_store.index_structs(): | ||
self._persist(idx.index_id) | ||
|
||
def _persist(self, index_name: str): | ||
"""Saves the existing FAISS index to disk.""" | ||
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) # Persist global index store | ||
assert index_name in self.index_map, f"No such index: '{index_name}' exists." | ||
|
||
# Persist each index's storage context separately | ||
storage_context = self.index_map[index_name].storage_context | ||
storage_context.persist( | ||
persist_dir=os.path.join(PERSIST_DIR, index_name) | ||
) |
Empty file.
Oops, something went wrong.