Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed hex-space draw function to avoid overlaps #2609

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 67 additions & 57 deletions mesa/visualization/mpl_space_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@

import contextlib
import itertools
import math
import warnings
from collections.abc import Callable
from itertools import pairwise
from typing import Any

import networkx as nx
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.collections import PatchCollection
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.patches import Polygon, RegularPolygon
from matplotlib.patches import Polygon

import mesa
from mesa.experimental.cell_space import (
Expand Down Expand Up @@ -308,13 +308,6 @@ def draw_hex_grid(
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
draw_grid: whether to draw the grid
kwargs: additional keyword arguments passed to ax.scatter

Returns:
Returns the Axes object with the plot drawn onto it.

``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.

"""
if ax is None:
fig, ax = plt.subplots()
Expand All @@ -323,62 +316,79 @@ def draw_hex_grid(
s_default = (180 / max(space.width, space.height)) ** 2
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# for hexgrids we have to go from logical coordinates to visual coordinates
# this is a bit messy.

# give all even rows an offset in the x direction
# give all rows an offset in the y direction

# numbers here are based on a distance of 1 between centers of hexes
offset = math.sqrt(0.75)
# Parameters for hexagon grid
size = 1.0
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

loc = arguments["loc"].astype(float)

logical = np.mod(loc[:, 1], 2) == 0
loc[:, 0][logical] += 0.5
loc[:, 1] *= offset
# Calculate hexagon centers
loc[:, 0] = loc[:, 0] * x_spacing + (loc[:, 1] % 2) * (x_spacing / 2)
Copy link
Member

@quaquel quaquel Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think something is not right. Look at the top row in both figures. You'll see that they are differently aligned. This means that the neigborhood of a given cell as shown in the grid is not the same as in the underlying data. You can also check this by placing an agent in the very first cell.

This is due to this line of code: loc[:, 1] % 2. Note how in the original code, I check for 0, rather than 1.

Suggested change
loc[:, 0] = loc[:, 0] * x_spacing + (loc[:, 1] % 2) * (x_spacing / 2)
loc[:, 0] = loc[:, 0] * x_spacing + ((loc[:, 1]-1) % 2) * (x_spacing / 2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is still not resolved.

loc[:, 1] = loc[:, 1] * y_spacing
arguments["loc"] = loc

# plot the agents
_scatter(ax, arguments, **kwargs)

# further styling and adding of grid
ax.set_xlim(-1, space.width + 0.5)
ax.set_ylim(-offset, space.height * offset)

def setup_hexmesh(
width,
height,
):
"""Helper function for creating the hexmaesh."""
# fixme: this should be done once, rather than in each update
# fixme check coordinate system in hexgrid (see https://www.redblobgames.com/grids/hexagons/#coordinates-offset)

patches = []
for x, y in itertools.product(range(width), range(height)):
if y % 2 == 0:
x += 0.5 # noqa: PLW2901
y *= offset # noqa: PLW2901
hex = RegularPolygon(
(x, y),
numVertices=6,
radius=math.sqrt(1 / 3),
orientation=np.radians(120),
)
patches.append(hex)
mesh = PatchCollection(
patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1
)
return mesh
# Calculate proper bounds that account for the full hexagon width and height
x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2)
y_max = space.height * y_spacing

# Add padding that accounts for the hexagon points
x_padding = (
size * np.sqrt(3) / 2
) # Distance from center to rightmost point of hexagon
y_padding = size # Distance from center to topmost point of hexagon

# Plot limits to perfectly contain the hexagonal grid
# Determined through physical testing.
ax.set_xlim(-2 * x_padding, x_max + x_padding)
ax.set_ylim(-2 * y_padding, y_max + y_padding)

def setup_hexmesh(width, height):
"""Helper function for creating the hexmesh with unique edges."""
edges = []
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved
size = 1.0
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

def get_hex_vertices(
center_x: float, center_y: float
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices

# Generate edges for each hexagon
for row, col in itertools.product(range(height), range(width)):
# Calculate center position for each hexagon with offset for even rows
x = col * x_spacing + (row % 2) * (x_spacing / 2)
y = row * y_spacing

vertices = get_hex_vertices(x, y)

# Edge logic, connecting each vertex to the next
for v1, v2 in pairwise([*vertices, vertices[0]]):
# Sort vertices to ensure consistent edge representation and avoid duplicates.
edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))]))
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved
if edge not in edges:
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved
edges.append(edge)
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved

# Return LineCollection for hexmesh
return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1)

if draw_grid:
# add grid
ax.add_collection(
setup_hexmesh(
space.width,
space.height,
)
)
ax.add_collection(setup_hexmesh(space.width, space.height))

# Set aspect ratio to 'equal' to ensure hexagons appear regular
ax.set_aspect("equal")
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved
return ax


Expand Down
Loading