Skip to content

Commit

Permalink
rag: bypass ingestion when payloads are matching already saved data (#…
Browse files Browse the repository at this point in the history
…157)

Co-authored-by: Avram Tudor <tudor.avram@8x8.com>
  • Loading branch information
quitrk and Avram Tudor authored Feb 20, 2025
1 parent e487c28 commit a96150b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
6 changes: 4 additions & 2 deletions skynet/modules/ttt/assistant/v1/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel

from skynet.modules.ttt.assistant.constants import assistant_default_system_message

from skynet.modules.ttt.summaries.v1.models import DocumentPayload, HintType

default_max_depth = 5
Expand All @@ -25,7 +27,7 @@ class RagPayload(BaseModel):
{
'urls': ['https://jitsi.github.io/handbook'],
'max_depth': default_max_depth,
'system_message': 'You are an AI assistant of Jitsi, a video conferencing platform. You provide response suggestions to the support agent',
'system_message': assistant_default_system_message,
}
]
}
Expand All @@ -43,7 +45,7 @@ class RagConfig(RagPayload):
'error': None,
'max_depth': default_max_depth,
'status': 'running',
'system_message': 'You are an AI assistant of Jitsi, a video conferencing platform. You provide response suggestions to the support agent',
'system_message': assistant_default_system_message,
'urls': ['https://jitsi.github.io/handbook'],
}
]
Expand Down
2 changes: 1 addition & 1 deletion skynet/modules/ttt/assistant/v1/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def create_rag_db(payload: RagPayload, customer_id=Depends(CustomerId()))
"""

store = await get_vector_store()
return await store.create_from_urls(payload, customer_id)
return await store.update_from_urls(payload, customer_id)


@api_version(1)
Expand Down
22 changes: 20 additions & 2 deletions skynet/modules/ttt/rag/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
log = get_logger(__name__)


def bypass_ingestion(existing_payload: RagPayload, new_payload: RagPayload):
existing_urls = existing_payload.urls
new_urls = new_payload.urls

existing_urls.sort()
new_urls.sort()

return (
list(dict.fromkeys(existing_urls)) == list(dict.fromkeys(new_urls))
and existing_payload.max_depth == new_payload.max_depth
)


class SkynetVectorStore(ABC):
embedding = HuggingFaceEmbeddings(
model_name=embeddings_model_path, model_kwargs={'device': 'cpu', 'trust_remote_code': True}
Expand Down Expand Up @@ -111,13 +124,18 @@ async def workflow(self, payload: RagPayload, store_id: str):

await db.lrem(RUNNING_RAG_KEY, 0, store_id)

async def create_from_urls(self, payload: RagPayload, store_id: str) -> Optional[RagConfig]:
async def update_from_urls(self, payload: RagPayload, store_id: str) -> Optional[RagConfig]:
"""
Create a vector store with the given id, using the documents crawled from the given URL.
"""

config = await self.get_config(store_id)

if store_id in await db.lrange(RUNNING_RAG_KEY, 0, -1):
return await self.get_config(store_id)
return config

if config and config.status == RagStatus.SUCCESS and bypass_ingestion(config, payload):
return await self.update_config(store_id, system_message=payload.system_message)

await db.rpush(RUNNING_RAG_KEY, store_id)
config = RagConfig(urls=payload.urls, max_depth=payload.max_depth)
Expand Down

0 comments on commit a96150b

Please sign in to comment.