Skip to content

Commit

Permalink
feat: support configuring assistant system message (#156)
Browse files Browse the repository at this point in the history
* feat: support configuring assistant system message

* code review

* tweak prompts

---------

Co-authored-by: Avram Tudor <tudor.avram@8x8.com>
  • Loading branch information
quitrk and Avram Tudor authored Feb 20, 2025
1 parent dea42aa commit e487c28
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 16 deletions.
6 changes: 6 additions & 0 deletions skynet/modules/ttt/assistant/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
assistant_default_system_message = "You are an AI assistant who provides the next answer in a given conversation."
assistant_rag_question_extractor = """
Based on the provided text, formulate a proper question for RAG.
Start your response with "Response:".
"""
assistant_limit_data_to_rag = "Only respond based on the information provided below."
25 changes: 25 additions & 0 deletions skynet/modules/ttt/assistant/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from skynet.modules.ttt.assistant.constants import assistant_default_system_message, assistant_limit_data_to_rag


def get_assistant_chat_messages(
use_rag: bool,
use_only_rag_data: bool,
text: str,
prompt: str,
system_message: str,
):
messages = [('system', system_message or assistant_default_system_message)]

if use_rag:
if use_only_rag_data:
messages.append(('system', assistant_limit_data_to_rag))

messages.append(('system', '{context}'))

if text:
messages.append(('human', text))

if prompt:
messages.append(('human', prompt))

return messages
17 changes: 15 additions & 2 deletions skynet/modules/ttt/assistant/v1/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel

from skynet.modules.ttt.summaries.v1.models import DocumentPayload
from skynet.modules.ttt.summaries.v1.models import DocumentPayload, HintType

default_max_depth = 5

Expand All @@ -15,12 +15,19 @@ class RagStatus(Enum):


class RagPayload(BaseModel):
system_message: Optional[str] = None
max_depth: Optional[int] = default_max_depth
urls: list[str]

model_config = {
'json_schema_extra': {
'examples': [{'urls': ['https://jitsi.github.io/handbook'], 'max_depth': default_max_depth}]
'examples': [
{
'urls': ['https://jitsi.github.io/handbook'],
'max_depth': default_max_depth,
'system_message': 'You are an AI assistant of Jitsi, a video conferencing platform. You provide response suggestions to the support agent',
}
]
}
}

Expand All @@ -36,6 +43,7 @@ class RagConfig(RagPayload):
'error': None,
'max_depth': default_max_depth,
'status': 'running',
'system_message': 'You are an AI assistant of Jitsi, a video conferencing platform. You provide response suggestions to the support agent',
'urls': ['https://jitsi.github.io/handbook'],
}
]
Expand All @@ -44,13 +52,18 @@ class RagConfig(RagPayload):


class AssistantDocumentPayload(DocumentPayload):
hint: HintType = HintType.CONVERSATION
use_only_rag_data: bool = False

model_config = {
'json_schema_extra': {
'examples': [
{
'text': 'User provided context here (will be appended to the RAG one)',
'prompt': 'User prompt here',
'max_completion_tokens': None,
'hint': 'conversation',
'use_only_rag_data': False, # If True and a vector store is available, only the RAG data will be used for assistance
}
]
}
Expand Down
36 changes: 22 additions & 14 deletions skynet/modules/ttt/processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from operator import itemgetter

from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.prompts import ChatPromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain.text_splitter import RecursiveCharacterTextSplitter
Expand All @@ -26,6 +26,8 @@
use_oci,
)
from skynet.logs import get_logger
from skynet.modules.ttt.assistant.constants import assistant_rag_question_extractor
from skynet.modules.ttt.assistant.utils import get_assistant_chat_messages

from skynet.modules.ttt.rag.app import get_vector_store
from skynet.modules.ttt.summaries.prompts.action_items import (
Expand Down Expand Up @@ -135,6 +137,8 @@ async def assist(payload: DocumentPayload, customer_id: str | None = None, model

store = await get_vector_store()
vector_store = await store.get(customer_id)
config = await store.get_config(customer_id)
question = payload.prompt

base_retriever = vector_store.as_retriever(search_kwargs={'k': 3}) if vector_store else None
retriever = (
Expand All @@ -143,26 +147,30 @@ async def assist(payload: DocumentPayload, customer_id: str | None = None, model
else None
)

prompt_template = '''
Context: {context}
Additional context: {additional_context}
User prompt: {user_prompt}
'''
if retriever and payload.text:
question_payload = DocumentPayload(**(payload.model_dump() | {'prompt': assistant_rag_question_extractor}))
question = await summarize(question_payload, JobType.SUMMARY, current_model)

prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'user_prompt', 'additional_context'])
log.info(f'Using question: {question}')

template = ChatPromptTemplate(
get_assistant_chat_messages(
use_rag=bool(retriever),
use_only_rag_data=payload.use_only_rag_data,
text=payload.text,
prompt=payload.prompt,
system_message=config.system_message,
)
)

rag_chain = (
{
'context': (itemgetter('user_prompt') | retriever | format_docs) if retriever else lambda x: '',
'user_prompt': itemgetter('user_prompt'),
'additional_context': itemgetter('additional_context'),
}
| prompt
{'context': (itemgetter('question') | retriever | format_docs) if retriever else lambda _: ''}
| template
| current_model
| StrOutputParser()
)

return await rag_chain.ainvoke(input={'user_prompt': payload.prompt, 'additional_context': payload.text})
return await rag_chain.ainvoke(input={'question': question})


async def summarize(payload: DocumentPayload, job_type: JobType, model: BaseChatModel = None) -> str:
Expand Down

0 comments on commit e487c28

Please sign in to comment.