Skip to content

Commit fdcf8fb

Browse files
committed
Add selective checkpointer
1 parent 1274ff0 commit fdcf8fb

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed

chat/src/core/setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from persistence.s3_checkpointer import S3Checkpointer
1+
from persistence.selective_checkpointer import SelectiveCheckpointer
22
from search.opensearch_neural_search import OpenSearchNeuralSearch
33
from langchain_aws import ChatBedrock
44
from langchain_core.language_models.base import BaseModel
@@ -14,7 +14,7 @@ def chat_model(**kwargs) -> BaseModel:
1414

1515
def checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
1616
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME")
17-
return S3Checkpointer(bucket_name=checkpoint_bucket, **kwargs)
17+
return SelectiveCheckpointer(bucket_name=checkpoint_bucket, **kwargs)
1818

1919
def prefix(value):
2020
env_prefix = os.getenv("ENV_PREFIX")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import os
2+
from typing import Optional
3+
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
4+
from langchain_core.runnables import RunnableConfig
5+
from langgraph.checkpoint.base import (
6+
ChannelVersions,
7+
Checkpoint,
8+
CheckpointMetadata
9+
)
10+
from persistence.s3_checkpointer import S3Checkpointer
11+
12+
def _is_tool_message(message):
13+
if isinstance(message, ToolMessage):
14+
return True
15+
if isinstance(message, AIMessage) and message.additional_kwargs.get('stop_reason', '') == 'tool_use':
16+
return True
17+
return False
18+
19+
def _prune_messages(messages):
20+
if messages is None:
21+
return messages
22+
23+
last_human_message = None
24+
for i, message in reversed(list(enumerate(messages))):
25+
if isinstance(message, HumanMessage):
26+
last_human_message = i
27+
break
28+
29+
if last_human_message is not None:
30+
return [
31+
msg
32+
for i, msg in enumerate(messages)
33+
if not _is_tool_message(msg) or i > last_human_message
34+
]
35+
36+
return messages
37+
38+
class SelectiveCheckpointer(S3Checkpointer):
39+
"""S3 Checkpointer that discards ToolMessages from previous checkpoints."""
40+
41+
def __init__(
42+
self,
43+
bucket_name: str,
44+
region_name: str = os.getenv("AWS_REGION"),
45+
endpoint_url: Optional[str] = None,
46+
compression: Optional[str] = None,
47+
retain_history: Optional[bool] = False,
48+
) -> None:
49+
super().__init__(bucket_name, region_name, endpoint_url, compression)
50+
self.retain_history = retain_history
51+
52+
def put(
53+
self,
54+
config: RunnableConfig,
55+
checkpoint: Checkpoint,
56+
metadata: CheckpointMetadata,
57+
new_versions: ChannelVersions,
58+
) -> RunnableConfig:
59+
# Remove previous checkpoints
60+
thread_id = config["configurable"]["thread_id"]
61+
if not self.retain_history:
62+
self.delete_checkpoints(thread_id)
63+
64+
# Remove all ToolMessages except those related to the most
65+
# recent question (HumanMessage)
66+
messages = checkpoint.get("channel_values", {}).get("messages", [])
67+
checkpoint["channel_values"]["messages"] = _prune_messages(messages)
68+
69+
return super().put(config, checkpoint, metadata, new_versions)

chat/test/core/test_setup.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@ def test_chat_model_returns_bedrock_instance(self):
1616

1717
class TestCheckpointSaver(unittest.TestCase):
1818
@patch.dict(os.environ, {"CHECKPOINT_BUCKET_NAME": "test-bucket"})
19-
@patch("core.setup.S3Checkpointer")
20-
def test_checkpoint_saver_initialization(self, mock_s3_checkpointer):
19+
@patch("core.setup.SelectiveCheckpointer")
20+
def test_checkpoint_saver_initialization(self, mock_checkpointer):
2121
kwargs = {"prefix": "test"}
2222
result = checkpoint_saver(**kwargs)
2323

24-
mock_s3_checkpointer.assert_called_once_with(
24+
mock_checkpointer.assert_called_once_with(
2525
bucket_name="test-bucket",
2626
**kwargs
2727
)
28-
self.assertEqual(result, mock_s3_checkpointer.return_value)
28+
self.assertEqual(result, mock_checkpointer.return_value)
2929

3030
class TestPrefix(unittest.TestCase):
3131
def test_prefix_with_env_prefix(self):

0 commit comments

Comments
 (0)