Skip to content

Commit

Permalink
feat: Part 3 - Introduce Vector Store Manager and Vector Store Class (#…
Browse files Browse the repository at this point in the history
…633)

**Reason for Change**:
This series of PR will integrate llamaindex RAG service for Kaito.

This PR introduces the main logic for indexing, querying, adding and
listing documents.`faiss_store.py` is the most important file.

`faiss_store.py` introduces the code to perform RAG operations,
including persisting and loading DB from disk incase it is needed. PR
also includes tests for `faiss_store.py` in `test_faiss_store.py`.
`conftest.py` is a test configuration file, in it we ensure the tests
run synchronously and CPU only.

`manager.py` is a light abstraction around the API (`main.py`) which
will be introduced in subsequent PRs.

---------

Signed-off-by: Ishaan Sehgal <ishaanforthewin@gmail.com>
Co-authored-by: jerryzhuang <zhuangqhc@gmail.com>
Co-authored-by: Fei Guo <guofei@microsoft.com>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 65b844a commit 941170b
Show file tree
Hide file tree
Showing 12 changed files with 391 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ unit-test: ## Run unit tests.

inference-api-e2e:
pip install -r presets/inference/text-generation/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO .
pytest -o log_cli=true -o log_cli_level=INFO presets/inference/text-generation/tests

# Ginkgo configurations
GINKGO_FOCUS ?=
Expand Down
2 changes: 1 addition & 1 deletion ragengine/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from llama_index.llms.openai import OpenAI
from llama_index.core.llms.callbacks import llm_completion_callback
import requests
from config import INFERENCE_URL, INFERENCE_ACCESS_SECRET #, RESPONSE_FIELD
from ragengine.config import INFERENCE_URL, INFERENCE_ACCESS_SECRET #, RESPONSE_FIELD

class Inference(CustomLLM):
params: dict = {}
Expand Down
20 changes: 20 additions & 0 deletions ragengine/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Dict, List, Optional

from pydantic import BaseModel

class Document(BaseModel):
text: str
metadata: Optional[dict] = {}

class IndexRequest(BaseModel):
index_name: str
documents: List[Document]

class QueryRequest(BaseModel):
index_name: str
query: str
top_k: int = 10
llm_params: Optional[Dict] = None # Accept a dictionary for parameters

class ListDocumentsResponse(BaseModel):
documents:Dict[str, Dict[str, Dict[str, str]]]
Empty file added ragengine/tests/__init__.py
Empty file.
Empty file.
6 changes: 6 additions & 0 deletions ragengine/tests/vector_store/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model
os.environ["MKL_NUM_THREADS"] = "1" # Force MKL to use a single thread
142 changes: 142 additions & 0 deletions ragengine/tests/vector_store/test_faiss_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
from tempfile import TemporaryDirectory
from unittest.mock import patch

import pytest
from ragengine.vector_store.faiss_store import FaissVectorStoreHandler
from ragengine.models import Document
from ragengine.embedding.huggingface_local import LocalHuggingFaceEmbedding
from ragengine.config import MODEL_ID, INFERENCE_URL, INFERENCE_ACCESS_SECRET

@pytest.fixture(scope='session')
def init_embed_manager():
return LocalHuggingFaceEmbedding(MODEL_ID)

@pytest.fixture
def vector_store_manager(init_embed_manager):
with TemporaryDirectory() as temp_dir:
print(f"Saving temporary test storage at: {temp_dir}")
# Mock the persistence directory
os.environ['PERSIST_DIR'] = temp_dir
yield FaissVectorStoreHandler(init_embed_manager)

def test_index_documents(vector_store_manager):
documents = [
Document(doc_id="1", text="First document", metadata={"type": "text"}),
Document(doc_id="2", text="Second document", metadata={"type": "text"})
]

doc_ids = vector_store_manager.index_documents("test_index", documents)

assert len(doc_ids) == 2
assert doc_ids == ["1", "2"]

def test_index_documents_isolation(vector_store_manager):
doc_1_id, doc_2_id = "1", "2"
documents1 = [
Document(doc_id=doc_1_id, text="First document in index1", metadata={"type": "text"}),
]
documents2 = [
Document(doc_id=doc_2_id, text="First document in index2", metadata={"type": "text"}),
]

# Index documents in separate indices
index_name_1, index_name_2 = "index1", "index2"
vector_store_manager.index_documents(index_name_1, documents1)
vector_store_manager.index_documents(index_name_2, documents2)

# Ensure documents are correctly persisted and separated by index
doc_1 = vector_store_manager.get_document(index_name_1, doc_1_id)
assert doc_1 and doc_1.node_ids # Ensure documents were created

doc_2 = vector_store_manager.get_document(index_name_2, doc_2_id)
assert doc_2 and doc_2.node_ids # Ensure documents were created

# Ensure that the documents do not mix between indices
assert vector_store_manager.get_document(index_name_2, doc_1_id) is None, f"Document {doc_1_id} should not exist in {index_name_2}"
assert vector_store_manager.get_document(index_name_1, doc_2_id) is None, f"Document {doc_2_id} should not exist in {index_name_1}"

@patch('requests.post')
def test_query_documents(mock_post, vector_store_manager):
# Define Mock Response for Custom Inference API
mock_response = {
"result": "This is the completion from the API"
}

mock_post.return_value.json.return_value = mock_response

# Add documents to index
documents = [
Document(doc_id="1", text="First document", metadata={"type": "text"}),
Document(doc_id="2", text="Second document", metadata={"type": "text"})
]
vector_store_manager.index_documents("test_index", documents)

params = {"temperature": 0.7}
# Mock query and results
query_result = vector_store_manager.query("test_index", "First", top_k=1, params=params)

assert query_result is not None
assert query_result.response == "This is the completion from the API"

mock_post.assert_called_once_with(
INFERENCE_URL,
# Auto-Generated by LlamaIndex
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7},
headers={"Authorization": f"Bearer {INFERENCE_ACCESS_SECRET}"}
)

def test_add_document(vector_store_manager, capsys):
documents = [Document(doc_id="3", text="Third document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)

# Add a document to the existing index
new_document = Document(doc_id="4", text="Fourth document", metadata={"type": "text"})
vector_store_manager.index_documents("test_index", new_document)

# Assert that the document exists
assert vector_store_manager.document_exists("test_index", "4")

def test_persist_and_load_index_store(vector_store_manager):
"""Test that the index store is persisted and loaded correctly."""
# Add a document and persist the index
documents = [Document(doc_id="1", text="Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
vector_store_manager._persist("test_index")

# Simulate a fresh load of the index store (clearing in-memory state)
vector_store_manager.index_store = None # Clear current in-memory store
vector_store_manager._load_index_store()

# Verify that the store was reloaded and contains the expected index structure
assert vector_store_manager.index_store is not None
assert len(vector_store_manager.index_store.index_structs()) > 0

# TODO: Prevent default re-indexing from load_index_from_storage
def test_persist_and_load_index(vector_store_manager):
"""Test that an index is persisted and then loaded correctly."""
# Add a document and persist the index
documents = [Document(doc_id="1", text="Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)

documents = [Document(doc_id="1", text="Another Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("another_test_index", documents)

vector_store_manager._persist_all()

# Simulate a fresh load of the index (clearing in-memory state)
vector_store_manager.index_map = {} # Clear current in-memory index map
loaded_indices = vector_store_manager._load_indices()

# Verify that the index was reloaded and contains the expected document
assert loaded_indices is not None
assert vector_store_manager.document_exists("test_index", "1")
assert vector_store_manager.document_exists("another_test_index", "1")

vector_store_manager.index_map = {} # Clear current in-memory index map
loaded_index = vector_store_manager._load_index("test_index")

assert loaded_index is not None
assert vector_store_manager.document_exists("test_index", "1")
assert not vector_store_manager.document_exists("another_test_index", "1") # Since we didn't load this index

Empty file.
32 changes: 32 additions & 0 deletions ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from typing import Dict, List

from ragengine.models import Document
from llama_index.core import VectorStoreIndex
import hashlib


class BaseVectorStore(ABC):
def generate_doc_id(text: str) -> str:
"""Generates a unique document ID based on the hash of the document text."""
return hashlib.sha256(text.encode('utf-8')).hexdigest()

@abstractmethod
def index_documents(self, index_name: str, documents: List[Document]) -> List[str]:
pass

@abstractmethod
def query(self, index_name: str, query: str, top_k: int, params: dict):
pass

@abstractmethod
def add_document(self, index_name: str, document: Document):
pass

@abstractmethod
def list_all_indexed_documents(self) -> Dict[str, VectorStoreIndex]:
pass

@abstractmethod
def document_exists(self, index_name: str, doc_id: str) -> bool:
pass
167 changes: 167 additions & 0 deletions ragengine/vector_store/faiss_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
from typing import Dict, List

import faiss
from llama_index.core import Document as LlamaDocument
from llama_index.core import (StorageContext, VectorStoreIndex)
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core.storage.docstore.types import RefDocInfo
from llama_index.vector_stores.faiss import FaissVectorStore

from ragengine.models import Document
from ragengine.inference.inference import Inference

from config import PERSIST_DIR

from .base import BaseVectorStore
from ragengine.embedding.base import BaseEmbeddingModel


class FaissVectorStoreHandler(BaseVectorStore):
def __init__(self, embedding_manager: BaseEmbeddingModel):
self.embedding_manager = embedding_manager
self.embed_model = self.embedding_manager.model
self.dimension = self.embedding_manager.get_embedding_dimension()
# TODO: Consider allowing user custom indexing method (would require configmap?) e.g.
"""
# Choose the FAISS index type based on the provided index_method
if index_method == 'FlatL2':
faiss_index = faiss.IndexFlatL2(self.dimension) # L2 (Euclidean distance) index
elif index_method == 'FlatIP':
faiss_index = faiss.IndexFlatIP(self.dimension) # Inner product (cosine similarity) index
elif index_method == 'IVFFlat':
quantizer = faiss.IndexFlatL2(self.dimension) # Quantizer for IVF
faiss_index = faiss.IndexIVFFlat(quantizer, self.dimension, 100) # IVF with flat quantization
elif index_method == 'HNSW':
faiss_index = faiss.IndexHNSWFlat(self.dimension, 32) # HNSW index with 32 neighbors
else:
raise ValueError(f"Unknown index method: {index_method}")
"""
self.index_map = {} # Used to store the in-memory index via namespace (e.g. index_name -> VectorStoreIndex)
self.index_store = SimpleIndexStore() # Use to store global index metadata
self.llm = Inference()

def index_documents(self, index_name: str, documents: List[Document]) -> List[str]:
"""
Called by the /index endpoint to index documents into the specified index.
If the index already exists, appends new documents to it.
Otherwise, creates a new index with the provided documents.
Args:
index_name (str): The name of the index to update or create.
documents (List[Document]): A list of documents to index.
Returns:
List[str]: A list of document IDs that were successfully indexed.
"""
if index_name in self.index_map:
return self._append_documents_to_index(index_name, documents)
else:
return self._create_new_index(index_name, documents)

def _append_documents_to_index(self, index_name: str, documents: List[Document]) -> List[str]:
"""
Appends documents to an existing index.
Args:
index_name (str): The name of the existing index.
documents (List[Document]): A list of documents to append.
Returns:
List[str]: A list of document IDs that were successfully indexed.
"""
print(f"Index {index_name} already exists. Appending documents to existing index.")
indexed_doc_ids = set()

for doc in documents:
doc.doc_id = self.generate_doc_id(doc.text)
if not self.document_exists(index_name, doc.doc_id):
self.add_document_to_index(index_name, doc)
indexed_doc_ids.add(doc.doc_id)
else:
print(f"Document {doc.doc_id} already exists in index {index_name}. Skipping.")

if indexed_doc_ids:
self._persist(index_name)
return list(indexed_doc_ids)

def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]:
"""
Creates a new index with the provided documents.
Args:
index_name (str): The name of the new index to create.
documents (List[Document]): A list of documents to index.
Returns:
List[str]: A list of document IDs that were successfully indexed.
"""
faiss_index = faiss.IndexFlatL2(self.dimension)
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

llama_docs = []
indexed_doc_ids = set()

for doc in documents:
doc.doc_id = self.generate_doc_id(doc.text)
llama_doc = LlamaDocument(id_=doc.doc_id, text=doc.text, metadata=doc.metadata)
llama_docs.append(llama_doc)
indexed_doc_ids.add(doc.doc_id)

if llama_docs:
index = VectorStoreIndex.from_documents(
llama_docs,
storage_context=storage_context,
embed_model=self.embed_model,
# use_async=True # TODO: Indexing Process Performed Async
)
index.set_index_id(index_name)
self.index_map[index_name] = index
self.index_store.add_index_struct(index.index_struct)
self._persist(index_name)
return list(indexed_doc_ids)

def add_document_to_index(self, index_name: str, document: Document):
"""Inserts a single document into the existing FAISS index."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=document.doc_id)
self.index_map[index_name].insert(llama_doc)

def query(self, index_name: str, query: str, top_k: int, llm_params: dict):
"""Queries the FAISS vector store."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
self.llm.set_params(llm_params)

query_engine = self.index_map[index_name].as_query_engine(llm=self.llm, similarity_top_k=top_k)
return query_engine.query(query)

def list_all_indexed_documents(self) -> Dict[str, VectorStoreIndex]:
"""Lists all documents in the vector store."""
return self.index_map

def document_exists(self, index_name: str, doc_id: str) -> bool:
"""Checks if a document exists in the vector store."""
if index_name not in self.index_map:
print(f"No such index: '{index_name}' exists in vector store.")
return False
return doc_id in self.index_map[index_name].ref_doc_info

def _persist_all(self):
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) # Persist global index store
for idx in self.index_store.index_structs():
self._persist(idx.index_id)

def _persist(self, index_name: str):
"""Saves the existing FAISS index to disk."""
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) # Persist global index store
assert index_name in self.index_map, f"No such index: '{index_name}' exists."

# Persist each index's storage context separately
storage_context = self.index_map[index_name].storage_context
storage_context.persist(
persist_dir=os.path.join(PERSIST_DIR, index_name)
)
Empty file.
Loading

0 comments on commit 941170b

Please sign in to comment.