Skip to content

Commit 504db15

Browse files
authored
Merge pull request #296 from nulib/5373-checkpoint-bloat
2 parents 167ae10 + f97a010 commit 504db15

File tree

3 files changed

+86
-6
lines changed

3 files changed

+86
-6
lines changed

chat/src/core/setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from persistence.s3_checkpointer import S3Checkpointer
1+
from persistence.selective_checkpointer import SelectiveCheckpointer
22
from search.opensearch_neural_search import OpenSearchNeuralSearch
33
from langchain_aws import ChatBedrock
44
from langchain_core.language_models.base import BaseModel
@@ -14,7 +14,7 @@ def chat_model(**kwargs) -> BaseModel:
1414

1515
def checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
1616
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME")
17-
return S3Checkpointer(bucket_name=checkpoint_bucket, **kwargs)
17+
return SelectiveCheckpointer(bucket_name=checkpoint_bucket, retain_history=False, **kwargs)
1818

1919
def prefix(value):
2020
env_prefix = os.getenv("ENV_PREFIX")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import os
2+
from typing import Optional
3+
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
4+
from langchain_core.runnables import RunnableConfig
5+
from langgraph.checkpoint.base import (
6+
ChannelVersions,
7+
Checkpoint,
8+
CheckpointMetadata
9+
)
10+
from persistence.s3_checkpointer import S3Checkpointer
11+
12+
13+
# Split messages into interactions, each one starting with a HumanMessage
14+
def _split_interactions(messages):
15+
if messages is None:
16+
return []
17+
18+
interactions = []
19+
current_interaction = []
20+
21+
for message in messages:
22+
if isinstance(message, HumanMessage) and current_interaction:
23+
interactions.append(current_interaction)
24+
current_interaction = []
25+
current_interaction.append(message)
26+
27+
if current_interaction:
28+
interactions.append(current_interaction)
29+
30+
return interactions
31+
32+
def _is_tool_message(message):
33+
if isinstance(message, ToolMessage):
34+
return True
35+
if isinstance(message, AIMessage) and message.response_metadata.get('stop_reason', '') == 'tool_use':
36+
return True
37+
return False
38+
39+
def _prune_messages(messages):
40+
interactions = _split_interactions(messages)
41+
# Remove all tool-related messages except those related to the most recent interaction
42+
for i, interaction in enumerate(interactions[:-1]):
43+
interactions[i] = [message for message in interaction if not _is_tool_message(message)]
44+
45+
# Return the flattened list of messages
46+
return [message for interaction in interactions for message in interaction]
47+
48+
class SelectiveCheckpointer(S3Checkpointer):
49+
"""S3 Checkpointer that discards ToolMessages from previous checkpoints."""
50+
51+
def __init__(
52+
self,
53+
bucket_name: str,
54+
region_name: str = os.getenv("AWS_REGION"),
55+
endpoint_url: Optional[str] = None,
56+
compression: Optional[str] = None,
57+
retain_history: Optional[bool] = True,
58+
) -> None:
59+
super().__init__(bucket_name, region_name, endpoint_url, compression)
60+
self.retain_history = retain_history
61+
62+
def put(
63+
self,
64+
config: RunnableConfig,
65+
checkpoint: Checkpoint,
66+
metadata: CheckpointMetadata,
67+
new_versions: ChannelVersions,
68+
) -> RunnableConfig:
69+
# Remove previous checkpoints
70+
thread_id = config["configurable"]["thread_id"]
71+
if not self.retain_history:
72+
self.delete_checkpoints(thread_id)
73+
74+
# Remove all ToolMessages except those related to the most
75+
# recent question (HumanMessage)
76+
messages = checkpoint.get("channel_values", {}).get("messages", [])
77+
checkpoint["channel_values"]["messages"] = _prune_messages(messages)
78+
79+
return super().put(config, checkpoint, metadata, new_versions)

chat/test/core/test_setup.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@ def test_chat_model_returns_bedrock_instance(self):
1616

1717
class TestCheckpointSaver(unittest.TestCase):
1818
@patch.dict(os.environ, {"CHECKPOINT_BUCKET_NAME": "test-bucket"})
19-
@patch("core.setup.S3Checkpointer")
20-
def test_checkpoint_saver_initialization(self, mock_s3_checkpointer):
19+
@patch("core.setup.SelectiveCheckpointer")
20+
def test_checkpoint_saver_initialization(self, mock_checkpointer):
2121
kwargs = {"prefix": "test"}
2222
result = checkpoint_saver(**kwargs)
2323

24-
mock_s3_checkpointer.assert_called_once_with(
24+
mock_checkpointer.assert_called_once_with(
2525
bucket_name="test-bucket",
26+
retain_history=False,
2627
**kwargs
2728
)
28-
self.assertEqual(result, mock_s3_checkpointer.return_value)
29+
self.assertEqual(result, mock_checkpointer.return_value)
2930

3031
class TestPrefix(unittest.TestCase):
3132
def test_prefix_with_env_prefix(self):

0 commit comments

Comments
 (0)