Skip to content

Commit 031a468

Browse files
committed
Move metrics logging to the MetricsCallbackHandler
1 parent a4334fd commit 031a468

File tree

6 files changed

+128
-119
lines changed

6 files changed

+128
-119
lines changed

chat/src/agent/callbacks/metrics.py

+89-31
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,109 @@
1+
from datetime import datetime
12
from typing import Any, Dict
23
from langchain_core.callbacks import BaseCallbackHandler
34
from langchain_core.outputs import LLMResult
45
from langchain_core.messages.tool import ToolMessage
6+
import boto3
57
import json
8+
import os
69

710
class MetricsCallbackHandler(BaseCallbackHandler):
8-
def __init__(self, *args, **kwargs):
9-
self.accumulator = {}
10-
self.answers = []
11-
self.artifacts = []
12-
super().__init__(*args, **kwargs)
13-
14-
def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
15-
if response is None:
16-
return
17-
18-
if not response.generations or not response.generations[0]:
19-
return
20-
21-
for generation in response.generations[0]:
22-
if generation.text != "":
23-
self.answers.append(generation.text)
24-
25-
if not hasattr(generation, 'message') or generation.message is None:
26-
continue
27-
28-
metadata = getattr(generation.message, 'usage_metadata', None)
29-
if metadata is None:
30-
continue
31-
32-
for k, v in metadata.items():
33-
self.accumulator[k] = self.accumulator.get(k, 0) + v
34-
35-
def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
11+
def __init__(self, log_stream = None, *args, extra_data = {}, **kwargs):
12+
self.accumulator = {}
13+
self.answers = []
14+
self.artifacts = []
15+
self.log_stream = log_stream
16+
self.extra_data = extra_data
17+
super().__init__(*args, **kwargs)
18+
19+
def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
20+
if response is None:
21+
return
22+
23+
if not response.generations or not response.generations[0]:
24+
return
25+
26+
for generation in response.generations[0]:
27+
if generation.text != "":
28+
self.answers.append(generation.text)
29+
30+
if not hasattr(generation, "message") or generation.message is None:
31+
continue
32+
33+
metadata = getattr(generation.message, "usage_metadata", None)
34+
if metadata is None:
35+
continue
36+
37+
for k, v in metadata.items():
38+
self.accumulator[k] = self.accumulator.get(k, 0) + v
39+
40+
def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
3641
content = output.content
3742
if isinstance(content, str):
3843
try:
3944
content = json.loads(content)
4045
except json.decoder.JSONDecodeError as e:
41-
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
46+
print(
47+
f"Invalid json ({e}) returned from {output.name} tool: {output.content}"
48+
)
4249
return
43-
50+
4451
match output.name:
4552
case "aggregate":
46-
self.artifacts.append({"type": "aggregation", "artifact": content.get("aggregation_result", {})})
53+
self.artifacts.append(
54+
{
55+
"type": "aggregation",
56+
"artifact": content.get("aggregation_result", {}),
57+
}
58+
)
4759
case "search":
4860
source_urls = [doc.get("api_link") for doc in content]
4961
self.artifacts.append({"type": "source_urls", "artifact": source_urls})
5062
case "summarize":
5163
print(output)
64+
65+
def log_metrics(self):
66+
if self.log_stream is None:
67+
return
68+
69+
log_group = os.getenv("METRICS_LOG_GROUP")
70+
if log_group and ensure_log_stream_exists(log_group, self.log_stream):
71+
client = log_client()
72+
message = {
73+
"answer": self.answers,
74+
"artifacts": self.artifacts,
75+
"token_counts": self.accumulator,
76+
}
77+
message.update(self.extra_data)
78+
79+
log_events = [
80+
{
81+
"timestamp": timestamp(),
82+
"message": json.dumps(message),
83+
}
84+
]
85+
client.put_log_events(
86+
logGroupName=log_group, logStreamName=self.log_stream, logEvents=log_events
87+
)
88+
89+
90+
def ensure_log_stream_exists(log_group, log_stream):
91+
client = log_client()
92+
try:
93+
print(
94+
client.create_log_stream(logGroupName=log_group, logStreamName=log_stream)
95+
)
96+
return True
97+
except client.exceptions.ResourceAlreadyExistsException:
98+
return True
99+
except Exception:
100+
print(f"Could not create log stream: {log_group}:{log_stream}")
101+
return False
102+
103+
104+
def log_client():
105+
return boto3.client("logs", region_name=os.getenv("AWS_REGION", "us-east-1"))
106+
107+
108+
def timestamp():
109+
return round(datetime.timestamp(datetime.now()) * 1000)

chat/src/agent/search_agent.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from core.setup import checkpoint_saver
1313
from agent.callbacks.socket import SocketCallbackHandler
1414
from typing import Optional
15+
import time
1516

1617
DEFAULT_SYSTEM_MESSAGE = """
1718
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
@@ -23,7 +24,8 @@
2324
MAX_RECURSION_LIMIT = 16
2425

2526
class SearchWorkflow:
26-
def __init__(self, model: BaseModel, system_message: str):
27+
def __init__(self, model: BaseModel, system_message: str, metrics = None):
28+
self.metrics = metrics
2729
self.model = model
2830
self.summarization_model = ChatBedrock(model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", streaming=False)
2931
self.system_message = system_message
@@ -57,15 +59,25 @@ def summarize(self, state: MessagesState):
5759
It is extremely important that you return only the valid, parsable summarized
5860
JSON with no additional text or explanation, no markdown code fencing, and all
5961
unnecessary whitespace removed.
62+
63+
Prioritize speed over comprehensiveness.
6064
6165
{last_message.content}
6266
"""
67+
6368
config = {
64-
"callbacks": [],
69+
"callbacks": [self.metrics] if self.metrics else [],
6570
"metadata": {"source": "summarize"}
6671
}
72+
73+
start_time = time.time()
74+
6775
summary = self.summarization_model.invoke([HumanMessage(content=summary_prompt)], config=config)
68-
print(f'Condensed {len(last_message.content)} bytes to {len(summary.content)} bytes via summarization')
76+
77+
end_time = time.time()
78+
elapsed_time = end_time - start_time
79+
print(f'Condensed {len(last_message.content)} bytes to {len(summary.content)} bytes in {elapsed_time:.2f} seconds')
80+
6981
last_message.content = summary.content
7082

7183
return {"messages": messages}
@@ -81,6 +93,7 @@ def __init__(
8193
self,
8294
model: BaseModel,
8395
*,
96+
metrics = None,
8497
system_message: str = DEFAULT_SYSTEM_MESSAGE,
8598
**kwargs
8699
):
@@ -92,7 +105,7 @@ def __init__(
92105
except NotImplementedError:
93106
pass
94107

95-
self.workflow_logic = SearchWorkflow(model=model, system_message=system_message)
108+
self.workflow_logic = SearchWorkflow(model=model, system_message=system_message, metrics=metrics)
96109

97110
# Define a new graph
98111
workflow = StateGraph(MessagesState)

chat/src/agent/tools.py

+2-33
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import json
22

3-
from langchain_core.language_models.chat_models import BaseModel
4-
from langchain_core.messages import HumanMessage
53
from langchain_core.tools import tool
64
from core.setup import opensearch_vector_store
75
from typing import List
@@ -29,7 +27,8 @@ def filter_results(results):
2927
Filters out the embeddings from the results
3028
"""
3129
filtered = []
32-
for doc in results:
30+
for result in results:
31+
doc = result.metadata
3332
if 'embedding' in doc:
3433
doc.pop('embedding')
3534
filtered.append(doc)
@@ -101,33 +100,3 @@ def retrieve_documents(doc_ids: List[str]):
101100
return filter_results(response)
102101
except Exception as e:
103102
return {"error": str(e)}
104-
105-
@tool(response_format="content")
106-
def summarize(content, model: BaseModel):
107-
"""
108-
Summarize content. If content is a list of documents, each document will
109-
be replaced with a summary to reduce the amount of content passed to the agent's
110-
model at each turn. Otherwise, the content will be summarized as a whole.
111-
112-
Args:
113-
content: The content to summarize.
114-
model (BaseModel): The summarization model to use.
115-
116-
Returns:
117-
A new list of documents, pared down.
118-
"""
119-
120-
summary_prompt = f"""
121-
Summarize the following content. If the content is a list of documents
122-
with IDs, replace each document with a new dict with the shape
123-
{'id': doc.id, 'title': doc.title 'content': summary}, where summary is a
124-
concise but semantically meaningful summary of the document content for the
125-
agent to use on subsequent turns. Otherwise, produce a summary of the content
126-
as a whole.
127-
128-
{content}
129-
"""
130-
print(f"Summarizing content: {content}")
131-
summary = model.invoke([HumanMessage(content=summary_prompt)])
132-
print(f"Summarized content: {summary.content}")
133-
return summary.content

chat/src/handlers.py

+12-48
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import boto3
21
import json
32
import logging
4-
import os
53
from core.secrets import load_secrets
6-
from datetime import datetime
74
from core.event_config import EventConfig
85
from honeybadger import honeybadger
96
from agent.search_agent import SearchAgent
@@ -63,59 +60,26 @@ def chat(event, context):
6360
config.socket.send({"type": "error", "message": "Question cannot be blank"})
6461
return {"statusCode": 400, "body": "Question cannot be blank"}
6562

66-
metrics = MetricsCallbackHandler()
63+
log_info = {
64+
"is_dev_team": config.api_token.is_dev_team(),
65+
"is_superuser": config.api_token.is_superuser(),
66+
"k": config.k,
67+
"model": config.model,
68+
"question": config.question,
69+
"ref": config.ref,
70+
}
71+
metrics = MetricsCallbackHandler(context.log_stream_name, extra_data=log_info)
6772
callbacks = [SocketCallbackHandler(config.socket, config.ref), metrics]
6873
model = chat_model(model=config.model, streaming=config.stream_response)
69-
search_agent = SearchAgent(model=model)
70-
74+
search_agent = SearchAgent(model=model, metrics=metrics)
75+
7176
try:
7277
search_agent.invoke(config.question, config.ref, forget=config.forget, docs=config.docs, callbacks=callbacks)
73-
log_metrics(context, metrics, config)
78+
metrics.log_metrics()
7479
except Exception as e:
7580
error_response = {"type": "error", "message": "An unexpected error occurred. Please try again later."}
7681
if config.socket:
7782
config.socket.send(error_response)
7883
raise e
7984

8085
return {"statusCode": 200}
81-
82-
83-
def log_metrics(context, metrics, config):
84-
log_group = os.getenv("METRICS_LOG_GROUP")
85-
log_stream = context.log_stream_name
86-
if log_group and ensure_log_stream_exists(log_group, log_stream):
87-
client = log_client()
88-
log_events = [{
89-
"timestamp": timestamp(),
90-
"message": json.dumps({
91-
"answer": metrics.answers,
92-
"is_dev_team": config.api_token.is_dev_team(),
93-
"is_superuser": config.api_token.is_superuser(),
94-
"k": config.k,
95-
"model": config.model,
96-
"question": config.question,
97-
"ref": config.ref,
98-
"artifacts": metrics.artifacts,
99-
"token_counts": metrics.accumulator,
100-
})
101-
}]
102-
client.put_log_events(
103-
logGroupName=log_group, logStreamName=log_stream, logEvents=log_events
104-
)
105-
106-
def ensure_log_stream_exists(log_group, log_stream):
107-
client = log_client()
108-
try:
109-
print(client.create_log_stream(logGroupName=log_group, logStreamName=log_stream))
110-
return True
111-
except client.exceptions.ResourceAlreadyExistsException:
112-
return True
113-
except Exception:
114-
print(f"Could not create log stream: {log_group}:{log_stream}")
115-
return False
116-
117-
def log_client():
118-
return boto3.client("logs", region_name=os.getenv("AWS_REGION", "us-east-1"))
119-
120-
def timestamp():
121-
return round(datetime.timestamp(datetime.now()) * 1000)

chat/template.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ Resources:
5353
ChatWebSocket:
5454
Type: AWS::ApiGatewayV2::Api
5555
Properties:
56-
Name: ChatWebSocket
5756
ProtocolType: WEBSOCKET
5857
RouteSelectionExpression: "$request.body.message"
5958
CheckpointBucket:

chat/test/agent/test_tools.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,17 @@ def test_discover_fields(self, mock_opensearch):
3333

3434
@patch('agent.tools.opensearch_vector_store')
3535
def test_search(self, mock_opensearch):
36-
mock_results = [{"id": "doc1", "text": "example result"}]
36+
class MockDoc:
37+
def __init__(self, content, metadata):
38+
self.content = content
39+
self.metadata = metadata
40+
41+
expected_results = [{"id": "doc1", "text": "example result"}]
42+
mock_results = [MockDoc(content=doc["id"], metadata=doc) for doc in expected_results]
3743
mock_opensearch.return_value.similarity_search.return_value = mock_results
3844

3945
response = search.invoke("test query")
40-
self.assertEqual(response, mock_results)
46+
self.assertEqual(response, expected_results)
4147

4248
@patch('agent.tools.opensearch_vector_store')
4349
def test_aggregate(self, mock_opensearch):

0 commit comments

Comments
 (0)