Skip to content

Commit

Permalink
Adds Long-Term-memory to microchain agent (#129)
Browse files Browse the repository at this point in the history
* Adding sqlalchemy support

* Added long term storage

* Using SQLTable for db interactions

* Added test for db_storage.py | tested microchain manually

* Updated poetry.lock

* Minor improvements

* Updated poetry.lock

* Fixing mypy etc

* Added PR comments

* Changes after merge

* Updated poetry.lock

* Fixing mypy

* Fixed isort

* Implemented PR comments

* Fixed isort

* Fixed black
  • Loading branch information
gabrielfior authored May 7, 2024
1 parent dd79d47 commit de6c70b
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 23 deletions.
98 changes: 97 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions prediction_market_agent/agents/microchain_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from prediction_market_agent_tooling.tools.costs import openai_costs
from prediction_market_agent_tooling.tools.utils import check_not_none

from prediction_market_agent.agents.microchain_agent.microchain_agent import get_agent
from prediction_market_agent.agents.microchain_agent.microchain_agent import build_agent
from prediction_market_agent.agents.microchain_agent.utils import (
has_been_run_past_initialization,
)
Expand Down Expand Up @@ -65,7 +65,7 @@ def execute_reasoning(agent: Agent, reasoning: str, model: str) -> None:

# Initialize the agent
if "agent" not in st.session_state:
st.session_state.agent = get_agent(market_type=MarketType.OMEN, model=model)
st.session_state.agent = build_agent(market_type=MarketType.OMEN, model=model)
st.session_state.agent.reset()
st.session_state.agent.build_initial_messages()
st.session_state.running_cost = 0.0
Expand Down
4 changes: 2 additions & 2 deletions prediction_market_agent/agents/microchain_agent/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from prediction_market_agent_tooling.deploy.agent import DeployableAgent
from prediction_market_agent_tooling.markets.markets import MarketType

from prediction_market_agent.agents.microchain_agent.microchain_agent import get_agent
from prediction_market_agent.agents.microchain_agent.microchain_agent import build_agent


class DeployableMicrochainAgent(DeployableAgent):
Expand All @@ -14,7 +14,7 @@ def run(self, market_type: MarketType) -> None:
Override main 'run' method, as the all logic from the helper methods
is handed over to the agent.
"""
agent: Agent = get_agent(
agent: Agent = build_agent(
market_type=market_type,
model=self.model,
)
Expand Down
39 changes: 23 additions & 16 deletions prediction_market_agent/agents/microchain_agent/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from prediction_market_agent_tooling.markets.data_models import Currency, TokenAmount
from prediction_market_agent_tooling.markets.markets import MarketType

from prediction_market_agent.agents.microchain_agent.memory import LongTermMemory
from prediction_market_agent.agents.microchain_agent.utils import (
MicroMarket,
get_balance,
Expand All @@ -15,6 +16,7 @@
get_no_outcome,
get_yes_outcome,
)
from prediction_market_agent.db.models import LongTermMemories
from prediction_market_agent.tools.mech.utils import (
MechResponse,
MechTool,
Expand Down Expand Up @@ -297,21 +299,6 @@ def __init__(self, market_type: MarketType) -> None:
)


class SummarizeLearning(Function):
@property
def description(self) -> str:
return "Use this function summarize your learnings and save them so that you can access them later."

@property
def example_args(self) -> list[str]:
return [
"Today I learned that I need to check my balance fore making decisions about how much to invest."
]

def __call__(self, summary: str) -> str:
return summary


class GetBalance(MarketFunction):
@property
def description(self) -> str:
Expand Down Expand Up @@ -351,10 +338,30 @@ def __call__(self) -> list[str]:
return [str(position) for position in positions]


class RememberPastLearnings(Function):
def __init__(self, long_term_memory: LongTermMemory) -> None:
self.long_term_memory = long_term_memory
super().__init__()

@property
def description(self) -> str:
return """Use this function to fetch information about the previous actions you executed. Examples of past
activities include previous bets you placed, previous markets you redeemed from, balances you requested,
market positions you requested, markets you fetched, tokens you bought, tokens you sold, probabilities for
markets you requested, among others.
"""

@property
def example_args(self) -> list[str]:
return []

def __call__(self) -> t.Sequence[LongTermMemories]:
return self.long_term_memory.search()


MISC_FUNCTIONS = [
Sum,
Product,
# SummarizeLearning,
]

# Functions that interact with the prediction markets
Expand Down
22 changes: 22 additions & 0 deletions prediction_market_agent/agents/microchain_agent/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# inspired by crewAI's LongTermMemory (https://github.com/joaomdmoura/crewAI/blob/main/src/crewai/memory/long_term/long_term_memory.py)
from typing import Any, Dict, Sequence

from prediction_market_agent.db.db_storage import DBStorage
from prediction_market_agent.db.models import LongTermMemories


# In the future, create a base class which this class extends.
class LongTermMemory:
def __init__(self, task_description: str):
self.task_description = task_description
self.storage = DBStorage()

def save_history(self, history: list[Dict[str, Any]]) -> None:
"""Save item to storage. Note that score allows many types for easier handling by agent."""
self.storage.save_multiple(
task_description=self.task_description,
history=history,
)

def search(self, latest_n: int = 5) -> Sequence[LongTermMemories]:
return self.storage.load(self.task_description, latest_n)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typer
from functions import RememberPastLearnings
from microchain import LLM, Agent, Engine, OpenAIChatGenerator
from microchain.functions import Reasoning, Stop
from prediction_market_agent_tooling.markets.markets import MarketType
Expand All @@ -7,6 +8,7 @@
MARKET_FUNCTIONS,
MISC_FUNCTIONS,
)
from prediction_market_agent.agents.microchain_agent.memory import LongTermMemory
from prediction_market_agent.agents.microchain_agent.omen_functions import (
OMEN_FUNCTIONS,
)
Expand All @@ -22,10 +24,11 @@
"""


def get_agent(
def build_agent(
market_type: MarketType,
model: str,
api_base: str = "https://api.openai.com/v1",
long_term_memory: LongTermMemory | None = None,
) -> Agent:
engine = Engine()
engine.register(Reasoning())
Expand All @@ -36,6 +39,10 @@ def get_agent(
engine.register(function(market_type=market_type))
for function in OMEN_FUNCTIONS:
engine.register(function())

if long_term_memory:
engine.register(RememberPastLearnings(long_term_memory))

generator = OpenAIChatGenerator(
model=model,
api_key=APIKeys().openai_api_key.get_secret_value(),
Expand All @@ -58,15 +65,22 @@ def main(
iterations: int = 10,
seed_prompt: str | None = None,
) -> None:
agent = get_agent(
# This description below serves to unique identify agent entries on the LTM, and should be
# unique across instances (i.e. markets).
unique_task_description = f"microchain-agent-demo-{market_type}"
long_term_memory = LongTermMemory(unique_task_description)

agent = build_agent(
market_type=market_type,
api_base=api_base,
model=model,
long_term_memory=long_term_memory,
)
if seed_prompt:
agent.bootstrap = [f'Reasoning("{seed_prompt}")']
agent.run(iterations=iterations)
# generator.print_usage() # Waiting for microchain release
long_term_memory.save_history(agent.history)


if __name__ == "__main__":
Expand Down
Empty file.
Loading

0 comments on commit de6c70b

Please sign in to comment.