Skip to content

Commit

Permalink
Add function to retrieve actions given context for general agent (#616)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Jan 8, 2025
1 parent 8d52457 commit 9a914ef
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
from datetime import timedelta

from langchain.vectorstores.chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from microchain import Function
from prediction_market_agent_tooling.tools.utils import utcnow
from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow

from prediction_market_agent.agents.microchain_agent.memory import DatedChatMessage
from prediction_market_agent.agents.microchain_agent.microchain_agent_keys import (
MicrochainAgentKeys,
)
from prediction_market_agent.agents.utils import memories_to_learnings
from prediction_market_agent.db.long_term_memory_table_handler import (
LongTermMemories,
LongTermMemoryTableHandler,
)


class LookAtPastActions(Function):
class LongTermMemoryBasedFunction(Function):
def __init__(
self, long_term_memory: LongTermMemoryTableHandler, model: str
) -> None:
self.long_term_memory = long_term_memory
self.model = model
super().__init__()


class LookAtPastActionsFromLastDay(LongTermMemoryBasedFunction):
@property
def description(self) -> str:
return (
Expand All @@ -42,3 +50,49 @@ def __call__(self) -> str:
DatedChatMessage.from_long_term_memory(ltm) for ltm in memories
]
return memories_to_learnings(memories=simple_memories, model=self.model)


class CheckAllPastActionsGivenContext(LongTermMemoryBasedFunction):
@property
def description(self) -> str:
return (
"Use this function to fetch information about the actions you executed with respect to a specific context. "
"For example, you can use this function to look into all your past actions if you ever did form a coalition with another agent."
)

@property
def example_args(self) -> list[str]:
return ["What coalitions did I form?"]

def __call__(self, context: str) -> str:
keys = MicrochainAgentKeys()
all_memories = self.long_term_memory.search()

collection = Chroma(
embedding_function=OpenAIEmbeddings(
api_key=keys.openai_api_key_secretstr_v1
)
)
collection.add_texts(
texts=[
f"From: {check_not_none(x.metadata_dict)['role']} Content: {check_not_none(x.metadata_dict)['content']}"
for x in all_memories
],
metadatas=[{"json": x.model_dump_json()} for x in all_memories],
)

top_k_per_query_results = collection.similarity_search(context, k=50)
results = [
DatedChatMessage.from_long_term_memory(
LongTermMemories.model_validate_json(x.metadata["json"])
)
for x in top_k_per_query_results
]

return memories_to_learnings(memories=results, model=self.model)


MEMORY_FUNCTIONS: list[type[LongTermMemoryBasedFunction]] = [
LookAtPastActionsFromLastDay,
CheckAllPastActionsGivenContext,
]
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
MARKET_FUNCTIONS,
)
from prediction_market_agent.agents.microchain_agent.memory_functions import (
LookAtPastActions,
MEMORY_FUNCTIONS,
)
from prediction_market_agent.agents.microchain_agent.nft_functions import NFT_FUNCTIONS
from prediction_market_agent.agents.microchain_agent.nft_treasury_game.messages_functions import (
Expand Down Expand Up @@ -171,8 +171,8 @@ def build_agent_functions(
functions.extend(f() for f in BALANCE_FUNCTIONS)

if long_term_memory:
functions.append(
LookAtPastActions(long_term_memory=long_term_memory, model=model)
functions.extend(
f(long_term_memory=long_term_memory, model=model) for f in MEMORY_FUNCTIONS
)

return functions
Expand Down
4 changes: 2 additions & 2 deletions prediction_market_agent/agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


MEMORIES_TO_LEARNINGS_TEMPLATE = """
You are an agent that trades in prediction markets. You are aiming to improve
You are an agent that does actions on its own. You are aiming to improve
your strategy over time. You have a collection of memories that record your
actions, and your reasoning behind them.
Expand All @@ -32,7 +32,7 @@
Each memory comes with a timestamp. If the memories are clustered into
different times, then make a separate list for each cluster. Refer to each
cluster as a 'Trading Session', and display the range of timestamps for each.
cluster as a 'Session', and display the range of timestamps for each.
MEMORIES:
{memories}
Expand Down
33 changes: 30 additions & 3 deletions tests/agents/microchain/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
SellYes,
)
from prediction_market_agent.agents.microchain_agent.memory_functions import (
LookAtPastActions,
CheckAllPastActionsGivenContext,
LookAtPastActionsFromLastDay,
)
from prediction_market_agent.agents.microchain_agent.utils import (
get_balance,
Expand Down Expand Up @@ -163,7 +164,7 @@ def test_predict_probability(market_type: MarketType) -> None:


@pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.")
def test_remember_past_learnings(
def test_look_at_past_actions(
long_term_memory_table_handler: LongTermMemoryTableHandler,
) -> None:
long_term_memory_table_handler.save_history(
Expand All @@ -175,13 +176,39 @@ def test_remember_past_learnings(
)
## Uncomment below to test with the memories accrued from use of https://autonomous-trader-agent.streamlit.app/
# long_term_memory = LongTermMemoryTableHandler(task_description="microchain-streamlit-app")
past_actions = LookAtPastActions(
past_actions = LookAtPastActionsFromLastDay(
long_term_memory=long_term_memory_table_handler,
model=DEFAULT_OPENAI_MODEL,
)
print(past_actions())


@pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.")
def test_check_past_actions_given_context(
long_term_memory_table_handler: LongTermMemoryTableHandler,
) -> None:
long_term_memory_table_handler.save_history(
history=[
{
"role": "user",
"content": "Agent X sent me a message asking for a coalition.",
},
{
"role": "user",
"content": "I agreed with agent X to form a coalition, I'll send him my NFT key if he sends me 5 xDai",
},
{"role": "user", "content": "I went to the park and saw a bird."},
]
)
## Uncomment below to test with the memories accrued from use of https://autonomous-trader-agent.streamlit.app/
# long_term_memory = LongTermMemoryTableHandler(task_description="microchain-streamlit-app")
past_actions = CheckAllPastActionsGivenContext(
long_term_memory=long_term_memory_table_handler,
model=DEFAULT_OPENAI_MODEL,
)
print(past_actions(context="What coalitions did I form?"))


@pytest.mark.parametrize("market_type", [MarketType.OMEN])
def test_kelly_bet(market_type: MarketType) -> None:
get_kelly_bet = GetKellyBet(market_type=market_type, keys=APIKeys())
Expand Down

0 comments on commit 9a914ef

Please sign in to comment.