Skip to content

Commit

Permalink
refactor: Improve handle events
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 2, 2024
1 parent d7f9924 commit 4258b79
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 79 deletions.
18 changes: 12 additions & 6 deletions generals/core/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from generals.core.grid import Grid
from generals.core.game import Game
from generals.gui.event_handler import ReplayCommand
from generals.gui import GUI
from copy import deepcopy

Expand Down Expand Up @@ -33,31 +34,36 @@ def load(cls, path):
def play(self):
agents = [agent for agent in self.agent_data.keys()]
game = Game(self.grid, agents)
gui = GUI(game, self.agent_data, from_replay=True)
gui = GUI(game, self.agent_data, mode="replay")
gui_properties = gui.properties

game_step, last_input_time, last_move_time = 0, 0, 0
while 1:
_t = time.time()
# Check inputs
if _t - last_input_time > 0.008: # check for input every 8ms
control_events = gui.tick()
command = gui.tick()
last_input_time = _t
else:
control_events = {"time_change": 0}
if "restart" in control_events:
command = ReplayCommand()
if command.quit:
import pygame

pygame.quit()
quit()
if command.restart:
game_step = 0
# If we control replay, change game state
game_step = max(
0, min(len(self.game_states) - 1, game_step + control_events["time_change"])
0, min(len(self.game_states) - 1, game_step + command.frame_change)
)
if gui_properties.paused and game_step != game.time:
game.channels = deepcopy(self.game_states[game_step])
game.time = game_step
last_move_time = _t
# If we are not paused, play the game
elif (
_t - last_move_time > gui_properties.game_speed * 0.512
_t - last_move_time > (1/gui_properties.game_speed) * 0.512
and not gui_properties.paused
):
if game.is_done():
Expand Down
20 changes: 15 additions & 5 deletions generals/envs/gymnasium_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ def action_space(self) -> gym.Space:

def render(self, fps: int = 6) -> None:
if self.render_mode == "human":
self.gui.tick(fps=fps)
command = self.gui.tick(fps=fps)
if command.quit:
self.close()

def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[
Observation, dict[str, Any]
]:
def reset(
self, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Observation, dict[str, Any]]:
if options is None:
options = {}
super().reset(seed=seed)
Expand Down Expand Up @@ -95,7 +97,9 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None)
info = {}
return observation, info

def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]:
def step(
self, action: Action
) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]:
# get action of NPC
npc_action = self.npc.play(self.game._agent_observation(self.npc.name))
actions = {self.agent_name: action, self.npc.name: npc_action}
Expand Down Expand Up @@ -133,3 +137,9 @@ def _default_reward(
else:
reward = 0
return reward

def close(self) -> None:
import pygame

pygame.quit()
quit()
13 changes: 9 additions & 4 deletions generals/envs/pettingzoo_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def action_space(self, agent: AgentID) -> spaces.Space:

def render(self, fps=6) -> None:
if self.render_mode == "human":
self.gui.tick(fps=fps)
command = self.gui.tick(fps=fps)
if command.quit:
self.close()

def reset(
self, seed: int | None = None, options: dict | None = None
Expand All @@ -76,7 +78,7 @@ def reset(
self.game = Game(grid, self.agents)

if self.render_mode == "human":
self.gui = GUI(self.game, self.agent_data)
self.gui = GUI(self.game, self.agent_data, "train")

if "replay_file" in options:
self.replay = Replay(
Expand Down Expand Up @@ -146,5 +148,8 @@ def _default_reward(
reward = 0
return reward

def close(self):
print("Closing environment")
def close(self) -> None:
import pygame

pygame.quit()
quit()
180 changes: 141 additions & 39 deletions generals/gui/event_handler.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,179 @@
import pygame
from pygame.event import Event

from .properties import Properties
from generals.core import config as c

######################
# Replay keybindings #
######################
RIGHT = pygame.K_RIGHT
LEFT = pygame.K_LEFT
SPACE = pygame.K_SPACE
Q = pygame.K_q
R = pygame.K_r
H = pygame.K_h
L = pygame.K_l


class Command:
def __init__(self):
self.quit: bool = False


class ReplayCommand(Command):
def __init__(self):
super().__init__()
self.frame_change: int = 0
self.speed_change: float = 1.0
self.restart: bool = False
self.pause: bool = False


class GameCommand(Command):
def __init__(self):
super().__init__()
raise NotImplementedError


class TrainCommand(Command):
def __init__(self):
super().__init__()


class EventHandler:
def __init__(self, properties: Properties, from_replay=False):
def __init__(self, properties: Properties):
"""
Initialize the event handler.
Args:
properties: the Properties object
from_replay: bool, whether the game is from a replay
"""
self.properties = properties
self.from_replay = from_replay
self.mode = properties.mode
self.handler_fn = self.initialize_handler_fn()
self.command = self.initialize_command()

def initialize_handler_fn(self):
"""
Initialize the handler function based on the mode.
"""
if self.mode == "replay":
return self.__handle_replay_key_controls
elif self.mode == "game":
return self.__handle_game_key_controls
elif self.mode == "train":
return self.__handle_train_key_controls
raise ValueError("Invalid mode")

def handle_events(self):
def initialize_command(self):
"""
Initialize the command type based on the mode.
"""
if self.mode == "replay":
return ReplayCommand
elif self.mode == "game":
return GameCommand
elif self.mode == "train":
return TrainCommand
raise ValueError("Invalid mode")

def handle_events(self) -> Command:
"""
Handle pygame GUI events
"""
control_events = {
"time_change": 0,
}
command = self.command()
for event in pygame.event.get():
if event.type == pygame.QUIT or (
event.type == pygame.KEYDOWN and event.key == pygame.K_q
):
pygame.quit()
quit()

if event.type == pygame.KEYDOWN and self.from_replay:
self.__handle_key_controls(event, control_events)
if event.type == pygame.QUIT:
command.quit = True
if event.type == pygame.KEYDOWN:
command = self.handler_fn(event, command)
elif event.type == pygame.MOUSEBUTTONDOWN:
self.__handle_mouse_click()
return command

return control_events


def __handle_key_controls(self, event, control_events):
def __handle_replay_key_controls(self, event: Event, command: Command) -> Command:
"""
Handle key controls for replay mode.
Control game speed, pause, and replay frames.
"""
match event.key:
# Speed up game right arrow is pressed
case pygame.K_RIGHT:
self.properties.game_speed = max(1 / 128, self.properties.game_speed / 2)
# Slow down game left arrow is pressed
case pygame.K_LEFT:
self.properties.game_speed = min(32.0, self.properties.game_speed * 2)
# Toggle play/pause
case pygame.K_SPACE:
self.properties.paused = not self.properties.paused
case pygame.K_r:
control_events["restart"] = True
# Control replay frames
case pygame.K_h:
control_events["time_change"] = -1
self.properties.paused = True
case pygame.K_l:
control_events["time_change"] = 1
self.properties.paused = True
if event.key == Q:
command.quit = True
elif event.key == RIGHT:
command.speed_change = 2.0
elif event.key == LEFT:
command.speed_change = 0.5
elif event.key == SPACE:
command.pause = True
elif event.key == R:
command.restart = True
command.pause = True
elif event.key == H:
command.frame_change = -1
command.pause = True
elif event.key == L:
command.frame_change = 1
self.properties.paused = True
return command

def __handle_game_key_controls(
self, event: Event, command: Command
) -> dict[str, any]:
raise NotImplementedError

def __handle_train_key_controls(
self, event: Event, command: Command
) -> dict[str, any]:
if event.key == Q:
command.quit = True
return command

def __handle_mouse_click(self):
"""
Handle mouse click event.
"""
if self.properties.mode == "replay":
self.__handle_replay_clicks()
elif self.properties.mode == "game":
self.__handle_game_clicks()
elif self.properties.mode == "train":
self.__handle_train_clicks()

def __handle_game_clicks(self):
"""
Handle mouse clicks in game mode.
"""
pass

def __handle_train_clicks(self):
"""
Handle mouse clicks in training mode.
"""
pass

def __handle_replay_clicks(self):
"""
Handle mouse clicks in replay mode.
"""
agents = self.properties.game.agents
agent_fov = self.properties.agent_fov

x, y = pygame.mouse.get_pos()
for i, agent in enumerate(agents):
if self.properties.is_click_on_agents_row(x, y, i):
if self.is_click_on_agents_row(x, y, i):
agent_fov[agent] = not agent_fov[agent]
break

def is_click_on_agents_row(self, x: int, y: int, i: int) -> bool:
"""
Check if the click is on an agent's row.
Args:
x: int, x-coordinate of the click
y: int, y-coordinate of the click
i: int, index of the row
"""
return (
x >= self.properties.display_grid_width
and (i + 1) * c.GUI_ROW_HEIGHT <= y < (i + 2) * c.GUI_ROW_HEIGHT
)
19 changes: 13 additions & 6 deletions generals/gui/gui.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal

from generals.core.game import Game
from .properties import Properties
Expand All @@ -8,13 +8,20 @@

class GUI:
def __init__(
self, game: Game, agent_data: dict[str, dict[str, Any]], from_replay=False
self,
game: Game,
agent_data: dict[str, dict[str, Any]],
mode: Literal["train", "game", "replay"] = "train",
):
self.properties = Properties(game, agent_data)
self.properties = Properties(game, agent_data, mode)
self.__renderer = Renderer(self.properties)
self.__event_handler = EventHandler(self.properties, from_replay)
self.__event_handler = EventHandler(self.properties)

def tick(self, fps=None):
control_events = self.__event_handler.handle_events()
command = self.__event_handler.handle_events()
if self.properties.mode == "replay":
command = self.__event_handler.handle_replay_command(command)
self.properties.update_speed(command.speed_change)
self.properties.paused = command.pause
self.__renderer.render(fps)
return control_events
return command
Loading

0 comments on commit 4258b79

Please sign in to comment.