9
9
)
10
10
from persistence .s3_checkpointer import S3Checkpointer
11
11
12
+
13
+ # Split messages into interactions, each one starting with a HumanMessage
14
+ def _split_interactions (messages ):
15
+ if messages is None :
16
+ return []
17
+
18
+ interactions = []
19
+ current_interaction = []
20
+
21
+ for message in messages :
22
+ if isinstance (message , HumanMessage ) and current_interaction :
23
+ interactions .append (current_interaction )
24
+ current_interaction = []
25
+ current_interaction .append (message )
26
+
27
+ if current_interaction :
28
+ interactions .append (current_interaction )
29
+
30
+ return interactions
31
+
12
32
def _is_tool_message (message ):
13
33
if isinstance (message , ToolMessage ):
14
34
return True
15
- if isinstance (message , AIMessage ) and message .additional_kwargs .get ('stop_reason' , '' ) == 'tool_use' :
35
+ if isinstance (message , AIMessage ) and message .response_metadata .get ('stop_reason' , '' ) == 'tool_use' :
16
36
return True
17
37
return False
18
38
19
39
def _prune_messages (messages ):
20
- if messages is None :
21
- return messages
22
-
23
- last_human_message = None
24
- for i , message in reversed (list (enumerate (messages ))):
25
- if isinstance (message , HumanMessage ):
26
- last_human_message = i
27
- break
28
-
29
- if last_human_message is not None :
30
- return [
31
- msg
32
- for i , msg in enumerate (messages )
33
- if not _is_tool_message (msg ) or i > last_human_message
34
- ]
40
+ interactions = _split_interactions (messages )
41
+ # Remove all tool-related messages except those related to the most recent interaction
42
+ for i , interaction in enumerate (interactions [:- 1 ]):
43
+ interactions [i ] = [message for message in interaction if not _is_tool_message (message )]
35
44
36
- return messages
45
+ # Return the flattened list of messages
46
+ return [message for interaction in interactions for message in interaction ]
37
47
38
48
class SelectiveCheckpointer (S3Checkpointer ):
39
49
"""S3 Checkpointer that discards ToolMessages from previous checkpoints."""
@@ -44,7 +54,7 @@ def __init__(
44
54
region_name : str = os .getenv ("AWS_REGION" ),
45
55
endpoint_url : Optional [str ] = None ,
46
56
compression : Optional [str ] = None ,
47
- retain_history : Optional [bool ] = False ,
57
+ retain_history : Optional [bool ] = True ,
48
58
) -> None :
49
59
super ().__init__ (bucket_name , region_name , endpoint_url , compression )
50
60
self .retain_history = retain_history
0 commit comments