|
| 1 | +from datetime import datetime |
1 | 2 | from typing import Any, Dict
|
2 | 3 | from langchain_core.callbacks import BaseCallbackHandler
|
3 | 4 | from langchain_core.outputs import LLMResult
|
4 | 5 | from langchain_core.messages.tool import ToolMessage
|
| 6 | +import boto3 |
5 | 7 | import json
|
| 8 | +import os |
6 | 9 |
|
7 | 10 | class MetricsCallbackHandler(BaseCallbackHandler):
|
8 |
| - def __init__(self, *args, **kwargs): |
9 |
| - self.accumulator = {} |
10 |
| - self.answers = [] |
11 |
| - self.artifacts = [] |
12 |
| - super().__init__(*args, **kwargs) |
13 |
| - |
14 |
| - def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]): |
15 |
| - if response is None: |
16 |
| - return |
17 |
| - |
18 |
| - if not response.generations or not response.generations[0]: |
19 |
| - return |
20 |
| - |
21 |
| - for generation in response.generations[0]: |
22 |
| - if generation.text != "": |
23 |
| - self.answers.append(generation.text) |
24 |
| - |
25 |
| - if not hasattr(generation, 'message') or generation.message is None: |
26 |
| - continue |
27 |
| - |
28 |
| - metadata = getattr(generation.message, 'usage_metadata', None) |
29 |
| - if metadata is None: |
30 |
| - continue |
31 |
| - |
32 |
| - for k, v in metadata.items(): |
33 |
| - self.accumulator[k] = self.accumulator.get(k, 0) + v |
34 |
| - |
35 |
| - def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]): |
| 11 | + def __init__(self, log_stream = None, *args, extra_data = {}, **kwargs): |
| 12 | + self.accumulator = {} |
| 13 | + self.answers = [] |
| 14 | + self.artifacts = [] |
| 15 | + self.log_stream = log_stream |
| 16 | + self.extra_data = extra_data |
| 17 | + super().__init__(*args, **kwargs) |
| 18 | + |
| 19 | + def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]): |
| 20 | + if response is None: |
| 21 | + return |
| 22 | + |
| 23 | + if not response.generations or not response.generations[0]: |
| 24 | + return |
| 25 | + |
| 26 | + for generation in response.generations[0]: |
| 27 | + if generation.text != "": |
| 28 | + self.answers.append(generation.text) |
| 29 | + |
| 30 | + if not hasattr(generation, "message") or generation.message is None: |
| 31 | + continue |
| 32 | + |
| 33 | + metadata = getattr(generation.message, "usage_metadata", None) |
| 34 | + if metadata is None: |
| 35 | + continue |
| 36 | + |
| 37 | + for k, v in metadata.items(): |
| 38 | + self.accumulator[k] = self.accumulator.get(k, 0) + v |
| 39 | + |
| 40 | + def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]): |
36 | 41 | content = output.content
|
37 | 42 | if isinstance(content, str):
|
38 | 43 | try:
|
39 | 44 | content = json.loads(content)
|
40 | 45 | except json.decoder.JSONDecodeError as e:
|
41 |
| - print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}") |
| 46 | + print( |
| 47 | + f"Invalid json ({e}) returned from {output.name} tool: {output.content}" |
| 48 | + ) |
42 | 49 | return
|
43 |
| - |
| 50 | + |
44 | 51 | match output.name:
|
45 | 52 | case "aggregate":
|
46 |
| - self.artifacts.append({"type": "aggregation", "artifact": content.get("aggregation_result", {})}) |
| 53 | + self.artifacts.append( |
| 54 | + { |
| 55 | + "type": "aggregation", |
| 56 | + "artifact": content.get("aggregation_result", {}), |
| 57 | + } |
| 58 | + ) |
47 | 59 | case "search":
|
48 | 60 | source_urls = [doc.get("api_link") for doc in content]
|
49 | 61 | self.artifacts.append({"type": "source_urls", "artifact": source_urls})
|
50 | 62 | case "summarize":
|
51 | 63 | print(output)
|
| 64 | + |
| 65 | + def log_metrics(self): |
| 66 | + if self.log_stream is None: |
| 67 | + return |
| 68 | + |
| 69 | + log_group = os.getenv("METRICS_LOG_GROUP") |
| 70 | + if log_group and ensure_log_stream_exists(log_group, self.log_stream): |
| 71 | + client = log_client() |
| 72 | + message = { |
| 73 | + "answer": self.answers, |
| 74 | + "artifacts": self.artifacts, |
| 75 | + "token_counts": self.accumulator, |
| 76 | + } |
| 77 | + message.update(self.extra_data) |
| 78 | + |
| 79 | + log_events = [ |
| 80 | + { |
| 81 | + "timestamp": timestamp(), |
| 82 | + "message": json.dumps(message), |
| 83 | + } |
| 84 | + ] |
| 85 | + client.put_log_events( |
| 86 | + logGroupName=log_group, logStreamName=self.log_stream, logEvents=log_events |
| 87 | + ) |
| 88 | + |
| 89 | + |
| 90 | +def ensure_log_stream_exists(log_group, log_stream): |
| 91 | + client = log_client() |
| 92 | + try: |
| 93 | + print( |
| 94 | + client.create_log_stream(logGroupName=log_group, logStreamName=log_stream) |
| 95 | + ) |
| 96 | + return True |
| 97 | + except client.exceptions.ResourceAlreadyExistsException: |
| 98 | + return True |
| 99 | + except Exception: |
| 100 | + print(f"Could not create log stream: {log_group}:{log_stream}") |
| 101 | + return False |
| 102 | + |
| 103 | + |
| 104 | +def log_client(): |
| 105 | + return boto3.client("logs", region_name=os.getenv("AWS_REGION", "us-east-1")) |
| 106 | + |
| 107 | + |
| 108 | +def timestamp(): |
| 109 | + return round(datetime.timestamp(datetime.now()) * 1000) |
0 commit comments