Skip to content

Commit

Permalink
feat: implement BreviaBaseRetriever and integrate with query handling
Browse files Browse the repository at this point in the history
  • Loading branch information
nikazzio committed Jan 29, 2025
1 parent 08f8fcd commit f1d7458
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 13 deletions.
50 changes: 50 additions & 0 deletions brevia/base_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever


class BreviaBaseRetriever(VectorStoreRetriever):
""" Base custom Retriever for BREVIA"""

vectorstore: VectorStore
"""VectorStore used for retrieval."""

search_kwargs: dict
"""Configuration containing settings for the search from the application"""

async def _aget_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Asynchronous implementation for retrieving relevant documents with score
Merges results from multiple custom searches using different filters.
Parameters:
query (str): The search query.
run_manager (CallbackManagerForRetrieverRun): Manager for retriever runs.
Returns:
List[Document]: A list of relevant documents based on the search.
"""
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
for doc, score in docs_and_similarities:
doc.metadata["score"] = score
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs
)
else:
msg = f"search_type of {self.search_type} not allowed."
raise ValueError(msg)
return docs
23 changes: 20 additions & 3 deletions brevia/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from brevia.models import load_chatmodel, load_embeddings
from brevia.settings import get_settings
from brevia.utilities.types import load_type
from brevia.base_retriever import BreviaBaseRetriever

# system = load_prompt(f'{prompts_path}/qa/default.system.yaml')
# jinja2 template from file was disabled by langchain so, for now
Expand Down Expand Up @@ -142,6 +143,16 @@ class ChatParams(BaseModel):
filter: dict[str, str | dict] | None = None
source_docs: bool = False
multiquery: bool = False
search_type: str = "similarity"
score_threshold: float = 0.0

def get_search_kwargs(self) -> dict:
""" Return search kwargs """
return {
'k': self.docs_num,
'filter': self.filter,
'score_threshold': self.score_threshold,
}


def create_custom_retriever(
Expand All @@ -166,14 +177,20 @@ def create_default_retriever(
store: VectorStore,
search_kwargs: dict,
llm: BaseChatModel,
search_type: str | None = None,
multiquery: bool = False,

) -> BaseRetriever:
"""
Create a default retriever.
Can be a vector store retriever or a multiquery retriever.
"""
retriever = store.as_retriever(search_kwargs=search_kwargs)
retriever = BreviaBaseRetriever(
vectorstore=store,
search_type=search_type,
search_kwargs=search_kwargs
)

if multiquery:
return MultiQueryRetriever.from_llm(retriever=retriever, llm=llm)

Expand All @@ -199,8 +216,7 @@ def create_conversation_retriever(
distance_strategy=strategy,
use_jsonb=True,
)

search_kwargs = {'k': chat_params.docs_num, 'filter': chat_params.filter}
search_kwargs = chat_params.get_search_kwargs()
retriever_conf = collection.cmetadata.get(
'qa_retriever',
get_settings().qa_retriever.copy()
Expand All @@ -210,6 +226,7 @@ def create_conversation_retriever(
store=document_search,
search_kwargs=search_kwargs,
llm=llm,
search_type=chat_params.search_type,
multiquery=chat_params.multiquery,
)

Expand Down
19 changes: 9 additions & 10 deletions brevia/routers/qa_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,18 @@ async def run_chain(
):
"""Run chain usign async methods and return result"""
result = await chain.ainvoke({
'input': chat_body.question,
'chat_history': retrieve_chat_history(
history=chat_body.chat_history,
question=chat_body.question,
session=x_chat_session,
embeddings=embeddings,
),
'lang': lang,
},
'input': chat_body.question,
'chat_history': retrieve_chat_history(
history=chat_body.chat_history,
question=chat_body.question,
session=x_chat_session,
embeddings=embeddings,
),
'lang': lang,
},
config={'callbacks': chain_callbacks},
return_only_outputs=True,
)

return chat_result(
result=result,
callb=token_callback,
Expand Down

0 comments on commit f1d7458

Please sign in to comment.