Skip to content

Commit

Permalink
Use AgentMarket.get_positions in microchain Function (#91)
Browse files Browse the repository at this point in the history
* Use AgentMarket.get_positions in microchain Function

* Decimal -> float

* Bump PMAT version
  • Loading branch information
evangriffiths authored Apr 19, 2024
1 parent eb01089 commit 95ff9d7
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 231 deletions.
6 changes: 1 addition & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from decimal import Decimal

import typer
from loguru import logger
from prediction_market_agent_tooling.markets.agent_market import SortBy
Expand Down Expand Up @@ -47,9 +45,7 @@ def main(
logger.info(
f"Placing bet with position {pma.utils.parse_result_to_str(result)} on market '{market.question}'"
)
amount = Decimal(
input(f"How much do you want to bet? (in {market.currency}): ")
)
amount = float(input(f"How much do you want to bet? (in {market.currency}): "))
market.place_bet(
amount=market.get_bet_amount(amount),
outcome=result,
Expand Down
387 changes: 196 additions & 191 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions prediction_market_agent/agents/known_outcome_agent/deploy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
import typing as t
from decimal import Decimal

from loguru import logger
from prediction_market_agent_tooling.deploy.agent import DeployableAgent
Expand Down Expand Up @@ -82,6 +81,6 @@ def answer_binary_market(self, market: AgentMarket) -> bool | None:

def calculate_bet_amount(self, answer: bool, market: AgentMarket) -> BetAmount:
if isinstance(market, OmenAgentMarket):
return BetAmount(amount=(Decimal(1.0)), currency=market.currency)
return BetAmount(amount=1.0, currency=market.currency)
else:
raise NotImplementedError("This agent only supports xDai markets")
33 changes: 17 additions & 16 deletions prediction_market_agent/agents/microchain_agent/functions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import typing as t
from decimal import Decimal

from eth_utils import to_checksum_address
from microchain import Function
from prediction_market_agent_tooling.markets.agent_market import AgentMarket
from prediction_market_agent_tooling.markets.data_models import Currency, TokenAmount
from prediction_market_agent_tooling.markets.markets import MarketType
from prediction_market_agent_tooling.markets.omen.data_models import OmenUserPosition
from prediction_market_agent_tooling.markets.omen.omen_subgraph_handler import (
OmenSubgraphHandler,
)

from prediction_market_agent.agents.microchain_agent.utils import (
MechResult,
Expand Down Expand Up @@ -207,7 +201,7 @@ def __call__(self, market_id: str, amount: float) -> str:
)
market.buy_tokens(
outcome=self.outcome_bool,
amount=TokenAmount(amount=Decimal(amount), currency=self.currency),
amount=TokenAmount(amount=amount, currency=self.currency),
)
after_balance = market.get_token_balance(
user_id=self.user_address,
Expand Down Expand Up @@ -263,7 +257,7 @@ def __call__(self, market_id: str, amount: float) -> str:

market.sell_tokens(
outcome=self.outcome_bool,
amount=TokenAmount(amount=Decimal(amount), currency=self.currency),
amount=TokenAmount(amount=amount, currency=self.currency),
)

after_balance = market.get_token_balance(
Expand Down Expand Up @@ -314,25 +308,32 @@ def description(self) -> str:
def example_args(self) -> list[str]:
return []

def __call__(self) -> Decimal:
def __call__(self) -> float:
return get_balance(market_type=self.market_type).amount


class GetUserPositions(MarketFunction):
class GetPositions(MarketFunction):
def __init__(self, market_type: MarketType) -> None:
self.user_address = MicrochainAPIKeys().bet_from_address
super().__init__(market_type=market_type)

@property
def description(self) -> str:
return (
"Use this function to fetch the markets where the user has previously bet."
"Use this function to fetch the live markets where you have "
"previously bet, and the token amounts you hold for each outcome."
)

@property
def example_args(self) -> list[str]:
return ["0x2DD9f5678484C1F59F97eD334725858b938B4102"]
return []

def __call__(self, user_address: str) -> list[OmenUserPosition]:
return OmenSubgraphHandler().get_user_positions(
better_address=to_checksum_address(user_address)
def __call__(self) -> list[str]:
self.user_address = MicrochainAPIKeys().bet_from_address
positions = self.market_type.market_class.get_positions(
user_id=self.user_address
)
return [str(position) for position in positions]


MISC_FUNCTIONS = [
Expand All @@ -352,5 +353,5 @@ def __call__(self, user_address: str) -> list[OmenUserPosition]:
BuyNo,
SellYes,
SellNo,
GetUserPositions,
GetPositions,
]
3 changes: 1 addition & 2 deletions prediction_market_agent/agents/microchain_agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import tempfile
import typing as t
from contextlib import contextmanager
from decimal import Decimal
from enum import Enum

from mech_client.interact import ConfirmationType, interact
Expand Down Expand Up @@ -94,7 +93,7 @@ def get_balance(market_type: MarketType) -> BetAmount:
if market_type == MarketType.OMEN:
# We focus solely on xDAI balance for now to avoid the agent having to wrap/unwrap xDAI.
return BetAmount(
amount=Decimal(get_balances(MicrochainAPIKeys().bet_from_address).xdai),
amount=get_balances(MicrochainAPIKeys().bet_from_address).xdai,
currency=currency,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ poetry = "^1.7.1"
poetry-plugin-export = "^1.6.0"
functions-framework = "^3.5.0"
cron-validator = "^1.0.8"
prediction-market-agent-tooling = { version = "^0.13.2", extras = ["langchain", "google"] }
prediction-market-agent-tooling = { version = "^0.14.1", extras = ["langchain", "google"] }
pydantic-settings = "^2.1.0"
autoflake = "^2.2.1"
isort = "^5.13.2"
Expand Down
19 changes: 5 additions & 14 deletions tests/agents/microchain/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from microchain.functions import Reasoning, Stop
from prediction_market_agent_tooling.markets.agent_market import AgentMarket
from prediction_market_agent_tooling.markets.markets import MarketType
from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes

from prediction_market_agent.agents.microchain_agent.functions import (
MARKET_FUNCTIONS,
Expand All @@ -14,7 +13,7 @@
GetBalance,
GetMarketProbability,
GetMarkets,
GetUserPositions,
GetPositions,
MarketFunction,
PredictProbabilityForQuestionLocal,
PredictProbabilityForQuestionRemote,
Expand Down Expand Up @@ -64,18 +63,10 @@ def test_replicator_has_balance_gt_0(market_type: MarketType) -> None:


@pytest.mark.parametrize("market_type", [MarketType.OMEN])
def test_agent_0_has_bet_on_market(market_type: MarketType) -> None:
user_positions = GetUserPositions(market_type=market_type)(AGENT_0_ADDRESS)
# Assert 3 conditionIds are included
expected_condition_ids = [
HexBytes("0x9c7711bee0902cc8e6838179058726a7ba769cc97d4d0ea47b31370d2d7a117b"),
HexBytes("0xe2bf80af2a936cdabeef4f511620a2eec46f1caf8e75eb5dc189372367a9154c"),
HexBytes("0x3f8153364001b26b983dd92191a084de8230f199b5ad0b045e9e1df61089b30d"),
]
unique_condition_ids: list[HexBytes] = sum(
[u.position.conditionIds for u in user_positions], []
)
assert set(expected_condition_ids).issubset(unique_condition_ids)
def test_get_positions(market_type: MarketType) -> None:
get_positions = GetPositions(market_type=market_type)
positions = get_positions()
assert len(positions) > 0


@pytest.mark.parametrize("market_type", [MarketType.OMEN])
Expand Down

0 comments on commit 95ff9d7

Please sign in to comment.