Skip to content

Commit

Permalink
add system_font_color and player_font_colors arg to main function
Browse files Browse the repository at this point in the history
  • Loading branch information
hmasdev committed Oct 15, 2024
1 parent bbdcf35 commit 968a780
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
12 changes: 11 additions & 1 deletion langchain_werewolf/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from itertools import cycle
import logging
import random
from typing import Any, Callable
from typing import Any, Callable, Iterable
import click
from dotenv import load_dotenv
from langchain.globals import set_verbose, set_debug
import pydantic
from .const import CLI_PROMPT_COLOR, CLI_ECHO_COLORS
from .enums import ESystemOutputType, EInputOutputType
from .game.main import create_game_graph
from .models.config import Config, GeneralConfig
Expand All @@ -26,6 +28,8 @@
system_output_interface=EInputOutputType.standard,
system_input_interface=EInputOutputType.standard,
system_formatter=None,
system_font_color=CLI_PROMPT_COLOR,
player_font_colors=cycle(CLI_ECHO_COLORS),
seed=-1,
model='gpt-4o-mini',
recursion_limit=1000,
Expand All @@ -46,6 +50,8 @@ def main(
system_output_interface: Callable[[str], None] | EInputOutputType = DEFAULT_GENERAL_CONFIG.system_output_interface, # type: ignore # noqa
system_input_interface: Callable[[str], Any] | EInputOutputType = DEFAULT_GENERAL_CONFIG.system_input_interface, # type: ignore # noqa
system_formatter: str | None = DEFAULT_GENERAL_CONFIG.system_formatter, # type: ignore # noqa
system_font_color: str | None = DEFAULT_GENERAL_CONFIG.system_font_color, # type: ignore # noqa
player_font_colors: Iterable[str] | str | None = DEFAULT_GENERAL_CONFIG.player_font_colors, # type: ignore # noqa
config: Config | str | None = None,
seed: int = DEFAULT_GENERAL_CONFIG.seed, # type: ignore # noqa
model: str = DEFAULT_GENERAL_CONFIG.model, # type: ignore # noqa
Expand Down Expand Up @@ -79,6 +85,8 @@ def main(
system_input_interface=config.general.system_input_interface if config.general.system_input_interface is not None else system_input_interface, # noqa
system_output_interface=config.general.system_output_interface if config.general.system_output_interface is not None else system_output_interface, # noqa
system_formatter=config.general.system_formatter if config.general.system_formatter is not None else system_formatter, # noqa
system_font_color=config.general.system_font_color if config.general.system_font_color is not None else system_font_color, # noqa
player_font_colors=config.general.player_font_colors if config.general.player_font_colors is not None else player_font_colors, # noqa
seed=config.general.seed if config.general.seed is not None else seed, # noqa
model=config.general.model if config.general.model is not None else model, # noqa
recursion_limit=config.general.recursion_limit if config.general.recursion_limit is not None else recursion_limit, # noqa
Expand Down Expand Up @@ -121,6 +129,8 @@ def main(
players=players,
model=config_used.general.model, # type: ignore
system_formatter=config_used.general.system_formatter, # type: ignore # noqa
system_color=config_used.general.system_font_color, # type: ignore # noqa
player_colors=config_used.general.player_font_colors, # type: ignore # noqa
seed=config_used.general.seed, # type: ignore
),
)
Expand Down
4 changes: 3 additions & 1 deletion langchain_werewolf/models/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable
from typing import Any, Callable, Iterable
from pydantic import BaseModel, Field
from ..const import DEFAULT_MODEL, CUSTOM_PLAYER_PREFIX
from ..enums import (
Expand All @@ -22,6 +22,8 @@ class GeneralConfig(BaseModel, frozen=True):
system_output_interface: Callable[[str], None] | EInputOutputType | None = Field(default=None, title="The system output interface. Default is None.") # noqa
system_input_interface: Callable[[str], Any] | EInputOutputType | None = Field(default=None, title="The system input interface. Default is None.") # noqa
system_formatter: str | None = Field(default=None, title="The system formatter. The format should not include anything other than " + ', '.join('"{'+k+'}"' for k in MsgModel.model_fields.keys())) # noqa
system_font_color: str | None = Field(default=None, title="The system font color. Default is None.") # noqa
player_font_colors: Iterable | str | None = Field(default=None, title="The player font colors. Default is None.") # noqa
seed: int | None = Field(default=None, title="The random seed. Defaults to None.") # noqa
model: str | None = Field(default=None, title=f"The model to use. Default is None.") # noqa
recursion_limit: int | None = Field(default=None, title="The recursion limit. Default is None.") # noqa
Expand Down
8 changes: 5 additions & 3 deletions langchain_werewolf/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,16 @@ def create_echo_runnable(
players: Iterable[BaseGamePlayer] = tuple(),
model: str = DEFAULT_MODEL,
system_formatter: Callable[[MsgModel], str] | str | None = None,
system_color: str = CLI_PROMPT_COLOR,
player_colors: Iterable[str] = cycle(CLI_ECHO_COLORS),
system_color: str | None = CLI_PROMPT_COLOR,
player_colors: Iterable[str] | str | None = cycle(CLI_ECHO_COLORS),
language: ELanguage = BASE_LANGUAGE,
seed: int = -1,
) -> Runnable[StateModel, None]:
# initialize
player_names: list[str] = [player.name for player in players]
player_colors_ = {player.name: color for player, color in zip(players, player_colors)} # noqa
player_colors = player_colors or cycle([None])
player_colors = [player_colors] if isinstance(player_colors, str) else player_colors # noqa
player_colors_ = {player.name: color or None for player, color in zip(players, player_colors)} # noqa
# create cache
caches: dict[str, set[str]] = (
{player.name: {''} for player in players}
Expand Down

0 comments on commit 968a780

Please sign in to comment.