From 32e8cf7daf73199027f2d802d1a580b85c31fee2 Mon Sep 17 00:00:00 2001 From: Filip Karnis Date: Tue, 1 Oct 2024 19:23:48 +0200 Subject: [PATCH] refactor: Make default default reward methods static --- generals/integrations/gymnasium_integration.py | 7 ++++--- generals/integrations/pettingzoo_integration.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/generals/integrations/gymnasium_integration.py b/generals/integrations/gymnasium_integration.py index 05f3df4..44a730d 100644 --- a/generals/integrations/gymnasium_integration.py +++ b/generals/integrations/gymnasium_integration.py @@ -31,7 +31,7 @@ def __init__( self.replay = None self.render_mode = render_mode - self.reward_fn = self.default_reward if reward_fn is None else reward_fn + self.reward_fn = self._default_reward if reward_fn is None else reward_fn self.grid_factory = grid_factory self.agent_name = agent.name @@ -118,8 +118,9 @@ def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, return observation, reward, terminated, truncated, info - def default_reward( - self, observation: dict[str, Observation], + @staticmethod + def _default_reward( + observation: dict[str, Observation], action: Action, done: bool, info: Info, diff --git a/generals/integrations/pettingzoo_integration.py b/generals/integrations/pettingzoo_integration.py index 0b66c65..3941d6a 100644 --- a/generals/integrations/pettingzoo_integration.py +++ b/generals/integrations/pettingzoo_integration.py @@ -47,7 +47,7 @@ def __init__( len(self.possible_agents) == len(set(self.possible_agents)) ), "Agent names must be unique - you can pass custom names to agent constructors." - self.reward_fn = self.default_reward if reward_fn is None else reward_fn + self.reward_fn = self._default_reward if reward_fn is None else reward_fn @functools.lru_cache(maxsize=None) def observation_space(self, agent: AgentID) -> spaces.Space: @@ -126,8 +126,8 @@ def step(self, actions: dict[AgentID, Action]) -> tuple[ return observations, rewards, terminated, truncated, infos - def default_reward( - self, + @staticmethod + def _default_reward( observation: dict[str, Observation], action: Action, done: bool,