1
+ import os
2
+ from typing import Optional
3
+ from langchain_core .messages import AIMessage , HumanMessage , ToolMessage
4
+ from langchain_core .runnables import RunnableConfig
5
+ from langgraph .checkpoint .base import (
6
+ ChannelVersions ,
7
+ Checkpoint ,
8
+ CheckpointMetadata
9
+ )
10
+ from persistence .s3_checkpointer import S3Checkpointer
11
+
12
+ def _is_tool_message (message ):
13
+ if isinstance (message , ToolMessage ):
14
+ return True
15
+ if isinstance (message , AIMessage ) and message .additional_kwargs .get ('stop_reason' , '' ) == 'tool_use' :
16
+ return True
17
+ return False
18
+
19
+ def _prune_messages (messages ):
20
+ if messages is None :
21
+ return messages
22
+
23
+ last_human_message = None
24
+ for i , message in reversed (list (enumerate (messages ))):
25
+ if isinstance (message , HumanMessage ):
26
+ last_human_message = i
27
+ break
28
+
29
+ if last_human_message is not None :
30
+ return [
31
+ msg
32
+ for i , msg in enumerate (messages )
33
+ if not _is_tool_message (msg ) or i > last_human_message
34
+ ]
35
+
36
+ return messages
37
+
38
+ class SelectiveCheckpointer (S3Checkpointer ):
39
+ """S3 Checkpointer that discards ToolMessages from previous checkpoints."""
40
+
41
+ def __init__ (
42
+ self ,
43
+ bucket_name : str ,
44
+ region_name : str = os .getenv ("AWS_REGION" ),
45
+ endpoint_url : Optional [str ] = None ,
46
+ compression : Optional [str ] = None ,
47
+ retain_history : Optional [bool ] = False ,
48
+ ) -> None :
49
+ super ().__init__ (bucket_name , region_name , endpoint_url , compression )
50
+ self .retain_history = retain_history
51
+
52
+ def put (
53
+ self ,
54
+ config : RunnableConfig ,
55
+ checkpoint : Checkpoint ,
56
+ metadata : CheckpointMetadata ,
57
+ new_versions : ChannelVersions ,
58
+ ) -> RunnableConfig :
59
+ # Remove previous checkpoints
60
+ thread_id = config ["configurable" ]["thread_id" ]
61
+ if not self .retain_history :
62
+ self .delete_checkpoints (thread_id )
63
+
64
+ # Remove all ToolMessages except those related to the most
65
+ # recent question (HumanMessage)
66
+ messages = checkpoint .get ("channel_values" , {}).get ("messages" , [])
67
+ checkpoint ["channel_values" ]["messages" ] = _prune_messages (messages )
68
+
69
+ return super ().put (config , checkpoint , metadata , new_versions )
0 commit comments