Skip to content

Commit

Permalink
fix: Fix deterministicity.. seeds now work properly
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 18, 2024
1 parent ccbb1dc commit d56c63c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 22 deletions.
5 changes: 3 additions & 2 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def action_mask(self, agent: str) -> np.ndarray:
owned_cells_indices = self.channel_to_indices(more_than_1_army)
valid_action_mask = np.zeros((self.grid_dims[0], self.grid_dims[1], 4), dtype=bool)

if self.is_done():
if self.is_done() and not self.agent_won(agent): # if you lost, return all zeros
return valid_action_mask

for channel_index, direction in enumerate(DIRECTIONS):
Expand All @@ -107,6 +107,7 @@ def action_mask(self, agent: str) -> np.ndarray:
# get valid action mask for a given direction
valid_source_indices = action_destinations - direction.value
valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0
# assert False
return valid_action_mask

def channel_to_indices(self, channel: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -197,7 +198,7 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
self.time += 1

if self.is_done():
# Give all cells of loser to winner
# give all cells of loser to winner
winner = self.agents[0] if self.agent_won(self.agents[0]) else self.agents[1]
loser = self.agents[1] if winner == self.agents[0] else self.agents[0]
self.channels.ownership[winner] += self.channels.ownership[loser]
Expand Down
30 changes: 17 additions & 13 deletions generals/core/grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from numpy.random import Generator

from .config import MOUNTAIN, PASSABLE

Expand Down Expand Up @@ -94,7 +95,15 @@ def __init__(
self.mountain_density = mountain_density
self.city_density = city_density
self.general_positions = general_positions
self.seed = seed
self._rng = np.random.default_rng(seed)

@property
def rng(self):
return self._rng

@rng.setter
def rng(self, number_generator: Generator):
self._rng = number_generator

def grid_from_string(self, grid: str) -> Grid:
return Grid(grid)
Expand All @@ -115,32 +124,28 @@ def grid_from_generator(
city_density = self.city_density
if general_positions is None:
general_positions = self.general_positions
if seed is None:
if self.seed is None:
seed = np.random.randint(0, 2**20)
else:
seed = self.seed
if seed is not None:
self.rng = np.random.default_rng(seed)

# Probabilities of each cell type
p_neutral = 1 - mountain_density - city_density
probs = [p_neutral, mountain_density] + [city_density / 10] * 10

# Place cells on the map
rng = np.random.default_rng(seed)
map = rng.choice(
map = self.rng.choice(
[PASSABLE, MOUNTAIN, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
size=grid_dims,
p=probs,
)

# Place generals on random squares, they should be atleast some distance apart
min_distance = max(grid_dims) // 2
p1 = np.random.randint(0, grid_dims[0]), np.random.randint(0, grid_dims[1])
p1 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1])
while True:
p2 = np.random.randint(0, grid_dims[0]), np.random.randint(0, grid_dims[1])
p2 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1])
if abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) >= min_distance:
break
general_positions = np.array([p1, p2])
general_positions = [p1, p2]
for i, idx in enumerate(general_positions):
map[idx[0], idx[1]] = chr(ord("A") + i)

Expand All @@ -150,11 +155,10 @@ def grid_from_generator(
try:
return Grid(map_string)
except ValueError:
seed += 1 # Increase seed to generate a different map
return self.grid_from_generator(
grid_dims=grid_dims,
mountain_density=mountain_density,
city_density=city_density,
general_positions=general_positions,
seed=seed,
seed=None,
)
4 changes: 2 additions & 2 deletions generals/envs/gymnasium_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def reset(
if "grid" in options:
grid = self.grid_factory.grid_from_string(options["grid"])
else:
map_seed = self.np_random.integers(0, 2**20)
grid = self.grid_factory.grid_from_generator(seed=map_seed)
self.grid_factory.rng = self.np_random
grid = self.grid_factory.grid_from_generator()

# Create game for current run
self.game = Game(grid, self.agent_ids)
Expand Down
18 changes: 13 additions & 5 deletions tests/test_map.py → tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from generals.core.grid import Grid
from generals.core.grid import Grid, GridFactory


def test_grid_creation():
Expand Down Expand Up @@ -46,7 +46,6 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
grid = Grid(map)
assert Grid.verify_grid_connectivity(map)

map = """
Expand All @@ -57,8 +56,7 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
grid = Grid(map)
assert not Grid.verify_grid_connectivity(grid)
assert not Grid.verify_grid_connectivity(map)

map = """
...#.
Expand All @@ -68,7 +66,17 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
assert Grid.verify_grid_connectivity(map)
assert not Grid.verify_grid_connectivity(map)

def test_grid_factory():
generator = GridFactory()
generator.rng = np.random.default_rng()
for _ in range(10):
grid = generator.grid_from_generator()
assert Grid.verify_grid_connectivity(grid.grid)
height, width = grid.grid.shape
assert Grid.generals_distance(grid) >= max(height, width) // 2



def test_numpify_map():
Expand Down

0 comments on commit d56c63c

Please sign in to comment.