1
+ import os
2
+ from typing import Optional
3
+ from langchain_core .messages import AIMessage , HumanMessage , ToolMessage
4
+ from langchain_core .runnables import RunnableConfig
5
+ from langgraph .checkpoint .base import (
6
+ ChannelVersions ,
7
+ Checkpoint ,
8
+ CheckpointMetadata
9
+ )
10
+ from persistence .s3_checkpointer import S3Checkpointer
11
+
12
+
13
+ # Split messages into interactions, each one starting with a HumanMessage
14
+ def _split_interactions (messages ):
15
+ if messages is None :
16
+ return []
17
+
18
+ interactions = []
19
+ current_interaction = []
20
+
21
+ for message in messages :
22
+ if isinstance (message , HumanMessage ) and current_interaction :
23
+ interactions .append (current_interaction )
24
+ current_interaction = []
25
+ current_interaction .append (message )
26
+
27
+ if current_interaction :
28
+ interactions .append (current_interaction )
29
+
30
+ return interactions
31
+
32
+ def _is_tool_message (message ):
33
+ if isinstance (message , ToolMessage ):
34
+ return True
35
+ if isinstance (message , AIMessage ) and message .response_metadata .get ('stop_reason' , '' ) == 'tool_use' :
36
+ return True
37
+ return False
38
+
39
+ def _prune_messages (messages ):
40
+ interactions = _split_interactions (messages )
41
+ # Remove all tool-related messages except those related to the most recent interaction
42
+ for i , interaction in enumerate (interactions [:- 1 ]):
43
+ interactions [i ] = [message for message in interaction if not _is_tool_message (message )]
44
+
45
+ # Return the flattened list of messages
46
+ return [message for interaction in interactions for message in interaction ]
47
+
48
+ class SelectiveCheckpointer (S3Checkpointer ):
49
+ """S3 Checkpointer that discards ToolMessages from previous checkpoints."""
50
+
51
+ def __init__ (
52
+ self ,
53
+ bucket_name : str ,
54
+ region_name : str = os .getenv ("AWS_REGION" ),
55
+ endpoint_url : Optional [str ] = None ,
56
+ compression : Optional [str ] = None ,
57
+ retain_history : Optional [bool ] = True ,
58
+ ) -> None :
59
+ super ().__init__ (bucket_name , region_name , endpoint_url , compression )
60
+ self .retain_history = retain_history
61
+
62
+ def put (
63
+ self ,
64
+ config : RunnableConfig ,
65
+ checkpoint : Checkpoint ,
66
+ metadata : CheckpointMetadata ,
67
+ new_versions : ChannelVersions ,
68
+ ) -> RunnableConfig :
69
+ # Remove previous checkpoints
70
+ thread_id = config ["configurable" ]["thread_id" ]
71
+ if not self .retain_history :
72
+ self .delete_checkpoints (thread_id )
73
+
74
+ # Remove all ToolMessages except those related to the most
75
+ # recent question (HumanMessage)
76
+ messages = checkpoint .get ("channel_values" , {}).get ("messages" , [])
77
+ checkpoint ["channel_values" ]["messages" ] = _prune_messages (messages )
78
+
79
+ return super ().put (config , checkpoint , metadata , new_versions )
0 commit comments