Skip to content

Commit 47b3bd9

Browse files
authored
Visualize PropertyLayers (#2336)
This PR adds support for visualizing PropertyLayers in the Matplotlib-based space visualization component. It allows users to overlay PropertyLayer data on top of the existing grid and agent visualizations, or on its own. It introduces a new `propertylayer_portrayal` parameter to customize the appearance of PropertyLayers and refactors the existing space visualization code for better modularity and flexibility.
1 parent a7dc9b2 commit 47b3bd9

File tree

2 files changed

+180
-88
lines changed

2 files changed

+180
-88
lines changed
Lines changed: 176 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,190 @@
11
"""Matplotlib based solara components for visualization MESA spaces and plots."""
22

3-
from collections import defaultdict
3+
import warnings
44

5+
import matplotlib.pyplot as plt
56
import networkx as nx
7+
import numpy as np
68
import solara
9+
from matplotlib.cm import ScalarMappable
10+
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
711
from matplotlib.figure import Figure
8-
from matplotlib.ticker import MaxNLocator
912

1013
import mesa
1114
from mesa.experimental.cell_space import VoronoiGrid
15+
from mesa.space import PropertyLayer
1216
from mesa.visualization.utils import update_counter
1317

1418

15-
def make_space_matplotlib(agent_portrayal=None): # noqa: D103
19+
def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None):
20+
"""Create a Matplotlib-based space visualization component.
21+
22+
Args:
23+
agent_portrayal (function): Function to portray agents
24+
propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications
25+
26+
Returns:
27+
function: A function that creates a SpaceMatplotlib component
28+
"""
1629
if agent_portrayal is None:
1730

1831
def agent_portrayal(a):
1932
return {"id": a.unique_id}
2033

2134
def MakeSpaceMatplotlib(model):
22-
return SpaceMatplotlib(model, agent_portrayal)
35+
return SpaceMatplotlib(model, agent_portrayal, propertylayer_portrayal)
2336

2437
return MakeSpaceMatplotlib
2538

2639

2740
@solara.component
28-
def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None): # noqa: D103
41+
def SpaceMatplotlib(
42+
model,
43+
agent_portrayal,
44+
propertylayer_portrayal,
45+
dependencies: list[any] | None = None,
46+
):
47+
"""Create a Matplotlib-based space visualization component."""
2948
update_counter.get()
3049
space_fig = Figure()
3150
space_ax = space_fig.subplots()
3251
space = getattr(model, "grid", None)
3352
if space is None:
34-
# Sometimes the space is defined as model.space instead of model.grid
35-
space = model.space
36-
if isinstance(space, mesa.space.NetworkGrid):
37-
_draw_network_grid(space, space_ax, agent_portrayal)
53+
space = getattr(model, "space", None)
54+
55+
if isinstance(space, mesa.space._Grid):
56+
_draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
3857
elif isinstance(space, mesa.space.ContinuousSpace):
39-
_draw_continuous_space(space, space_ax, agent_portrayal)
58+
_draw_continuous_space(space, space_ax, agent_portrayal, model)
59+
elif isinstance(space, mesa.space.NetworkGrid):
60+
_draw_network_grid(space, space_ax, agent_portrayal)
4061
elif isinstance(space, VoronoiGrid):
4162
_draw_voronoi(space, space_ax, agent_portrayal)
42-
else:
43-
_draw_grid(space, space_ax, agent_portrayal)
44-
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)
45-
63+
elif space is None and propertylayer_portrayal:
64+
draw_property_layers(space_ax, space, propertylayer_portrayal, model)
4665

47-
# matplotlib scatter does not allow for multiple shapes in one call
48-
def _split_and_scatter(portray_data, space_ax):
49-
grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []})
50-
51-
# Extract data from the dictionary
52-
x = portray_data["x"]
53-
y = portray_data["y"]
54-
s = portray_data["s"]
55-
c = portray_data["c"]
56-
m = portray_data["m"]
57-
58-
if not (len(x) == len(y) == len(s) == len(c) == len(m)):
59-
raise ValueError(
60-
"Length mismatch in portrayal data lists: "
61-
f"x: {len(x)}, y: {len(y)}, size: {len(s)}, "
62-
f"color: {len(c)}, marker: {len(m)}"
63-
)
64-
65-
# Group the data by marker
66-
for i in range(len(x)):
67-
marker = m[i]
68-
grouped_data[marker]["x"].append(x[i])
69-
grouped_data[marker]["y"].append(y[i])
70-
grouped_data[marker]["s"].append(s[i])
71-
grouped_data[marker]["c"].append(c[i])
66+
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)
7267

73-
# Plot each group with the same marker
74-
for marker, data in grouped_data.items():
75-
space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker)
7668

69+
def draw_property_layers(ax, space, propertylayer_portrayal, model):
70+
"""Draw PropertyLayers on the given axes.
71+
72+
Args:
73+
ax (matplotlib.axes.Axes): The axes to draw on.
74+
space (mesa.space._Grid): The space containing the PropertyLayers.
75+
propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications.
76+
model (mesa.Model): The model instance.
77+
"""
78+
for layer_name, portrayal in propertylayer_portrayal.items():
79+
layer = getattr(model, layer_name, None)
80+
if not isinstance(layer, PropertyLayer):
81+
continue
82+
83+
data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
84+
width, height = data.shape if space is None else (space.width, space.height)
85+
86+
if space and data.shape != (width, height):
87+
warnings.warn(
88+
f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({width}, {height}).",
89+
UserWarning,
90+
stacklevel=2,
91+
)
92+
93+
# Get portrayal properties, or use defaults
94+
alpha = portrayal.get("alpha", 1)
95+
vmin = portrayal.get("vmin", np.min(data))
96+
vmax = portrayal.get("vmax", np.max(data))
97+
colorbar = portrayal.get("colorbar", True)
98+
99+
# Draw the layer
100+
if "color" in portrayal:
101+
rgba_color = to_rgba(portrayal["color"])
102+
normalized_data = (data - vmin) / (vmax - vmin)
103+
rgba_data = np.full((*data.shape, 4), rgba_color)
104+
rgba_data[..., 3] *= normalized_data * alpha
105+
rgba_data = np.clip(rgba_data, 0, 1)
106+
cmap = LinearSegmentedColormap.from_list(
107+
layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
108+
)
109+
im = ax.imshow(
110+
rgba_data.transpose(1, 0, 2),
111+
extent=(0, width, 0, height),
112+
origin="lower",
113+
)
114+
if colorbar:
115+
norm = Normalize(vmin=vmin, vmax=vmax)
116+
sm = ScalarMappable(norm=norm, cmap=cmap)
117+
sm.set_array([])
118+
ax.figure.colorbar(sm, ax=ax, orientation="vertical")
119+
120+
elif "colormap" in portrayal:
121+
cmap = portrayal.get("colormap", "viridis")
122+
if isinstance(cmap, list):
123+
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
124+
im = ax.imshow(
125+
data.T,
126+
cmap=cmap,
127+
alpha=alpha,
128+
vmin=vmin,
129+
vmax=vmax,
130+
extent=(0, width, 0, height),
131+
origin="lower",
132+
)
133+
if colorbar:
134+
plt.colorbar(im, ax=ax, label=layer_name)
135+
else:
136+
raise ValueError(
137+
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
138+
)
139+
140+
141+
def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model):
142+
if propertylayer_portrayal:
143+
draw_property_layers(space_ax, space, propertylayer_portrayal, model)
144+
145+
agent_data = _get_agent_data(space, agent_portrayal)
146+
147+
space_ax.set_xlim(0, space.width)
148+
space_ax.set_ylim(0, space.height)
149+
_split_and_scatter(agent_data, space_ax)
150+
151+
# Draw grid lines
152+
for x in range(space.width + 1):
153+
space_ax.axvline(x, color="gray", linestyle=":")
154+
for y in range(space.height + 1):
155+
space_ax.axhline(y, color="gray", linestyle=":")
156+
157+
158+
def _get_agent_data(space, agent_portrayal):
159+
"""Helper function to get agent data for visualization."""
160+
x, y, s, c, m = [], [], [], [], []
161+
for agents, pos in space.coord_iter():
162+
if not agents:
163+
continue
164+
if not isinstance(agents, list):
165+
agents = [agents] # noqa PLW2901
166+
for agent in agents:
167+
data = agent_portrayal(agent)
168+
x.append(pos[0] + 0.5) # Center the agent in the cell
169+
y.append(pos[1] + 0.5) # Center the agent in the cell
170+
default_size = (180 / max(space.width, space.height)) ** 2
171+
s.append(data.get("size", default_size))
172+
c.append(data.get("color", "tab:blue"))
173+
m.append(data.get("shape", "o"))
174+
return {"x": x, "y": y, "s": s, "c": c, "m": m}
77175

78-
def _draw_grid(space, space_ax, agent_portrayal):
79-
def portray(g):
80-
x = []
81-
y = []
82-
s = [] # size
83-
c = [] # color
84-
m = [] # shape
85-
for i in range(g.width):
86-
for j in range(g.height):
87-
content = g._grid[i][j]
88-
if not content:
89-
continue
90-
if not hasattr(content, "__iter__"):
91-
# Is a single grid
92-
content = [content]
93-
for agent in content:
94-
data = agent_portrayal(agent)
95-
x.append(i)
96-
y.append(j)
97-
98-
# This is the default value for the marker size, which auto-scales
99-
# according to the grid area.
100-
default_size = (180 / max(g.width, g.height)) ** 2
101-
# establishing a default prevents misalignment if some agents are not given size, color, etc.
102-
size = data.get("size", default_size)
103-
s.append(size)
104-
color = data.get("color", "tab:blue")
105-
c.append(color)
106-
mark = data.get("shape", "o")
107-
m.append(mark)
108-
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
109-
return out
110176

111-
space_ax.set_xlim(-1, space.width)
112-
space_ax.set_ylim(-1, space.height)
113-
_split_and_scatter(portray(space), space_ax)
177+
def _split_and_scatter(portray_data, space_ax):
178+
"""Helper function to split and scatter agent data."""
179+
for marker in set(portray_data["m"]):
180+
mask = [m == marker for m in portray_data["m"]]
181+
space_ax.scatter(
182+
[x for x, show in zip(portray_data["x"], mask) if show],
183+
[y for y, show in zip(portray_data["y"], mask) if show],
184+
s=[s for s, show in zip(portray_data["s"], mask) if show],
185+
c=[c for c, show in zip(portray_data["c"], mask) if show],
186+
marker=marker,
187+
)
114188

115189

116190
def _draw_network_grid(space, space_ax, agent_portrayal):
@@ -124,7 +198,7 @@ def _draw_network_grid(space, space_ax, agent_portrayal):
124198
)
125199

126200

127-
def _draw_continuous_space(space, space_ax, agent_portrayal):
201+
def _draw_continuous_space(space, space_ax, agent_portrayal, model):
128202
def portray(space):
129203
x = []
130204
y = []
@@ -139,15 +213,13 @@ def portray(space):
139213

140214
# This is matplotlib's default marker size
141215
default_size = 20
142-
# establishing a default prevents misalignment if some agents are not given size, color, etc.
143216
size = data.get("size", default_size)
144217
s.append(size)
145218
color = data.get("color", "tab:blue")
146219
c.append(color)
147220
mark = data.get("shape", "o")
148221
m.append(mark)
149-
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
150-
return out
222+
return {"x": x, "y": y, "s": s, "c": c, "m": m}
151223

152224
# Determine border style based on space.torus
153225
border_style = "solid" if not space.torus else (0, (5, 10))
@@ -186,8 +258,6 @@ def portray(g):
186258
if "color" in data:
187259
c.append(data["color"])
188260
out = {"x": x, "y": y}
189-
# This is the default value for the marker size, which auto-scales
190-
# according to the grid area.
191261
out["s"] = s
192262
if len(c) > 0:
193263
out["c"] = c
@@ -216,18 +286,37 @@ def portray(g):
216286
alpha=min(1, cell.properties[space.cell_coloring_property]),
217287
c="red",
218288
) # Plot filled polygon
219-
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in red
289+
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black
290+
291+
292+
def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]):
293+
"""Create a plotting function for a specified measure.
220294
295+
Args:
296+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
297+
298+
Returns:
299+
function: A function that creates a PlotMatplotlib component.
300+
"""
221301

222-
def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): # noqa: D103
223302
def MakePlotMeasure(model):
224303
return PlotMatplotlib(model, measure)
225304

226305
return MakePlotMeasure
227306

228307

229308
@solara.component
230-
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # noqa: D103
309+
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None):
310+
"""Create a Matplotlib-based plot for a measure or measures.
311+
312+
Args:
313+
model (mesa.Model): The model instance.
314+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
315+
dependencies (list[any] | None): Optional dependencies for the plot.
316+
317+
Returns:
318+
solara.FigureMatplotlib: A component for rendering the plot.
319+
"""
231320
update_counter.get()
232321
fig = Figure()
233322
ax = fig.subplots()
@@ -244,5 +333,5 @@ def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # no
244333
ax.plot(df.loc[:, m], label=m)
245334
fig.legend()
246335
# Set integer x axis
247-
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
336+
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
248337
solara.FigureMatplotlib(fig, dependencies=dependencies)

tests/test_solara_viz.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,14 @@ def test_call_space_drawer(mocker): # noqa: D103
100100
"Shape": "circle",
101101
"color": "gray",
102102
}
103+
propertylayer_portrayal = None
103104
# initialize with space drawer unspecified (use default)
104105
# component must be rendered for code to run
105106
solara.render(SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)]))
106107
# should call default method with class instance and agent portrayal
107-
mock_space_matplotlib.assert_called_with(model, agent_portrayal)
108+
mock_space_matplotlib.assert_called_with(
109+
model, agent_portrayal, propertylayer_portrayal
110+
)
108111

109112
# specify no space should be drawn
110113
mock_space_matplotlib.reset_mock()

0 commit comments

Comments
 (0)