From 9a914efe6f91a53c010a0dee9015d0c89ce9b21f Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Wed, 8 Jan 2025 15:19:46 +0100 Subject: [PATCH] Add function to retrieve actions given context for general agent (#616) --- .../microchain_agent/memory_functions.py | 58 ++++++++++++++++++- .../microchain_agent/microchain_agent.py | 6 +- prediction_market_agent/agents/utils.py | 4 +- tests/agents/microchain/test_functions.py | 33 ++++++++++- 4 files changed, 91 insertions(+), 10 deletions(-) diff --git a/prediction_market_agent/agents/microchain_agent/memory_functions.py b/prediction_market_agent/agents/microchain_agent/memory_functions.py index afe98a23..01f5e709 100644 --- a/prediction_market_agent/agents/microchain_agent/memory_functions.py +++ b/prediction_market_agent/agents/microchain_agent/memory_functions.py @@ -1,16 +1,22 @@ from datetime import timedelta +from langchain.vectorstores.chroma import Chroma +from langchain_openai import OpenAIEmbeddings from microchain import Function -from prediction_market_agent_tooling.tools.utils import utcnow +from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow from prediction_market_agent.agents.microchain_agent.memory import DatedChatMessage +from prediction_market_agent.agents.microchain_agent.microchain_agent_keys import ( + MicrochainAgentKeys, +) from prediction_market_agent.agents.utils import memories_to_learnings from prediction_market_agent.db.long_term_memory_table_handler import ( + LongTermMemories, LongTermMemoryTableHandler, ) -class LookAtPastActions(Function): +class LongTermMemoryBasedFunction(Function): def __init__( self, long_term_memory: LongTermMemoryTableHandler, model: str ) -> None: @@ -18,6 +24,8 @@ def __init__( self.model = model super().__init__() + +class LookAtPastActionsFromLastDay(LongTermMemoryBasedFunction): @property def description(self) -> str: return ( @@ -42,3 +50,49 @@ def __call__(self) -> str: DatedChatMessage.from_long_term_memory(ltm) for ltm in memories ] return memories_to_learnings(memories=simple_memories, model=self.model) + + +class CheckAllPastActionsGivenContext(LongTermMemoryBasedFunction): + @property + def description(self) -> str: + return ( + "Use this function to fetch information about the actions you executed with respect to a specific context. " + "For example, you can use this function to look into all your past actions if you ever did form a coalition with another agent." + ) + + @property + def example_args(self) -> list[str]: + return ["What coalitions did I form?"] + + def __call__(self, context: str) -> str: + keys = MicrochainAgentKeys() + all_memories = self.long_term_memory.search() + + collection = Chroma( + embedding_function=OpenAIEmbeddings( + api_key=keys.openai_api_key_secretstr_v1 + ) + ) + collection.add_texts( + texts=[ + f"From: {check_not_none(x.metadata_dict)['role']} Content: {check_not_none(x.metadata_dict)['content']}" + for x in all_memories + ], + metadatas=[{"json": x.model_dump_json()} for x in all_memories], + ) + + top_k_per_query_results = collection.similarity_search(context, k=50) + results = [ + DatedChatMessage.from_long_term_memory( + LongTermMemories.model_validate_json(x.metadata["json"]) + ) + for x in top_k_per_query_results + ] + + return memories_to_learnings(memories=results, model=self.model) + + +MEMORY_FUNCTIONS: list[type[LongTermMemoryBasedFunction]] = [ + LookAtPastActionsFromLastDay, + CheckAllPastActionsGivenContext, +] diff --git a/prediction_market_agent/agents/microchain_agent/microchain_agent.py b/prediction_market_agent/agents/microchain_agent/microchain_agent.py index 7413d62f..9ff7305c 100644 --- a/prediction_market_agent/agents/microchain_agent/microchain_agent.py +++ b/prediction_market_agent/agents/microchain_agent/microchain_agent.py @@ -42,7 +42,7 @@ MARKET_FUNCTIONS, ) from prediction_market_agent.agents.microchain_agent.memory_functions import ( - LookAtPastActions, + MEMORY_FUNCTIONS, ) from prediction_market_agent.agents.microchain_agent.nft_functions import NFT_FUNCTIONS from prediction_market_agent.agents.microchain_agent.nft_treasury_game.messages_functions import ( @@ -171,8 +171,8 @@ def build_agent_functions( functions.extend(f() for f in BALANCE_FUNCTIONS) if long_term_memory: - functions.append( - LookAtPastActions(long_term_memory=long_term_memory, model=model) + functions.extend( + f(long_term_memory=long_term_memory, model=model) for f in MEMORY_FUNCTIONS ) return functions diff --git a/prediction_market_agent/agents/utils.py b/prediction_market_agent/agents/utils.py index 92e4d5ab..744aed73 100644 --- a/prediction_market_agent/agents/utils.py +++ b/prediction_market_agent/agents/utils.py @@ -22,7 +22,7 @@ MEMORIES_TO_LEARNINGS_TEMPLATE = """ -You are an agent that trades in prediction markets. You are aiming to improve +You are an agent that does actions on its own. You are aiming to improve your strategy over time. You have a collection of memories that record your actions, and your reasoning behind them. @@ -32,7 +32,7 @@ Each memory comes with a timestamp. If the memories are clustered into different times, then make a separate list for each cluster. Refer to each -cluster as a 'Trading Session', and display the range of timestamps for each. +cluster as a 'Session', and display the range of timestamps for each. MEMORIES: {memories} diff --git a/tests/agents/microchain/test_functions.py b/tests/agents/microchain/test_functions.py index ff25a84b..c13c6fd1 100644 --- a/tests/agents/microchain/test_functions.py +++ b/tests/agents/microchain/test_functions.py @@ -18,7 +18,8 @@ SellYes, ) from prediction_market_agent.agents.microchain_agent.memory_functions import ( - LookAtPastActions, + CheckAllPastActionsGivenContext, + LookAtPastActionsFromLastDay, ) from prediction_market_agent.agents.microchain_agent.utils import ( get_balance, @@ -163,7 +164,7 @@ def test_predict_probability(market_type: MarketType) -> None: @pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.") -def test_remember_past_learnings( +def test_look_at_past_actions( long_term_memory_table_handler: LongTermMemoryTableHandler, ) -> None: long_term_memory_table_handler.save_history( @@ -175,13 +176,39 @@ def test_remember_past_learnings( ) ## Uncomment below to test with the memories accrued from use of https://autonomous-trader-agent.streamlit.app/ # long_term_memory = LongTermMemoryTableHandler(task_description="microchain-streamlit-app") - past_actions = LookAtPastActions( + past_actions = LookAtPastActionsFromLastDay( long_term_memory=long_term_memory_table_handler, model=DEFAULT_OPENAI_MODEL, ) print(past_actions()) +@pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.") +def test_check_past_actions_given_context( + long_term_memory_table_handler: LongTermMemoryTableHandler, +) -> None: + long_term_memory_table_handler.save_history( + history=[ + { + "role": "user", + "content": "Agent X sent me a message asking for a coalition.", + }, + { + "role": "user", + "content": "I agreed with agent X to form a coalition, I'll send him my NFT key if he sends me 5 xDai", + }, + {"role": "user", "content": "I went to the park and saw a bird."}, + ] + ) + ## Uncomment below to test with the memories accrued from use of https://autonomous-trader-agent.streamlit.app/ + # long_term_memory = LongTermMemoryTableHandler(task_description="microchain-streamlit-app") + past_actions = CheckAllPastActionsGivenContext( + long_term_memory=long_term_memory_table_handler, + model=DEFAULT_OPENAI_MODEL, + ) + print(past_actions(context="What coalitions did I form?")) + + @pytest.mark.parametrize("market_type", [MarketType.OMEN]) def test_kelly_bet(market_type: MarketType) -> None: get_kelly_bet = GetKellyBet(market_type=market_type, keys=APIKeys())