From a5ecd166859e9d956113e64fd865440ee9cc10e1 Mon Sep 17 00:00:00 2001 From: hmasdev Date: Sun, 27 Oct 2024 21:37:49 +0900 Subject: [PATCH] fix _generate_base_runnable in setup.py to make a custom input interface available for each player --- langchain_werewolf/setup.py | 49 +++++++++++++++++++------------ tests/test_setup.py | 57 ++++++++++++++++++++++++------------- 2 files changed, 68 insertions(+), 38 deletions(-) diff --git a/langchain_werewolf/setup.py b/langchain_werewolf/setup.py index 7659637..78c737e 100644 --- a/langchain_werewolf/setup.py +++ b/langchain_werewolf/setup.py @@ -17,7 +17,6 @@ from .const import ( BASE_LANGUAGE, CLI_PROMPT_COLOR, - CLI_PROMPT_SUFFIX, CLI_ECHO_COLORS, DEFAULT_MODEL, DEFAULT_PLAYER_PREFIX, @@ -49,13 +48,30 @@ def _generate_base_runnable( model: str | None, input_func: Callable[[str], Any] | EInputOutputType | None = None, seed: int | None = None, + *, + logger: Logger = getLogger(__name__), ) -> BaseChatModel | Runnable[str, str]: - if model is None: - return create_chat_model( - DEFAULT_MODEL, - seed=seed if seed is not None and seed >= 0 else None, - ) - elif ( + """Generate a BaseChatModel instance or a Runnable instance. + + Args: + model (str | None): model string + input_func (Callable[[str], Any] | EInputOutputType | None, optional): input function. Defaults to None. + seed (int | None, optional): random seed. Defaults to None. + logger (Logger, optional): logger. Defaults to getLogger(__name__). + + Raises: + ValueError: model not in MODEL_SERVICE_MAP, input_func is None and model is not None + + Returns: + BaseChatModel | Runnable[str, str]: runnable instance for generate a string + + NOTE: + priority: + 1. models in MODEL_SERVICE_MAP + 2. input_func is not None + 3. DEFAULT_MODEL + """ # noqa + if ( model in MODEL_SERVICE_MAP and MODEL_SERVICE_MAP[model] in { EChatService.OpenAI, @@ -67,18 +83,15 @@ def _generate_base_runnable( model, seed=seed if seed is not None and seed >= 0 else None, ) - elif ( - model in MODEL_SERVICE_MAP - and MODEL_SERVICE_MAP[model] == EChatService.CLI - and input_func is not None - ): - return create_input_runnable( - input_func=input_func, - styler=partial(click.style, fg=CLI_PROMPT_COLOR), - prompt_suffix=CLI_PROMPT_SUFFIX, - ) + elif input_func is not None: + return create_input_runnable(input_func=input_func) else: - raise ValueError(f'Unsupported: model={model}, input_func={input_func}') # noqa + if model is not None: + logger.warning(f'Invalid a pair of model={model} and input_func={input_func}. So use a chat model with {DEFAULT_MODEL}') # noqa + return create_chat_model( + DEFAULT_MODEL, + seed=seed if seed is not None and seed >= 0 else None, + ) def generate_players( diff --git a/tests/test_setup.py b/tests/test_setup.py index 729f041..5a5e82f 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -1,7 +1,5 @@ -from functools import partial from typing import Callable from unittest import mock -import click from langchain_core.language_models import BaseChatModel from langchain_core.runnables import Runnable, RunnableLambda import pytest @@ -9,8 +7,6 @@ from langchain_werewolf.const import ( BASE_LANGUAGE, DEFAULT_MODEL, - CLI_PROMPT_COLOR, - CLI_PROMPT_SUFFIX, GAME_MASTER_NAME, ) from langchain_werewolf.enums import ( @@ -19,7 +15,10 @@ ESystemOutputType, ETimeSpan, ) -from langchain_werewolf.game_players.base import BaseGamePlayer +from langchain_werewolf.game_players.base import ( + BaseGamePlayer, + GamePlayerRunnableInputModel, +) from langchain_werewolf.models.config import PlayerConfig from langchain_werewolf.models.state import ( ChatHistoryModel, @@ -150,8 +149,6 @@ def test__generate_base_runnable_with_cli_player_config( expected_args = [] # type: ignore expected_kwargs = { 'input_func': player_config.player_input_interface, - 'styler': partial(click.style, fg=CLI_PROMPT_COLOR), - 'prompt_suffix': CLI_PROMPT_SUFFIX, } create_input_runnable_mock = mocker.patch( 'langchain_werewolf.setup.create_input_runnable', @@ -164,19 +161,6 @@ def test__generate_base_runnable_with_cli_player_config( actual_args, actual_kwargs = create_input_runnable_mock.call_args assert actual_args == tuple(expected_args) assert actual_kwargs['input_func'] == expected_kwargs['input_func'] - assert actual_kwargs['prompt_suffix'] == expected_kwargs['prompt_suffix'] - assert all([ - actual_kwargs['styler'].func == expected_kwargs['styler'].func, # type: ignore # noqa - actual_kwargs['styler'].args == expected_kwargs['styler'].args, # type: ignore # noqa - actual_kwargs['styler'].keywords == expected_kwargs['styler'].keywords, # type: ignore # noqa - ]) - # NOTE: partial object is not equal to another partial object with the same function and arguments # noqa - - -def test__generate_base_runnable_with_unsupported_config() -> None: - # execution - with pytest.raises(ValueError): - _generate_base_runnable('cli', None) @pytest.mark.parametrize( @@ -266,6 +250,39 @@ def test_generate_players(mocker: MockerFixture) -> None: assert sum([player.role == ERole.Villager for player in actual]) == n_players - n_werewolves - n_knights - n_fortune_tellers # noqa +def test_generate_players_with_custom_input_interface(mocker: MockerFixture) -> None: # noqa + + # patch to avoid creating a real chat model + mocker.patch('langchain_werewolf.setup.create_chat_model', mocker.Mock(return_value=mocker.Mock(BaseChatModel))) # noqa + # create mocks for the input interfaces of players + mocks_input_interface = [mocker.Mock(side_effect=lambda s: s) for _ in range(4)] # noqa + + n_players = 4 + n_werewolves = 1 + n_knights = 0 + n_fortune_tellers = 0 + roles = [ERole.Werewolf, ERole.Villager, ERole.Villager, ERole.Villager] + custom_players = [ + PlayerConfig(role=role, model='', player_input_interface=mock, formatter=None) # noqa + for role, mock in zip(roles, mocks_input_interface) + ] + # execute + actual = generate_players( + n_players, + n_werewolves, + n_knights, + n_fortune_tellers, + seed=0, + player_input_interface=None, + custom_players=custom_players, + ) + # assert + input_for_player_runnable = GamePlayerRunnableInputModel(prompt='test', system_prompt='test2') # noqa + for mock_, player in zip(mocks_input_interface, actual): + player.runnable.invoke(input_for_player_runnable) + mock_.assert_called_once_with(input_for_player_runnable.prompt) + + @pytest.mark.parametrize( 'player_name, formatter', [