Skip to content

Commit

Permalink
Add streamlit demo for all deployable agents (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Apr 18, 2024
1 parent 17b3950 commit eb01089
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 31 deletions.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
python_version = 3.10
files = main.py, prediction_market_agent/, tests/
files = main.py, prediction_market_agent/, tests/, scripts/
plugins = pydantic.mypy
warn_redundant_casts = True
warn_unused_ignores = True
Expand Down
17 changes: 9 additions & 8 deletions prediction_market_agent/agents/known_outcome_agent/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@


class DeployableKnownOutcomeAgent(DeployableAgent):
model = "gpt-4-turbo-preview"
model = "gpt-4-1106-preview"
min_liquidity = 5

def load(self) -> None:
self.markets_with_known_outcomes: dict[str, Result] = {}
Expand All @@ -30,18 +31,16 @@ def pick_markets(self, markets: t.Sequence[AgentMarket]) -> list[AgentMarket]:
"This agent only supports predictions on Omen markets"
)

logger.info(f"Looking at market {market.id=} {market.question=}")

# Assume very high probability markets are already known, and have
# been correctly bet on, and therefore the value of betting on them
# is low.
if market_is_saturated(market=market):
logger.info(
f"Skipping market {market.id=} {market.question=}, because it is already saturated."
f"Skipping market {market.url} with the question '{market.question}', because it is already saturated."
)
elif market.get_liquidity_in_xdai() < 5:
elif market.get_liquidity_in_xdai() < self.min_liquidity:
logger.info(
f"Skipping market {market.id=} {market.question=}, because it has insufficient liquidity."
f"Skipping market {market.url} with the question '{market.question}', because it has insufficient liquidity (at least {self.min_liquidity} required)."
)
else:
picked_markets.append(market)
Expand All @@ -68,15 +67,17 @@ def answer_binary_market(self, market: AgentMarket) -> bool | None:
)
except Exception as e:
logger.error(
f"Failed to predict market {market.id=} {market.question=}: {e}"
f"Failed to predict market {market.url} with the question '{market.question}' because of '{e}'."
)
answer = None
if answer and answer.has_known_result():
logger.info(
f"Picking market {market.id=} {market.question=} with answer {answer.result=}"
f"Picking market {market.url} with the question '{market.question}' with answer '{answer.result}'"
)
return answer.result.to_boolean()

logger.info(f"No definite answer found for the market {market.url}.")

return None

def calculate_bet_amount(self, answer: bool, market: AgentMarket) -> BetAmount:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from prediction_market_agent.tools.web_scrape.basic_summary import _summary
from prediction_market_agent.tools.web_scrape.markdown import web_scrape
from prediction_market_agent.tools.web_search.tavily import web_search
from prediction_market_agent.utils import completion_str_to_json
from prediction_market_agent.utils import APIKeys, completion_str_to_json


class Result(str, Enum):
Expand Down Expand Up @@ -149,7 +149,9 @@ def summarize_if_required(content: str, model: str, question: str) -> str:
elif model == "gpt-4-1106-preview": # 128k context length
max_length = 100000
else:
raise ValueError(f"Unknown model: {model}")
raise ValueError(
f"Unknown model `{model}`, please add him to the `summarize_if_required` function."
)

if len(content) > max_length:
return _summary(content=content, objective=question, separators=[" "])
Expand All @@ -162,7 +164,11 @@ def has_question_event_happened_in_the_past(model: str, question: str) -> bool:
current date) (returning 1), if the event has not yet finished (returning 0) or
if it cannot be sure (returning -1)."""
date_str = utcnow().strftime("%Y-%m-%d %H:%M:%S %Z")
llm = ChatOpenAI(model=model, temperature=0.0)
llm = ChatOpenAI(
model=model,
temperature=0.0,
api_key=APIKeys().openai_api_key.get_secret_value(),
)
prompt = ChatPromptTemplate.from_template(
template=HAS_QUESTION_HAPPENED_IN_THE_PAST_PROMPT
).format_messages(
Expand All @@ -176,7 +182,7 @@ def has_question_event_happened_in_the_past(model: str, question: str) -> bool:
return True
except Exception as e:
logger.error(
"Exception occured, cannot assert if title happened in the past. ", e
f"Exception occured, cannot assert if title happened in the past because of '{e}'."
)

return False
Expand All @@ -191,14 +197,18 @@ def get_known_outcome(model: str, question: str, max_tries: int) -> Answer:
tries = 0
date_str = datetime.now().strftime("%d %B %Y")
previous_urls = []
llm = ChatOpenAI(model=model, temperature=0.4)
llm = ChatOpenAI(
model=model,
temperature=0.4,
api_key=APIKeys().openai_api_key.get_secret_value(),
)
while tries < max_tries:
search_prompt = ChatPromptTemplate.from_template(
template=GENERATE_SEARCH_QUERY_PROMPT
).format_messages(date_str=date_str, question=question)
logger.debug(f"Invoking LLM for {search_prompt=}")
logger.debug(f"Invoking LLM for the prompt '{search_prompt[0]}'")
search_query = str(llm.invoke(search_prompt).content).strip('"')
logger.debug(f"Searching for {search_query=}")
logger.debug(f"Searching web for the search query '{search_query}'")
search_results = web_search(query=search_query, max_results=5)
if not search_results:
raise ValueError("No search results found.")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import typing as t

from loguru import logger
from prediction_market_agent_tooling.deploy.agent import DeployableAgent
from prediction_market_agent_tooling.markets.agent_market import AgentMarket
from prediction_market_agent_tooling.markets.markets import MarketType
Expand All @@ -26,6 +27,10 @@ def pick_markets(self, markets: t.Sequence[AgentMarket]) -> t.Sequence[AgentMark
picked_markets.append(market)
if len(picked_markets) == 5:
break
else:
logger.info(
f"Market {market.url} is too saturated to bet on with p_yes {market.p_yes}."
)

return picked_markets

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def split_research_into_outcomes(self, question: str) -> Outcomes:
tasks=[create_outcomes_task],
)
result = report_crew.kickoff(inputs={"scenario": question})
return Outcomes.model_validate_json(result)
outcomes = Outcomes.model_validate_json(result)
logger.info(f"Created possible outcomes: {outcomes.outcomes}")
return outcomes

def build_tasks_for_outcome(self, input_dict: dict[str, t.Any] = {}) -> list[Task]:
task_research_one_outcome = Task(
Expand Down Expand Up @@ -116,7 +118,11 @@ def generate_prediction_for_one_outcome(self, sentence: str) -> ProbabilityOutpu
)

result = crew.kickoff(inputs={"sentence": sentence})
return ProbabilityOutput.model_validate_json(result)
output = ProbabilityOutput.model_validate_json(result)
logger.info(
f"For the sentence '{sentence}', the prediction is '{output.decision}', with p_yes={output.p_yes}, p_no={output.p_no}, and confidence={output.confidence}"
)
return output

def generate_final_decision(
self, outcomes_with_probabilities: list[t.Tuple[str, ProbabilityOutput]]
Expand All @@ -143,13 +149,16 @@ def generate_final_decision(
"outcome_to_assess": outcomes_with_probabilities[0][0],
}
)
return ProbabilityOutput.model_validate_json(
output = ProbabilityOutput.model_validate_json(
task_final_decision.output.raw_output
)
logger.info(
f"The final prediction is '{output.decision}', with p_yes={output.p_yes}, p_no={output.p_no}, and confidence={output.confidence}"
)
return output

def answer_binary_market(self, question: str) -> ProbabilityOutput:
outcomes = self.split_research_into_outcomes(question)
logger.debug("outcomes ", outcomes)

outcomes_with_probs = []
task_map = {}
Expand All @@ -173,12 +182,13 @@ def answer_binary_market(self, question: str) -> ProbabilityOutput:

# We parse individual task results to build outcomes_with_probs
for outcome, tasks in task_map.items():
raw_output = tasks[1].output.raw_output
try:
prediction_result = ProbabilityOutput.model_validate_json(
tasks[1].output.raw_output
)
prediction_result = ProbabilityOutput.model_validate_json(raw_output)
except Exception as e:
logger.error("Could not parse result as ProbabilityOutput ", e)
logger.error(
f"Could not parse the result ('{raw_output}') as ProbabilityOutput because of {e}"
)
prediction_result = ProbabilityOutput(
p_yes=0.5, p_no=0.5, confidence=0, decision=""
)
Expand Down
35 changes: 35 additions & 0 deletions prediction_market_agent/tools/streamlit_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import typing as t

import streamlit as st
from loguru import logger

if t.TYPE_CHECKING:
from loguru import Message


def loguru_streamlit_sink(log: "Message") -> None:
record = log.record
level = record["level"].name

message = record["message"]
# Replace escaped newlines with actual newlines.
message = message.replace("\\n", "\n")

if level == "ERROR":
st.error(message, icon="❌")

elif level == "WARNING":
st.warning(message, icon="⚠️")

else:
st.info(message, icon="ℹ️")


@st.cache_resource
def add_sink_to_logger() -> None:
"""
Adds streamlit as a sink to the loguru, so any loguru logs will be shown in the streamlit app.
Needs to be behind a cache decorator, so it only runs once per streamlit session (otherwise we would see duplicated messages).
"""
logger.add(loguru_streamlit_sink)
8 changes: 7 additions & 1 deletion prediction_market_agent/tools/web_scrape/basic_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI

from prediction_market_agent.utils import APIKeys


def _summary(
objective: str, content: str, separators: list[str] = ["\n\n", "\n"]
) -> str:
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0125")
llm = ChatOpenAI(
temperature=0,
model="gpt-3.5-turbo-0125",
api_key=APIKeys().openai_api_key.get_secret_value(),
)
text_splitter = RecursiveCharacterTextSplitter(
separators=separators, chunk_size=10000, chunk_overlap=500
)
Expand Down
9 changes: 3 additions & 6 deletions prediction_market_agent/tools/web_search/tavily.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import tenacity
from prediction_market_agent_tooling.tools.cache import persistent_inmemory_cache
from prediction_market_agent_tooling.tools.utils import (
check_not_none,
secret_str_from_env,
)
from pydantic import BaseModel
from tavily import TavilyClient

from prediction_market_agent.utils import APIKeys


class WebSearchResult(BaseModel):
url: str
Expand All @@ -21,8 +19,7 @@ def web_search(query: str, max_results: int) -> list[WebSearchResult]:
"""
Web search using Tavily API.
"""
tavily_api_key = check_not_none(secret_str_from_env("TAVILY_API_KEY"))
tavily = TavilyClient(api_key=tavily_api_key.get_secret_value())
tavily = TavilyClient(api_key=APIKeys().tavily_api_key.get_secret_value())
response = tavily.search(
query=query,
search_depth="advanced",
Expand Down
100 changes: 100 additions & 0 deletions scripts/agent_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
PYTHONPATH=. streamlit run scripts/agent_app.py
Tip: if you specify PYTHONPATH=., streamlit will watch for the changes in all files, isntead of just this one.
"""

import typing as t

import streamlit as st
from prediction_market_agent_tooling.deploy.agent import DeployableAgent
from prediction_market_agent_tooling.markets.markets import (
MarketType,
get_binary_markets,
)

from prediction_market_agent.agents.known_outcome_agent.deploy import (
DeployableKnownOutcomeAgent,
)
from prediction_market_agent.agents.think_thoroughly_agent.deploy import (
DeployableThinkThoroughlyAgent,
)
from prediction_market_agent.tools.streamlit_utils import add_sink_to_logger

AGENTS = [DeployableKnownOutcomeAgent, DeployableThinkThoroughlyAgent]

st.set_page_config(layout="wide")
add_sink_to_logger()

st.title("Agent's decision-making process")

# Fetch markets from the selected market type.
market_source = MarketType(
st.selectbox(
"Select a market source", [market_source.value for market_source in MarketType]
)
)
markets = get_binary_markets(42, market_source)

# Select an agent from the list of available agents.
agent_class_names = st.multiselect(
"Select agents", [agent_class.__name__ for agent_class in AGENTS]
)
if not agent_class_names:
st.warning("Please select at least one agent.")
st.stop()

# Get the agent classes from the names.
agent_classes: list[t.Type[DeployableAgent]] = []
for AgentClass in AGENTS:
if AgentClass.__name__ in agent_class_names:
agent_classes.append(AgentClass)

# Ask the user to provide a question.
custom_question_input = st.checkbox("Provide a custom question", value=False)
question = (
st.text_input("Question")
if custom_question_input
else st.selectbox("Select a question", [m.question for m in markets])
)
if not question:
st.warning("Please enter a question.")
st.stop()

market = (
[m for m in markets if m.question == question][0]
if not custom_question_input
# If custom question is provided, just take some random market and update its question.
else markets[0].model_copy(update={"question": question, "p_yes": 0.5})
)

for idx, (column, AgentClass) in enumerate(
zip(st.columns(len(agent_classes)), agent_classes)
):
with column:
# Show the agent's title.
st.write(
f"## {AgentClass.__name__.replace('Deployable', '').replace('Agent', '')}"
)

# Simulate deployable agent logic.
agent = AgentClass()

if not agent.pick_markets([market]):
st.warning("Agent wouldn't pick this market to bet on.")
if not st.checkbox(
"Continue with the prediction anyway",
value=False,
key=f"continue_{idx}",
):
continue

answer = agent.answer_binary_market(market)

if answer is None:
st.error("Agent failed to answer this market.")
continue

bet_amount = agent.calculate_bet_amount(answer, market)

st.success(f"Would bet {bet_amount.amount} {bet_amount.currency} on {answer}!")

0 comments on commit eb01089

Please sign in to comment.