Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Property layer visualization for HexGrid #2646

Merged
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 90 additions & 53 deletions mesa/visualization/mpl_space_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import contextlib
import itertools
import warnings
from collections.abc import Callable
from collections.abc import Callable, Iterator
from functools import lru_cache
from itertools import pairwise
from typing import Any

Expand All @@ -18,7 +19,7 @@
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.collections import LineCollection, PatchCollection, PolyCollection
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.patches import Polygon

Expand Down Expand Up @@ -159,6 +160,37 @@ def draw_space(
return ax


# Helper function for getting the vertices of a hexagon given the center and size
def _get_hex_vertices(
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved
center_x: float, center_y: float, size: float
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices


@lru_cache(maxsize=1024, typed=True)
def _get_hexmesh(
width: int, height: int, size: float
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved
) -> Iterator[list[tuple[float, float]]]:
"""Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon."""
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

for row, col in itertools.product(range(height), range(width)):
# Calculate center position with offset for even rows
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
y = row * y_spacing
yield _get_hex_vertices(x, y, size)


def draw_property_layers(
space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes
):
Expand Down Expand Up @@ -205,46 +237,74 @@ def draw_property_layers(
vmax = portrayal.get("vmax", np.max(data))
colorbar = portrayal.get("colorbar", True)

# Draw the layer
# Prepare colormap
if "color" in portrayal:
data = data.T
rgba_color = to_rgba(portrayal["color"])
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
rgba_data = np.clip(rgba_data, 0, 1)
cmap = LinearSegmentedColormap.from_list(
layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
)
im = ax.imshow(
rgba_data,
origin="lower",
)
if colorbar:
norm = Normalize(vmin=vmin, vmax=vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
ax.figure.colorbar(sm, ax=ax, orientation="vertical")

elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
im = ax.imshow(
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
origin="lower",
)
if colorbar:
plt.colorbar(im, ax=ax, label=layer_name)
elif isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)

if isinstance(space, OrthogonalGrid):
if "color" in portrayal:
data = data.T
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
rgba_data = np.clip(rgba_data, 0, 1)
ax.imshow(rgba_data, origin="lower")
else:
ax.imshow(
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
origin="lower",
)

elif isinstance(space, HexGrid):
width, height = data.shape

# Generate hexagon mesh
hexagons = _get_hexmesh(width, height, size=1)

# Normalize colors
norm = Normalize(vmin=vmin, vmax=vmax)
colors = data.ravel() # flatten data to 1D array

if "color" in portrayal:
normalized_colors = np.clip(norm(colors), 0, 1)
rgba_colors = np.full((len(colors), 4), rgba_color)
rgba_colors[:, 3] = normalized_colors * alpha
else:
rgba_colors = cmap(norm(colors))
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved

# Draw hexagons
collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1)
ax.add_collection(collection)

else:
raise NotImplementedError(
f"PropertyLayer visualization not implemented for {type(space)}."
)

# Add colorbar if requested
if colorbar:
norm = Normalize(vmin=vmin, vmax=vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
plt.colorbar(sm, ax=ax, label=layer_name)


def draw_orthogonal_grid(
space: OrthogonalGrid,
Expand Down Expand Up @@ -350,38 +410,15 @@ def setup_hexmesh(width, height):
"""Helper function for creating the hexmesh with unique edges."""
edges = set()
size = 1.0
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

def get_hex_vertices(
center_x: float, center_y: float
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices

# Generate edges for each hexagon
for row, col in itertools.product(range(height), range(width)):
# Calculate center position for each hexagon with offset for even rows
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
y = row * y_spacing

vertices = get_hex_vertices(x, y)

for vertices in _get_hexmesh(width, height, size):
# Edge logic, connecting each vertex to the next
for v1, v2 in pairwise([*vertices, vertices[0]]):
# Sort vertices to ensure consistent edge representation and avoid duplicates.
# Sort vertices to ensure consistent edge representation
Sahil-Chhoker marked this conversation as resolved.
Show resolved Hide resolved
edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))]))
edges.add(edge)

# Return LineCollection for hexmesh
return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1)

if draw_grid:
Expand Down
Loading