Skip to content

Commit f97a010

Browse files
committed
Change how message filtering happens
Split checkpoint into interactions and filter all but the last one Identify tool-related AIMessages by response_metadata
1 parent fdcf8fb commit f97a010

File tree

3 files changed

+30
-19
lines changed

3 files changed

+30
-19
lines changed

chat/src/core/setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def chat_model(**kwargs) -> BaseModel:
1414

1515
def checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
1616
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME")
17-
return SelectiveCheckpointer(bucket_name=checkpoint_bucket, **kwargs)
17+
return SelectiveCheckpointer(bucket_name=checkpoint_bucket, retain_history=False, **kwargs)
1818

1919
def prefix(value):
2020
env_prefix = os.getenv("ENV_PREFIX")

chat/src/persistence/selective_checkpointer.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,41 @@
99
)
1010
from persistence.s3_checkpointer import S3Checkpointer
1111

12+
13+
# Split messages into interactions, each one starting with a HumanMessage
14+
def _split_interactions(messages):
15+
if messages is None:
16+
return []
17+
18+
interactions = []
19+
current_interaction = []
20+
21+
for message in messages:
22+
if isinstance(message, HumanMessage) and current_interaction:
23+
interactions.append(current_interaction)
24+
current_interaction = []
25+
current_interaction.append(message)
26+
27+
if current_interaction:
28+
interactions.append(current_interaction)
29+
30+
return interactions
31+
1232
def _is_tool_message(message):
1333
if isinstance(message, ToolMessage):
1434
return True
15-
if isinstance(message, AIMessage) and message.additional_kwargs.get('stop_reason', '') == 'tool_use':
35+
if isinstance(message, AIMessage) and message.response_metadata.get('stop_reason', '') == 'tool_use':
1636
return True
1737
return False
1838

1939
def _prune_messages(messages):
20-
if messages is None:
21-
return messages
22-
23-
last_human_message = None
24-
for i, message in reversed(list(enumerate(messages))):
25-
if isinstance(message, HumanMessage):
26-
last_human_message = i
27-
break
28-
29-
if last_human_message is not None:
30-
return [
31-
msg
32-
for i, msg in enumerate(messages)
33-
if not _is_tool_message(msg) or i > last_human_message
34-
]
40+
interactions = _split_interactions(messages)
41+
# Remove all tool-related messages except those related to the most recent interaction
42+
for i, interaction in enumerate(interactions[:-1]):
43+
interactions[i] = [message for message in interaction if not _is_tool_message(message)]
3544

36-
return messages
45+
# Return the flattened list of messages
46+
return [message for interaction in interactions for message in interaction]
3747

3848
class SelectiveCheckpointer(S3Checkpointer):
3949
"""S3 Checkpointer that discards ToolMessages from previous checkpoints."""
@@ -44,7 +54,7 @@ def __init__(
4454
region_name: str = os.getenv("AWS_REGION"),
4555
endpoint_url: Optional[str] = None,
4656
compression: Optional[str] = None,
47-
retain_history: Optional[bool] = False,
57+
retain_history: Optional[bool] = True,
4858
) -> None:
4959
super().__init__(bucket_name, region_name, endpoint_url, compression)
5060
self.retain_history = retain_history

chat/test/core/test_setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_checkpoint_saver_initialization(self, mock_checkpointer):
2323

2424
mock_checkpointer.assert_called_once_with(
2525
bucket_name="test-bucket",
26+
retain_history=False,
2627
**kwargs
2728
)
2829
self.assertEqual(result, mock_checkpointer.return_value)

0 commit comments

Comments
 (0)