diff --git a/generals/core/replay.py b/generals/core/replay.py index 64d17f4..e623a6b 100644 --- a/generals/core/replay.py +++ b/generals/core/replay.py @@ -3,6 +3,7 @@ from generals.core.grid import Grid from generals.core.game import Game +from generals.gui.event_handler import ReplayCommand from generals.gui import GUI from copy import deepcopy @@ -33,7 +34,7 @@ def load(cls, path): def play(self): agents = [agent for agent in self.agent_data.keys()] game = Game(self.grid, agents) - gui = GUI(game, self.agent_data, from_replay=True) + gui = GUI(game, self.agent_data, mode="replay") gui_properties = gui.properties game_step, last_input_time, last_move_time = 0, 0, 0 @@ -41,15 +42,20 @@ def play(self): _t = time.time() # Check inputs if _t - last_input_time > 0.008: # check for input every 8ms - control_events = gui.tick() + command = gui.tick() last_input_time = _t else: - control_events = {"time_change": 0} - if "restart" in control_events: + command = ReplayCommand() + if command.quit: + import pygame + + pygame.quit() + quit() + if command.restart: game_step = 0 # If we control replay, change game state game_step = max( - 0, min(len(self.game_states) - 1, game_step + control_events["time_change"]) + 0, min(len(self.game_states) - 1, game_step + command.frame_change) ) if gui_properties.paused and game_step != game.time: game.channels = deepcopy(self.game_states[game_step]) @@ -57,7 +63,7 @@ def play(self): last_move_time = _t # If we are not paused, play the game elif ( - _t - last_move_time > gui_properties.game_speed * 0.512 + _t - last_move_time > (1/gui_properties.game_speed) * 0.512 and not gui_properties.paused ): if game.is_done(): diff --git a/generals/envs/gymnasium_integration.py b/generals/envs/gymnasium_integration.py index e9b4942..49c98d0 100644 --- a/generals/envs/gymnasium_integration.py +++ b/generals/envs/gymnasium_integration.py @@ -58,11 +58,13 @@ def action_space(self) -> gym.Space: def render(self, fps: int = 6) -> None: if self.render_mode == "human": - self.gui.tick(fps=fps) + command = self.gui.tick(fps=fps) + if command.quit: + self.close() - def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[ - Observation, dict[str, Any] - ]: + def reset( + self, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Observation, dict[str, Any]]: if options is None: options = {} super().reset(seed=seed) @@ -95,7 +97,9 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) info = {} return observation, info - def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: + def step( + self, action: Action + ) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: # get action of NPC npc_action = self.npc.play(self.game._agent_observation(self.npc.name)) actions = {self.agent_name: action, self.npc.name: npc_action} @@ -133,3 +137,9 @@ def _default_reward( else: reward = 0 return reward + + def close(self) -> None: + import pygame + + pygame.quit() + quit() diff --git a/generals/envs/pettingzoo_integration.py b/generals/envs/pettingzoo_integration.py index 9afa322..6f9d2e6 100644 --- a/generals/envs/pettingzoo_integration.py +++ b/generals/envs/pettingzoo_integration.py @@ -59,7 +59,9 @@ def action_space(self, agent: AgentID) -> spaces.Space: def render(self, fps=6) -> None: if self.render_mode == "human": - self.gui.tick(fps=fps) + command = self.gui.tick(fps=fps) + if command.quit: + self.close() def reset( self, seed: int | None = None, options: dict | None = None @@ -76,7 +78,7 @@ def reset( self.game = Game(grid, self.agents) if self.render_mode == "human": - self.gui = GUI(self.game, self.agent_data) + self.gui = GUI(self.game, self.agent_data, "train") if "replay_file" in options: self.replay = Replay( @@ -146,5 +148,8 @@ def _default_reward( reward = 0 return reward - def close(self): - print("Closing environment") + def close(self) -> None: + import pygame + + pygame.quit() + quit() diff --git a/generals/gui/event_handler.py b/generals/gui/event_handler.py index fc56a62..82bc11d 100644 --- a/generals/gui/event_handler.py +++ b/generals/gui/event_handler.py @@ -1,77 +1,179 @@ import pygame +from pygame.event import Event from .properties import Properties +from generals.core import config as c + +###################### +# Replay keybindings # +###################### +RIGHT = pygame.K_RIGHT +LEFT = pygame.K_LEFT +SPACE = pygame.K_SPACE +Q = pygame.K_q +R = pygame.K_r +H = pygame.K_h +L = pygame.K_l + + +class Command: + def __init__(self): + self.quit: bool = False + + +class ReplayCommand(Command): + def __init__(self): + super().__init__() + self.frame_change: int = 0 + self.speed_change: float = 1.0 + self.restart: bool = False + self.pause: bool = False + + +class GameCommand(Command): + def __init__(self): + super().__init__() + raise NotImplementedError + + +class TrainCommand(Command): + def __init__(self): + super().__init__() class EventHandler: - def __init__(self, properties: Properties, from_replay=False): + def __init__(self, properties: Properties): """ Initialize the event handler. Args: properties: the Properties object - from_replay: bool, whether the game is from a replay """ self.properties = properties - self.from_replay = from_replay + self.mode = properties.mode + self.handler_fn = self.initialize_handler_fn() + self.command = self.initialize_command() + + def initialize_handler_fn(self): + """ + Initialize the handler function based on the mode. + """ + if self.mode == "replay": + return self.__handle_replay_key_controls + elif self.mode == "game": + return self.__handle_game_key_controls + elif self.mode == "train": + return self.__handle_train_key_controls + raise ValueError("Invalid mode") - def handle_events(self): + def initialize_command(self): + """ + Initialize the command type based on the mode. + """ + if self.mode == "replay": + return ReplayCommand + elif self.mode == "game": + return GameCommand + elif self.mode == "train": + return TrainCommand + raise ValueError("Invalid mode") + + def handle_events(self) -> Command: """ Handle pygame GUI events """ - control_events = { - "time_change": 0, - } + command = self.command() for event in pygame.event.get(): - if event.type == pygame.QUIT or ( - event.type == pygame.KEYDOWN and event.key == pygame.K_q - ): - pygame.quit() - quit() - - if event.type == pygame.KEYDOWN and self.from_replay: - self.__handle_key_controls(event, control_events) + if event.type == pygame.QUIT: + command.quit = True + if event.type == pygame.KEYDOWN: + command = self.handler_fn(event, command) elif event.type == pygame.MOUSEBUTTONDOWN: self.__handle_mouse_click() + return command - return control_events - - - def __handle_key_controls(self, event, control_events): + def __handle_replay_key_controls(self, event: Event, command: Command) -> Command: """ Handle key controls for replay mode. Control game speed, pause, and replay frames. """ - match event.key: - # Speed up game right arrow is pressed - case pygame.K_RIGHT: - self.properties.game_speed = max(1 / 128, self.properties.game_speed / 2) - # Slow down game left arrow is pressed - case pygame.K_LEFT: - self.properties.game_speed = min(32.0, self.properties.game_speed * 2) - # Toggle play/pause - case pygame.K_SPACE: - self.properties.paused = not self.properties.paused - case pygame.K_r: - control_events["restart"] = True - # Control replay frames - case pygame.K_h: - control_events["time_change"] = -1 - self.properties.paused = True - case pygame.K_l: - control_events["time_change"] = 1 - self.properties.paused = True + if event.key == Q: + command.quit = True + elif event.key == RIGHT: + command.speed_change = 2.0 + elif event.key == LEFT: + command.speed_change = 0.5 + elif event.key == SPACE: + command.pause = True + elif event.key == R: + command.restart = True + command.pause = True + elif event.key == H: + command.frame_change = -1 + command.pause = True + elif event.key == L: + command.frame_change = 1 + self.properties.paused = True + return command + + def __handle_game_key_controls( + self, event: Event, command: Command + ) -> dict[str, any]: + raise NotImplementedError + def __handle_train_key_controls( + self, event: Event, command: Command + ) -> dict[str, any]: + if event.key == Q: + command.quit = True + return command def __handle_mouse_click(self): """ Handle mouse click event. """ + if self.properties.mode == "replay": + self.__handle_replay_clicks() + elif self.properties.mode == "game": + self.__handle_game_clicks() + elif self.properties.mode == "train": + self.__handle_train_clicks() + + def __handle_game_clicks(self): + """ + Handle mouse clicks in game mode. + """ + pass + + def __handle_train_clicks(self): + """ + Handle mouse clicks in training mode. + """ + pass + + def __handle_replay_clicks(self): + """ + Handle mouse clicks in replay mode. + """ agents = self.properties.game.agents agent_fov = self.properties.agent_fov x, y = pygame.mouse.get_pos() for i, agent in enumerate(agents): - if self.properties.is_click_on_agents_row(x, y, i): + if self.is_click_on_agents_row(x, y, i): agent_fov[agent] = not agent_fov[agent] break + + def is_click_on_agents_row(self, x: int, y: int, i: int) -> bool: + """ + Check if the click is on an agent's row. + + Args: + x: int, x-coordinate of the click + y: int, y-coordinate of the click + i: int, index of the row + """ + return ( + x >= self.properties.display_grid_width + and (i + 1) * c.GUI_ROW_HEIGHT <= y < (i + 2) * c.GUI_ROW_HEIGHT + ) diff --git a/generals/gui/gui.py b/generals/gui/gui.py index 43f24da..e51033d 100644 --- a/generals/gui/gui.py +++ b/generals/gui/gui.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal from generals.core.game import Game from .properties import Properties @@ -8,13 +8,20 @@ class GUI: def __init__( - self, game: Game, agent_data: dict[str, dict[str, Any]], from_replay=False + self, + game: Game, + agent_data: dict[str, dict[str, Any]], + mode: Literal["train", "game", "replay"] = "train", ): - self.properties = Properties(game, agent_data) + self.properties = Properties(game, agent_data, mode) self.__renderer = Renderer(self.properties) - self.__event_handler = EventHandler(self.properties, from_replay) + self.__event_handler = EventHandler(self.properties) def tick(self, fps=None): - control_events = self.__event_handler.handle_events() + command = self.__event_handler.handle_events() + if self.properties.mode == "replay": + command = self.__event_handler.handle_replay_command(command) + self.properties.update_speed(command.speed_change) + self.properties.paused = command.pause self.__renderer.render(fps) - return control_events + return command diff --git a/generals/gui/properties.py b/generals/gui/properties.py index 6bc43d2..f620aa2 100644 --- a/generals/gui/properties.py +++ b/generals/gui/properties.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Literal from pygame.time import Clock @@ -11,7 +11,7 @@ class Properties: __game: Game __agent_data: dict[str, dict[str, Any]] - __paused: bool = False + __mode: Literal["train", "game", "replay"] __game_speed: int = 1 __clock: Clock = Clock() @@ -22,21 +22,11 @@ def __post_init__(self): self.__display_grid_height: int = c.SQUARE_SIZE * self.grid_height self.__right_panel_width: int = 4 * c.GUI_CELL_WIDTH - self.__agent_fov: dict[str, bool] = {name: True for name in self.agent_data.keys()} + self.__paused: bool = False - def is_click_on_agents_row(self, x: int, y: int, i: int) -> bool: - """ - Check if the click is on an agent's row. - - Args: - x: int, x-coordinate of the click - y: int, y-coordinate of the click - i: int, index of the row - """ - return ( - x >= self.display_grid_width - and (i + 1) * c.GUI_ROW_HEIGHT <= y < (i + 2) * c.GUI_ROW_HEIGHT - ) + self.__agent_fov: dict[str, bool] = { + name: True for name in self.agent_data.keys() + } @property def game(self): @@ -46,6 +36,10 @@ def game(self): def agent_data(self): return self.__agent_data + @property + def mode(self): + return self.__mode + @property def paused(self): return self.__paused @@ -59,8 +53,9 @@ def game_speed(self): return self.__game_speed @game_speed.setter - def game_speed(self, value: int): - self.__game_speed = value + def game_speed(self, value: float): + new_speed = min(32.0, max(0.25, value)) # clip speed + self.__game_speed = new_speed @property def clock(self): @@ -89,3 +84,8 @@ def display_grid_height(self): @property def right_panel_width(self): return self.__right_panel_width + + def update_speed(self, change: float) -> None: + """change: multiplier how much to change, usually 2.0 or 0.5""" + new_speed = self.game_speed * change + self.game_speed = new_speed diff --git a/generals/gui/rendering.py b/generals/gui/rendering.py index 334d92a..4cc41a6 100644 --- a/generals/gui/rendering.py +++ b/generals/gui/rendering.py @@ -16,6 +16,7 @@ def __init__(self, properties: Properties): self.properties = properties + self.mode = self.properties.mode self.game = self.properties.game self.agent_data = self.properties.agent_data @@ -140,7 +141,9 @@ def render_stats(self): info_text = { "time": f"Time: {str(self.game.time // 2) + ('.' if self.game.time % 2 == 1 else '')}", - "speed": "Paused" if self.properties.paused else f"Speed: {str(1 / self.properties.game_speed)}x", + "speed": "Paused" + if self.mode == "replay" and self.properties.paused + else f"Speed: {str(self.properties.game_speed)}x", } # Write additional info