Skip to content

Commit

Permalink
Microchain streamlit app improvements (#147)
Browse files Browse the repository at this point in the history
* Display agent history each iteration
  • Loading branch information
evangriffiths authored May 7, 2024
1 parent a11c2ea commit 6ae8e2d
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 49 deletions.
123 changes: 79 additions & 44 deletions prediction_market_agent/agents/microchain_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,10 @@

from prediction_market_agent.agents.microchain_agent.microchain_agent import build_agent
from prediction_market_agent.agents.microchain_agent.utils import (
get_initial_history_length,
has_been_run_past_initialization,
)
from prediction_market_agent.streamlit_utils import check_required_api_keys
from prediction_market_agent.utils import APIKeys


def check_api_keys() -> None:
keys = APIKeys()
if not keys.OPENAI_API_KEY:
st.error("No OpenAI API Key provided via env var/secret.")
st.stop()
elif not keys.BET_FROM_PRIVATE_KEY:
st.error("No wallet private key provided via env var/secret.")
st.stop()


def run_agent(agent: Agent, iterations: int, model: str) -> None:
Expand All @@ -40,9 +30,48 @@ def run_agent(agent: Agent, iterations: int, model: str) -> None:
def execute_reasoning(agent: Agent, reasoning: str, model: str) -> None:
with openai_costs(model) as costs:
agent.execute_command(f'Reasoning("{reasoning}")')
display_new_history_callback(agent) # Run manually after `execute_command`
st.session_state.running_cost += costs.cost


def display_all_history(agent: Agent) -> None:
"""
Display the agent's history in the Streamlit app.
"""
# Skip the initial messages
history = agent.history[get_initial_history_length(agent) :]

for h in history:
st.chat_message(h["role"]).write(h["content"])


def display_new_history_callback(agent: Agent) -> None:
"""
A callback to display the agent's history in the Streamlit app after a run
with a single interation.
"""
history_depth = 2 # One for the user input, one for the agent's reply
history = agent.history[-history_depth:]
for h in history:
st.chat_message(h["role"]).write(h["content"])


def agent_is_initialized() -> bool:
return "agent" in st.session_state


def maybe_initialize_agent(model: str) -> None:
# Initialize the agent
if not agent_is_initialized():
st.session_state.agent = build_agent(market_type=MarketType.OMEN, model=model)
st.session_state.agent.reset()
st.session_state.agent.build_initial_messages()
st.session_state.running_cost = 0.0

# Add a callback to display the agent's history after each run
st.session_state.agent.on_iteration_end = display_new_history_callback


st.set_page_config(layout="wide")
st.title("Microchain Agent")
st.write(
Expand All @@ -54,55 +83,61 @@ def execute_reasoning(agent: Agent, reasoning: str, model: str) -> None:
check_required_api_keys(["OPENAI_API_KEY", "BET_FROM_PRIVATE_KEY"])

# Ask the user to choose a model
model = st.selectbox(
"Model",
["gpt-4-turbo-2024-04-09", "gpt-3.5-turbo-0125"],
index=0,
)
if model is None:
st.error("Please select a model.")
if not agent_is_initialized():
model = st.selectbox(
"Model",
["gpt-4-turbo-2024-04-09", "gpt-3.5-turbo-0125"],
index=0,
)
if model is None:
st.error("Please select a model.")
else:
model = st.selectbox(
"Model",
[st.session_state.agent.llm.generator.model],
index=0,
disabled=True,
)
model = check_not_none(model)

# Initialize the agent
if "agent" not in st.session_state:
st.session_state.agent = build_agent(market_type=MarketType.OMEN, model=model)
st.session_state.agent.reset()
st.session_state.agent.build_initial_messages()
st.session_state.running_cost = 0.0

# Interactive settings
user_reasoning = st.text_input("Reasoning")
if st.button("Intervene by adding reasoning"):
execute_reasoning(
agent=st.session_state.agent,
reasoning=user_reasoning,
model=model,
)

# Allow the user to run the
add_reasoning_button = st.button("Intervene by adding reasoning")
iterations = st.number_input(
"Run iterations",
value=1,
step=1,
min_value=1,
max_value=100,
)
if st.button("Run the agent"):
run_agent_button = st.button("Run the agent")

# Execution
if agent_is_initialized():
display_all_history(st.session_state.agent)
if add_reasoning_button:
maybe_initialize_agent(model)
execute_reasoning(
agent=st.session_state.agent,
reasoning=user_reasoning,
model=model,
)
if run_agent_button:
maybe_initialize_agent(model)
run_agent(
agent=st.session_state.agent,
iterations=int(iterations),
model=model,
)

if not has_been_run_past_initialization(st.session_state.agent):
if agent_is_initialized() and has_been_run_past_initialization(st.session_state.agent):
st.info(
"Run complete. Click 'Run' to allow the agent to continue, or add your "
"own reasoning."
)
# Display running cost
# st.info(f"Running OpenAPI credits cost: ${st.session_state.running_cost:.2f}") # TODO debug why always == 0.0
else:
st.info(
"Start by clicking 'Run' to see the agent in action. Alternatively "
"bootstrap the agent with your own reasoning before running."
)

# Display the agent's history
history = st.session_state.agent.history[3:] # Skip the initial messages
for h in history:
st.chat_message(h["role"]).write(h["content"])

# Display running cost
# st.info(f"Running OpenAPI credits cost: ${st.session_state.running_cost:.2f}") # TODO debug why always == 0.0
13 changes: 8 additions & 5 deletions prediction_market_agent/agents/microchain_agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,15 @@ def get_example_market_id(market_type: MarketType) -> str:
raise ValueError(f"Market type '{market_type}' not supported")


def has_been_run_past_initialization(agent: Agent) -> bool:
if not hasattr(agent, "history"):
return False

def get_initial_history_length(agent: Agent) -> int:
initialized_history_length = 1
if agent.bootstrap:
initialized_history_length += len(agent.bootstrap) * 2
return initialized_history_length


def has_been_run_past_initialization(agent: Agent) -> bool:
if not hasattr(agent, "history"):
return False

return len(agent.history) > initialized_history_length
return len(agent.history) > get_initial_history_length(agent)

0 comments on commit 6ae8e2d

Please sign in to comment.