Skip to content

Commit 97ebd8a

Browse files
committed
Save tokens by filtering fields in documents returned from index.
Update the prompt to limit tool calls and avoid recursion errors.
1 parent c712fe4 commit 97ebd8a

File tree

2 files changed

+56
-33
lines changed

2 files changed

+56
-33
lines changed

chat/src/agent/search_agent.py

+9-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import json
12
from typing import Literal, List
2-
from langchain_aws import ChatBedrock
33
from langchain_core.messages import HumanMessage, ToolMessage
44
from agent.tools import aggregate, discover_fields, search, retrieve_documents
55
from langchain_core.messages.base import BaseMessage
@@ -9,6 +9,7 @@
99
from langgraph.graph import END, START, StateGraph, MessagesState
1010
from langgraph.prebuilt import ToolNode
1111
from langgraph.errors import GraphRecursionError
12+
from core.document import minimize_documents
1213
from core.setup import checkpoint_saver
1314
from agent.callbacks.socket import SocketCallbackHandler
1415
from typing import Optional
@@ -18,7 +19,9 @@
1819
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
1920
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown
2021
links using the document's canonical_link field. Do not include intermediate messages explaining your process. If the user's
21-
question is unclear, ask for clarification.
22+
question is unclear, ask for clarification. Use no more than 6 tool calls. If you still cannot answer the question after 6
23+
tool calls, summarize the information you have gathered so far and suggest ways in which the user might narrow the scope
24+
of their question to make it more answerable.
2225
"""
2326

2427
MAX_RECURSION_LIMIT = 16
@@ -27,7 +30,6 @@ class SearchWorkflow:
2730
def __init__(self, model: BaseModel, system_message: str, metrics = None):
2831
self.metrics = metrics
2932
self.model = model
30-
self.summarization_model = ChatBedrock(model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", streaming=False)
3133
self.system_message = system_message
3234

3335
def should_continue(self, state: MessagesState) -> Literal["tools", END]:
@@ -41,44 +43,18 @@ def should_continue(self, state: MessagesState) -> Literal["tools", END]:
4143

4244
def summarize(self, state: MessagesState):
4345
messages = state["messages"]
44-
question = messages[0].content
4546
last_message = messages[-1]
4647
if last_message.name not in ["search", "retrieve_documents"]:
4748
return {"messages": messages}
4849

49-
summary_prompt = f"""
50-
Summarize the following content. Return ONLY a valid JSON list where each
51-
document is replaced with a new dict with the `id`, `title`, `canonical_link`,
52-
and `api_link` fields, as well as any other information from the original that
53-
might be useful in answering questions. Flatten any nested structures to retain
54-
only semantically useful information (e.g., [{{'id': id1, 'label': label1}},
55-
{{'id': id2, 'label': label2}}, {{'id': id3, 'label': label3}}] becomes
56-
[label1, label2, label3]). Be judicious about what information is retained,
57-
but keep enough to answer the question "{question}" and any likely followups.
58-
59-
It is extremely important that you return only the valid, parsable summarized
60-
JSON with no additional text or explanation, no markdown code fencing, and all
61-
unnecessary whitespace removed.
62-
63-
Prioritize speed over comprehensiveness.
64-
65-
{last_message.content}
66-
"""
67-
68-
config = {
69-
"callbacks": [self.metrics] if self.metrics else [],
70-
"metadata": {"source": "summarize"}
71-
}
72-
7350
start_time = time.time()
74-
75-
summary = self.summarization_model.invoke([HumanMessage(content=summary_prompt)], config=config)
76-
51+
content = minimize_documents(json.loads(last_message.content))
52+
content = json.dumps(content, separators=(',', ':'))
7753
end_time = time.time()
7854
elapsed_time = end_time - start_time
79-
print(f'Condensed {len(last_message.content)} bytes to {len(summary.content)} bytes in {elapsed_time:.2f} seconds')
55+
print(f'Condensed {len(last_message.content)} bytes to {len(content)} bytes in {elapsed_time:.2f} seconds. Savings: {100 * (1 - len(content) / len(last_message.content)):.2f}%')
8056

81-
last_message.content = summary.content
57+
last_message.content = content
8258

8359
return {"messages": messages}
8460

chat/src/core/document.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
def minimize_documents(docs):
2+
return [minimize_document(doc) for doc in docs]
3+
4+
def minimize_document(doc):
5+
return {
6+
'id': doc.get('id'),
7+
'title': minimize(doc.get('title')),
8+
'alternate_title': minimize(doc.get('alternate_title')),
9+
'description': minimize(doc.get('description')),
10+
'abstract': minimize(doc.get('abstract')),
11+
'subject': labels_only(doc.get('subject')),
12+
'date_created': minimize(doc.get('date_created')),
13+
'provenance': minimize(doc.get('provenance')),
14+
'collection': minimize(doc.get('collection', {}).get('title')),
15+
'creator': labels_only(doc.get('creator')),
16+
'contributor': labels_only(doc.get('contributor')),
17+
18+
'work_type': minimize(doc.get('work_type')),
19+
'genre': labels_only(doc.get('genre')),
20+
'scope_and_contents': minimize(doc.get('scope_and_contents')),
21+
'table_of_contents': minimize(doc.get('table_of_contents')),
22+
'cultural_context': minimize(doc.get('cultural_context')),
23+
'notes': minimize(doc.get('notes')),
24+
'keywords': minimize(doc.get('keywords')),
25+
'visibility': minimize(doc.get('visibility')),
26+
'canonical_link': minimize(doc.get('canonical_link')),
27+
28+
'rights_statement': label_only(doc.get('rights_statement')),
29+
}
30+
31+
def labels_only(list_of_fields):
32+
return minimize([label_only(field) for field in list_of_fields])
33+
34+
def label_only(field):
35+
if field is None:
36+
return None
37+
return field.get('label_with_role', field.get('label', None))
38+
39+
def minimize(field):
40+
try:
41+
if field is None:
42+
return None
43+
if len(field) == 0:
44+
return None
45+
return field
46+
except TypeError:
47+
return field

0 commit comments

Comments
 (0)