Skip to content

Commit

Permalink
refactor: improve clarity of grid & gridfactory
Browse files Browse the repository at this point in the history
  • Loading branch information
anordin95 authored and strakam committed Dec 23, 2024
1 parent af2535d commit 731f7f2
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 81 deletions.
3 changes: 2 additions & 1 deletion examples/complete_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

# Initialize grid factory
grid_factory = GridFactory(
grid_dims=(10, 10), # Grid height and width
min_grid_dims=(10, 10), # Grid height and width are randomly selected
max_grid_dims=(15, 15),
mountain_density=0.2, # Expected percentage of mountains
city_density=0.05, # Expected percentage of cities
general_positions=[(1, 2), (7, 8)], # Positions of the generals
Expand Down
3 changes: 2 additions & 1 deletion examples/record_replay_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

# Initialize grid factory
grid_factory = GridFactory(
grid_dims=(4, 4), # Grid height and width
min_grid_dims=(4, 4), # Grid height and width
max_grid_dims=(4, 4),
mountain_density=0.0, # Expected percentage of mountains
city_density=0.05, # Expected percentage of cities
general_positions=[(0, 0), (3, 3)], # Positions of the generals
Expand Down
108 changes: 45 additions & 63 deletions generals/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,35 @@
from .config import MOUNTAIN, PASSABLE


class InvalidGridError(Exception):
pass


class Grid:
def __init__(self, grid: str | np.ndarray):
if not isinstance(grid, str | np.ndarray):
raise ValueError(f"grid must be a str or np.ndarray. Received grid with type: {type(grid)}.")

if isinstance(grid, str):
grid = grid.strip()
grid = Grid.numpify_grid(grid)

Grid.ensure_grid_is_valid(grid)
self.grid = grid

def __eq__(self, other):
return np.array_equal(self.grid, other.grid)
@staticmethod
def ensure_grid_is_valid(grid: np.ndarray):
if not Grid.are_generals_connected(grid):
raise InvalidGridError("Invalid grid layout - generals cannot reach each other.")

@property
def grid(self):
return self._grid

@grid.setter
def grid(self, grid: str | np.ndarray):
match grid:
case str(grid):
grid = grid.strip()
grid = Grid.numpify_grid(grid)
case np.ndarray():
pass
case _:
raise ValueError("Grid must be encoded as a string or a numpy array.")
if not Grid.verify_grid_connectivity(grid):
raise ValueError("Invalid grid layout - generals cannot reach each other.")
# check that exactly one 'A' and one 'B' are present in the grid
first_general = np.argwhere(np.isin(grid, ["A"]))
second_general = np.argwhere(np.isin(grid, ["B"]))
if len(first_general) != 1 or len(second_general) != 1:
raise ValueError("Exactly one 'A' and one 'B' should be present in the grid.")
raise InvalidGridError("Exactly one 'A' and one 'B' should be present in the grid.")

self._grid = grid
def __eq__(self, other):
return np.array_equal(self.grid, other.grid)

@staticmethod
def generals_distance(grid: "Grid") -> int:
Expand All @@ -48,10 +47,9 @@ def stringify_grid(grid: np.ndarray) -> str:
return "\n".join(["".join(row) for row in grid])

@staticmethod
def verify_grid_connectivity(grid: np.ndarray | str) -> bool:
def are_generals_connected(grid: np.ndarray | str) -> bool:
"""
Verify grid layout (can generals reach each other?)
Returns True if grid is valid, False otherwise
Returns True if there is a path connecting the two generals.
"""
if isinstance(grid, str):
grid = Grid.numpify_grid(grid)
Expand Down Expand Up @@ -100,37 +98,23 @@ def __init__(
seed: A random seed i.e. a way to make the randomness repeatable.
"""
self.rng = np.random.default_rng(seed)
self.grid_height = self.rng.integers(min_grid_dims[0], max_grid_dims[0] + 1)
self.grid_width = self.rng.integers(min_grid_dims[0], max_grid_dims[0] + 1)
self.min_grid_dims = min_grid_dims
self.max_grid_dims = max_grid_dims
self.mountain_density = mountain_density
self.city_density = city_density
self.general_positions = general_positions

def grid_from_string(self, grid: str) -> Grid:
return Grid(grid)
def set_rng(self, rng: np.random.Generator):
self.rng = rng

def grid_from_generator(
self,
grid_dims: tuple[int, int] | None = None,
mountain_density: float | None = None,
city_density: float | None = None,
general_positions: list[tuple[int, int]] | None = None,
seed: int | None = None,
) -> Grid:
if grid_dims is None:
grid_dims = (self.grid_height, self.grid_width)
if mountain_density is None:
mountain_density = self.mountain_density
if city_density is None:
city_density = self.city_density
if general_positions is None:
general_positions = self.general_positions
if seed is not None:
self.rng = np.random.default_rng(seed)
def generate(self) -> Grid:
grid_height = self.rng.integers(self.min_grid_dims[0], self.max_grid_dims[0] + 1)
grid_width = self.rng.integers(self.min_grid_dims[0], self.max_grid_dims[0] + 1)
grid_dims = (grid_height, grid_width)

# Probabilities of each cell type
p_neutral = 1 - mountain_density - city_density
probs = [p_neutral, mountain_density] + [city_density / 10] * 10
p_neutral = 1 - self.mountain_density - self.city_density
probs = [p_neutral, self.mountain_density] + [self.city_density / 10] * 10

# Place cells on the map
map = self.rng.choice(
Expand All @@ -139,14 +123,17 @@ def grid_from_generator(
p=probs,
)

# Place generals on random squares, they should be atleast some distance apart
min_distance = max(grid_dims) // 2
p1 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1])
while True:
p2 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1])
if abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) >= min_distance:
break
general_positions = [p1, p2]
general_positions = self.general_positions
if general_positions is None:
# Select each generals location, they should be atleast some distance apart
min_distance = max(grid_dims) // 2
p1 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1])
while True:
p2 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1])
if abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) >= min_distance:
break
general_positions = [p1, p2]

for i, idx in enumerate(general_positions):
map[idx[0], idx[1]] = chr(ord("A") + i)

Expand All @@ -155,11 +142,6 @@ def grid_from_generator(

try:
return Grid(map_string)
except ValueError:
return self.grid_from_generator(
grid_dims=grid_dims,
mountain_density=mountain_density,
city_density=city_density,
general_positions=general_positions,
seed=None,
)
except InvalidGridError:
# Keep randomly generating grids until one works!
return self.generate()
12 changes: 7 additions & 5 deletions generals/envs/gymnasium_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from generals.agents import Agent, AgentFactory
from generals.core.game import Action, Game, Info
from generals.core.grid import GridFactory
from generals.core.grid import Grid, GridFactory
from generals.core.observation import Observation
from generals.core.replay import Replay
from generals.gui import GUI
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
assert self.agent_id != npc.id, "Agent ids must be unique - you can pass custom ids to agent constructors."

# Game
grid = self.grid_factory.grid_from_generator()
grid = self.grid_factory.generate()
self.game = Game(grid, [self.agent_id, self.npc.id])
self.observation_space = self.game.observation_space
self.action_space = self.game.action_space
Expand All @@ -69,10 +69,12 @@ def reset(
options = {}

if "grid" in options:
grid = self.grid_factory.grid_from_string(options["grid"])
grid = Grid(options["grid"])
else:
self.grid_factory.rng = self.np_random
grid = self.grid_factory.grid_from_generator()
# Provide the np.random.Generator instance created in Env.reset()
# as opposed to creating a new one with the same seed.
self.grid_factory.set_rng(rng=self.np_random)
grid = self.grid_factory.generate()

# Create game for current run
self.game = Game(grid, self.agent_ids)
Expand Down
11 changes: 8 additions & 3 deletions generals/envs/pettingzoo_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from copy import deepcopy
from typing import Any, TypeAlias

import numpy as np
import pettingzoo # type: ignore
from gymnasium import spaces

from generals.agents.agent import Agent
from generals.core.game import Action, Game, Info, Observation
from generals.core.grid import GridFactory
from generals.core.grid import Grid, GridFactory
from generals.core.replay import Replay
from generals.gui import GUI
from generals.gui.properties import GuiMode
Expand Down Expand Up @@ -81,9 +82,13 @@ def reset(
options = {}
self.agents = deepcopy(self.possible_agents)
if "grid" in options:
grid = self.grid_factory.grid_from_string(options["grid"])
grid = Grid(options["grid"])
else:
grid = self.grid_factory.grid_from_generator(seed=seed)
# The pettingzoo.Parallel_Env's reset() notably differs
# from gymnasium.Env's reset() in that it does not create
# a random generator which should be re-used.
self.grid_factory.set_rng(rng=np.random.default_rng(seed))
grid = self.grid_factory.generate()

self.game = Game(grid, self.agents)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_game(grid=None):
city_density=0.1,
general_positions=[[3, 3], [1, 3]],
)
grid = grid_factory.grid_from_generator()
grid = grid_factory.generate()
return game.Game(grid, ["red", "blue"])


Expand Down
14 changes: 7 additions & 7 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_verify_grid():
...B.
"""
_grid = Grid(map)
assert Grid.verify_grid_connectivity(_grid.grid)
assert Grid.are_generals_connected(_grid.grid)

map = """
.....
Expand All @@ -36,7 +36,7 @@ def test_verify_grid():
"""

map = Grid.numpify_grid(map)
assert not Grid.verify_grid_connectivity(map)
assert not Grid.are_generals_connected(map)

map = """
.....
Expand All @@ -46,7 +46,7 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
assert Grid.verify_grid_connectivity(map)
assert Grid.are_generals_connected(map)

map = """
...#.
Expand All @@ -56,7 +56,7 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
assert not Grid.verify_grid_connectivity(map)
assert not Grid.are_generals_connected(map)

map = """
...#.
Expand All @@ -66,14 +66,14 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
assert not Grid.verify_grid_connectivity(map)
assert not Grid.are_generals_connected(map)

def test_grid_factory():
generator = GridFactory()
generator.rng = np.random.default_rng()
for _ in range(10):
grid = generator.grid_from_generator()
assert Grid.verify_grid_connectivity(grid.grid)
grid = generator.generate()
assert Grid.are_generals_connected(grid.grid)
height, width = grid.grid.shape
assert Grid.generals_distance(grid) >= max(height, width) // 2

Expand Down

0 comments on commit 731f7f2

Please sign in to comment.