Skip to content

Commit

Permalink
Merge pull request #35 from hmasdev/apply-mypy
Browse files Browse the repository at this point in the history
Refactor Field definitions in model classes for consistency
  • Loading branch information
hmasdev authored Nov 26, 2024
2 parents 1f7d107 + e8ecb3e commit 8e87b30
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion langchain_werewolf/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class GeneralConfig(BaseModel, frozen=True):


class PlayerConfig(BaseModel, frozen=True):
name: str = Field(..., title="The name of the player", default_factory=consecutive_string_generator(CUSTOM_PLAYER_PREFIX).__next__) # noqa
name: str = Field(title="The name of the player", default_factory=consecutive_string_generator(CUSTOM_PLAYER_PREFIX).__next__) # noqa
role: ERole | None = Field(default=None, title="The role of the player") # noqa
model: str = Field(default=DEFAULT_MODEL, title=f"The model to use. Default is {DEFAULT_MODEL}.") # noqa
language: ELanguage | None = Field(default=None, title="The language of the player") # noqa
Expand Down
4 changes: 2 additions & 2 deletions langchain_werewolf/models/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class IdentifiedModel(PartialFrozenModel, Generic[T]):
# FIXME: frozen_fields should be merged with the parent class's frozen_fields # noqa
frozen_fields: Annotated[set[str], constant_reducer] = {'frozen_fields', 'id'} # noqa

id: str = Field(..., title="object id", default_factory=_generate_unique_string) # noqa
value: T = Field(..., title="the value of the model")
id: str = Field(title="object id", default_factory=_generate_unique_string) # noqa
value: T = Field(title="the value of the model")


def reduce_dict(
Expand Down
18 changes: 9 additions & 9 deletions langchain_werewolf/models/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class MsgModel(BaseModel):
name: str \
= Field(..., title="the name of the player")
timestamp: datetime \
= Field(..., title="Timestamp", default_factory=datetime.now)
= Field(title="Timestamp", default_factory=datetime.now)
message: str \
= Field(..., title="Message")
participants: frozenset[str] \
= Field(..., title="the names of the participants", default_factory=frozenset) # noqa
= Field(title="the names of the participants", default_factory=frozenset) # noqa
template: str \
= Field(
default='\n'.join([
Expand Down Expand Up @@ -60,7 +60,7 @@ class ChatHistoryModel(PartialFrozenModel):
names: frozenset[str]\
= Field(..., title="the names of the chat participants")
messages: list[IdentifiedModel[MsgModel]]\
= Field(..., title="Chat Messages", default_factory=list)
= Field(title="Chat Messages", default_factory=list)

@field_serializer('names')
def serialize_names(self, value: frozenset[str]) -> list[str]:
Expand Down Expand Up @@ -124,13 +124,13 @@ class StateModel(PartialFrozenModel):
chat_state: Annotated[
dict[frozenset[str], ChatHistoryModel],
_reduce_chat_state,
] = Field(..., title="the chat state", default_factory=dict) # noqa
] = Field(title="the chat state", default_factory=dict) # noqa

# players information
alive_players_names: Annotated[list[str], overwrite_reducer]\
= Field(..., title="the names of the alive players")
safe_players_names: Annotated[set[str], overwrite_reducer]\
= Field(..., title="the names of the safe players", default_factory=set) # noqa
= Field(title="the names of the safe players", default_factory=set) # noqa

# game chat information
current_speaker: Annotated[str | None, overwrite_reducer]\
Expand All @@ -141,13 +141,13 @@ class StateModel(PartialFrozenModel):
# vote information
# TODO: modify the type of daytime_votes_history and nighttime_votes_history: dict to dict[str, str] # noqa
daytime_vote_result_history: Annotated[list[IdentifiedModel[str | None]], reduce_list]\
= Field(..., title="the history of the daytime vote results", default_factory=list) # noqa
= Field(title="the history of the daytime vote results", default_factory=list) # noqa
daytime_votes_history: Annotated[list[IdentifiedModel[dict]], reduce_list]\
= Field(..., title="the votes of each daytime discussion", default_factory=list) # noqa
= Field(title="the votes of each daytime discussion", default_factory=list) # noqa
nighttime_vote_result_history: Annotated[list[IdentifiedModel[str | None]], reduce_list]\
= Field(..., title="the history of the nighttime vote results", default_factory=list) # noqa
= Field(title="the history of the nighttime vote results", default_factory=list) # noqa
nighttime_votes_history: Annotated[list[IdentifiedModel[dict]], reduce_list]\
= Field(..., title="the votes of each nighttime discussion", default_factory=list) # noqa
= Field(title="the votes of each nighttime discussion", default_factory=list) # noqa
daytime_votes_current: Annotated[dict[str, str], _reduce_votes_current]\
= Field(default_factory=dict, title="the current daytime votes")
nighttime_votes_current: Annotated[dict[str, str], _reduce_votes_current]\
Expand Down

0 comments on commit 8e87b30

Please sign in to comment.