Skip to content

Commit

Permalink
fix _generate_base_runnable in setup.py to make a custom input interf…
Browse files Browse the repository at this point in the history
…ace available for each player
  • Loading branch information
hmasdev committed Oct 27, 2024
1 parent 7251c6c commit a5ecd16
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 38 deletions.
49 changes: 31 additions & 18 deletions langchain_werewolf/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .const import (
BASE_LANGUAGE,
CLI_PROMPT_COLOR,
CLI_PROMPT_SUFFIX,
CLI_ECHO_COLORS,
DEFAULT_MODEL,
DEFAULT_PLAYER_PREFIX,
Expand Down Expand Up @@ -49,13 +48,30 @@ def _generate_base_runnable(
model: str | None,
input_func: Callable[[str], Any] | EInputOutputType | None = None,
seed: int | None = None,
*,
logger: Logger = getLogger(__name__),
) -> BaseChatModel | Runnable[str, str]:
if model is None:
return create_chat_model(
DEFAULT_MODEL,
seed=seed if seed is not None and seed >= 0 else None,
)
elif (
"""Generate a BaseChatModel instance or a Runnable instance.
Args:
model (str | None): model string
input_func (Callable[[str], Any] | EInputOutputType | None, optional): input function. Defaults to None.
seed (int | None, optional): random seed. Defaults to None.
logger (Logger, optional): logger. Defaults to getLogger(__name__).
Raises:
ValueError: model not in MODEL_SERVICE_MAP, input_func is None and model is not None
Returns:
BaseChatModel | Runnable[str, str]: runnable instance for generate a string
NOTE:
priority:
1. models in MODEL_SERVICE_MAP
2. input_func is not None
3. DEFAULT_MODEL
""" # noqa
if (
model in MODEL_SERVICE_MAP
and MODEL_SERVICE_MAP[model] in {
EChatService.OpenAI,
Expand All @@ -67,18 +83,15 @@ def _generate_base_runnable(
model,
seed=seed if seed is not None and seed >= 0 else None,
)
elif (
model in MODEL_SERVICE_MAP
and MODEL_SERVICE_MAP[model] == EChatService.CLI
and input_func is not None
):
return create_input_runnable(
input_func=input_func,
styler=partial(click.style, fg=CLI_PROMPT_COLOR),
prompt_suffix=CLI_PROMPT_SUFFIX,
)
elif input_func is not None:
return create_input_runnable(input_func=input_func)
else:
raise ValueError(f'Unsupported: model={model}, input_func={input_func}') # noqa
if model is not None:
logger.warning(f'Invalid a pair of model={model} and input_func={input_func}. So use a chat model with {DEFAULT_MODEL}') # noqa
return create_chat_model(
DEFAULT_MODEL,
seed=seed if seed is not None and seed >= 0 else None,
)


def generate_players(
Expand Down
57 changes: 37 additions & 20 deletions tests/test_setup.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from functools import partial
from typing import Callable
from unittest import mock
import click
from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import Runnable, RunnableLambda
import pytest
from pytest_mock import MockerFixture
from langchain_werewolf.const import (
BASE_LANGUAGE,
DEFAULT_MODEL,
CLI_PROMPT_COLOR,
CLI_PROMPT_SUFFIX,
GAME_MASTER_NAME,
)
from langchain_werewolf.enums import (
Expand All @@ -19,7 +15,10 @@
ESystemOutputType,
ETimeSpan,
)
from langchain_werewolf.game_players.base import BaseGamePlayer
from langchain_werewolf.game_players.base import (
BaseGamePlayer,
GamePlayerRunnableInputModel,
)
from langchain_werewolf.models.config import PlayerConfig
from langchain_werewolf.models.state import (
ChatHistoryModel,
Expand Down Expand Up @@ -150,8 +149,6 @@ def test__generate_base_runnable_with_cli_player_config(
expected_args = [] # type: ignore
expected_kwargs = {
'input_func': player_config.player_input_interface,
'styler': partial(click.style, fg=CLI_PROMPT_COLOR),
'prompt_suffix': CLI_PROMPT_SUFFIX,
}
create_input_runnable_mock = mocker.patch(
'langchain_werewolf.setup.create_input_runnable',
Expand All @@ -164,19 +161,6 @@ def test__generate_base_runnable_with_cli_player_config(
actual_args, actual_kwargs = create_input_runnable_mock.call_args
assert actual_args == tuple(expected_args)
assert actual_kwargs['input_func'] == expected_kwargs['input_func']
assert actual_kwargs['prompt_suffix'] == expected_kwargs['prompt_suffix']
assert all([
actual_kwargs['styler'].func == expected_kwargs['styler'].func, # type: ignore # noqa
actual_kwargs['styler'].args == expected_kwargs['styler'].args, # type: ignore # noqa
actual_kwargs['styler'].keywords == expected_kwargs['styler'].keywords, # type: ignore # noqa
])
# NOTE: partial object is not equal to another partial object with the same function and arguments # noqa


def test__generate_base_runnable_with_unsupported_config() -> None:
# execution
with pytest.raises(ValueError):
_generate_base_runnable('cli', None)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -266,6 +250,39 @@ def test_generate_players(mocker: MockerFixture) -> None:
assert sum([player.role == ERole.Villager for player in actual]) == n_players - n_werewolves - n_knights - n_fortune_tellers # noqa


def test_generate_players_with_custom_input_interface(mocker: MockerFixture) -> None: # noqa

# patch to avoid creating a real chat model
mocker.patch('langchain_werewolf.setup.create_chat_model', mocker.Mock(return_value=mocker.Mock(BaseChatModel))) # noqa
# create mocks for the input interfaces of players
mocks_input_interface = [mocker.Mock(side_effect=lambda s: s) for _ in range(4)] # noqa

n_players = 4
n_werewolves = 1
n_knights = 0
n_fortune_tellers = 0
roles = [ERole.Werewolf, ERole.Villager, ERole.Villager, ERole.Villager]
custom_players = [
PlayerConfig(role=role, model='', player_input_interface=mock, formatter=None) # noqa
for role, mock in zip(roles, mocks_input_interface)
]
# execute
actual = generate_players(
n_players,
n_werewolves,
n_knights,
n_fortune_tellers,
seed=0,
player_input_interface=None,
custom_players=custom_players,
)
# assert
input_for_player_runnable = GamePlayerRunnableInputModel(prompt='test', system_prompt='test2') # noqa
for mock_, player in zip(mocks_input_interface, actual):
player.runnable.invoke(input_for_player_runnable)
mock_.assert_called_once_with(input_for_player_runnable.prompt)


@pytest.mark.parametrize(
'player_name, formatter',
[
Expand Down

0 comments on commit a5ecd16

Please sign in to comment.