-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement BreviaBaseRetriever and integrate with query handling
- Loading branch information
Showing
3 changed files
with
79 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import List | ||
from langchain_core.callbacks import CallbackManagerForRetrieverRun | ||
from langchain_core.documents import Document | ||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever | ||
|
||
|
||
class BreviaBaseRetriever(VectorStoreRetriever): | ||
""" Base custom Retriever for BREVIA""" | ||
|
||
vectorstore: VectorStore | ||
"""VectorStore used for retrieval.""" | ||
|
||
search_kwargs: dict | ||
"""Configuration containing settings for the search from the application""" | ||
|
||
async def _aget_relevant_documents( | ||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | ||
) -> List[Document]: | ||
""" | ||
Asynchronous implementation for retrieving relevant documents with score | ||
Merges results from multiple custom searches using different filters. | ||
Parameters: | ||
query (str): The search query. | ||
run_manager (CallbackManagerForRetrieverRun): Manager for retriever runs. | ||
Returns: | ||
List[Document]: A list of relevant documents based on the search. | ||
""" | ||
if self.search_type == "similarity": | ||
docs = await self.vectorstore.asimilarity_search( | ||
query, **self.search_kwargs | ||
) | ||
elif self.search_type == "similarity_score_threshold": | ||
docs_and_similarities = ( | ||
await self.vectorstore.asimilarity_search_with_relevance_scores( | ||
query, **self.search_kwargs | ||
) | ||
) | ||
for doc, score in docs_and_similarities: | ||
doc.metadata["score"] = score | ||
docs = [doc for doc, _ in docs_and_similarities] | ||
elif self.search_type == "mmr": | ||
docs = await self.vectorstore.amax_marginal_relevance_search( | ||
query, **self.search_kwargs | ||
) | ||
else: | ||
msg = f"search_type of {self.search_type} not allowed." | ||
raise ValueError(msg) | ||
return docs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters