From 968a7805b1e83b9ad9d8b0950dd9f63b2f90ec73 Mon Sep 17 00:00:00 2001 From: hmasdev Date: Tue, 15 Oct 2024 13:23:07 +0900 Subject: [PATCH] add system_font_color and player_font_colors arg to main function --- langchain_werewolf/main.py | 12 +++++++++++- langchain_werewolf/models/config.py | 4 +++- langchain_werewolf/setup.py | 8 +++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/langchain_werewolf/main.py b/langchain_werewolf/main.py index 28b34d8..2108b21 100644 --- a/langchain_werewolf/main.py +++ b/langchain_werewolf/main.py @@ -1,10 +1,12 @@ +from itertools import cycle import logging import random -from typing import Any, Callable +from typing import Any, Callable, Iterable import click from dotenv import load_dotenv from langchain.globals import set_verbose, set_debug import pydantic +from .const import CLI_PROMPT_COLOR, CLI_ECHO_COLORS from .enums import ESystemOutputType, EInputOutputType from .game.main import create_game_graph from .models.config import Config, GeneralConfig @@ -26,6 +28,8 @@ system_output_interface=EInputOutputType.standard, system_input_interface=EInputOutputType.standard, system_formatter=None, + system_font_color=CLI_PROMPT_COLOR, + player_font_colors=cycle(CLI_ECHO_COLORS), seed=-1, model='gpt-4o-mini', recursion_limit=1000, @@ -46,6 +50,8 @@ def main( system_output_interface: Callable[[str], None] | EInputOutputType = DEFAULT_GENERAL_CONFIG.system_output_interface, # type: ignore # noqa system_input_interface: Callable[[str], Any] | EInputOutputType = DEFAULT_GENERAL_CONFIG.system_input_interface, # type: ignore # noqa system_formatter: str | None = DEFAULT_GENERAL_CONFIG.system_formatter, # type: ignore # noqa + system_font_color: str | None = DEFAULT_GENERAL_CONFIG.system_font_color, # type: ignore # noqa + player_font_colors: Iterable[str] | str | None = DEFAULT_GENERAL_CONFIG.player_font_colors, # type: ignore # noqa config: Config | str | None = None, seed: int = DEFAULT_GENERAL_CONFIG.seed, # type: ignore # noqa model: str = DEFAULT_GENERAL_CONFIG.model, # type: ignore # noqa @@ -79,6 +85,8 @@ def main( system_input_interface=config.general.system_input_interface if config.general.system_input_interface is not None else system_input_interface, # noqa system_output_interface=config.general.system_output_interface if config.general.system_output_interface is not None else system_output_interface, # noqa system_formatter=config.general.system_formatter if config.general.system_formatter is not None else system_formatter, # noqa + system_font_color=config.general.system_font_color if config.general.system_font_color is not None else system_font_color, # noqa + player_font_colors=config.general.player_font_colors if config.general.player_font_colors is not None else player_font_colors, # noqa seed=config.general.seed if config.general.seed is not None else seed, # noqa model=config.general.model if config.general.model is not None else model, # noqa recursion_limit=config.general.recursion_limit if config.general.recursion_limit is not None else recursion_limit, # noqa @@ -121,6 +129,8 @@ def main( players=players, model=config_used.general.model, # type: ignore system_formatter=config_used.general.system_formatter, # type: ignore # noqa + system_color=config_used.general.system_font_color, # type: ignore # noqa + player_colors=config_used.general.player_font_colors, # type: ignore # noqa seed=config_used.general.seed, # type: ignore ), ) diff --git a/langchain_werewolf/models/config.py b/langchain_werewolf/models/config.py index 2e349b7..f113a12 100644 --- a/langchain_werewolf/models/config.py +++ b/langchain_werewolf/models/config.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, Iterable from pydantic import BaseModel, Field from ..const import DEFAULT_MODEL, CUSTOM_PLAYER_PREFIX from ..enums import ( @@ -22,6 +22,8 @@ class GeneralConfig(BaseModel, frozen=True): system_output_interface: Callable[[str], None] | EInputOutputType | None = Field(default=None, title="The system output interface. Default is None.") # noqa system_input_interface: Callable[[str], Any] | EInputOutputType | None = Field(default=None, title="The system input interface. Default is None.") # noqa system_formatter: str | None = Field(default=None, title="The system formatter. The format should not include anything other than " + ', '.join('"{'+k+'}"' for k in MsgModel.model_fields.keys())) # noqa + system_font_color: str | None = Field(default=None, title="The system font color. Default is None.") # noqa + player_font_colors: Iterable | str | None = Field(default=None, title="The player font colors. Default is None.") # noqa seed: int | None = Field(default=None, title="The random seed. Defaults to None.") # noqa model: str | None = Field(default=None, title=f"The model to use. Default is None.") # noqa recursion_limit: int | None = Field(default=None, title="The recursion limit. Default is None.") # noqa diff --git a/langchain_werewolf/setup.py b/langchain_werewolf/setup.py index 0713d2c..a67cc40 100644 --- a/langchain_werewolf/setup.py +++ b/langchain_werewolf/setup.py @@ -338,14 +338,16 @@ def create_echo_runnable( players: Iterable[BaseGamePlayer] = tuple(), model: str = DEFAULT_MODEL, system_formatter: Callable[[MsgModel], str] | str | None = None, - system_color: str = CLI_PROMPT_COLOR, - player_colors: Iterable[str] = cycle(CLI_ECHO_COLORS), + system_color: str | None = CLI_PROMPT_COLOR, + player_colors: Iterable[str] | str | None = cycle(CLI_ECHO_COLORS), language: ELanguage = BASE_LANGUAGE, seed: int = -1, ) -> Runnable[StateModel, None]: # initialize player_names: list[str] = [player.name for player in players] - player_colors_ = {player.name: color for player, color in zip(players, player_colors)} # noqa + player_colors = player_colors or cycle([None]) + player_colors = [player_colors] if isinstance(player_colors, str) else player_colors # noqa + player_colors_ = {player.name: color or None for player, color in zip(players, player_colors)} # noqa # create cache caches: dict[str, set[str]] = ( {player.name: {''} for player in players}