1
1
"""Matplotlib based solara components for visualization MESA spaces and plots."""
2
2
3
- from collections import defaultdict
3
+ import warnings
4
4
5
+ import matplotlib .pyplot as plt
5
6
import networkx as nx
7
+ import numpy as np
6
8
import solara
9
+ from matplotlib .cm import ScalarMappable
10
+ from matplotlib .colors import LinearSegmentedColormap , Normalize , to_rgba
7
11
from matplotlib .figure import Figure
8
- from matplotlib .ticker import MaxNLocator
9
12
10
13
import mesa
11
14
from mesa .experimental .cell_space import VoronoiGrid
15
+ from mesa .space import PropertyLayer
12
16
from mesa .visualization .utils import update_counter
13
17
14
18
15
- def make_space_matplotlib (agent_portrayal = None ): # noqa: D103
19
+ def make_space_matplotlib (agent_portrayal = None , propertylayer_portrayal = None ):
20
+ """Create a Matplotlib-based space visualization component.
21
+
22
+ Args:
23
+ agent_portrayal (function): Function to portray agents
24
+ propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications
25
+
26
+ Returns:
27
+ function: A function that creates a SpaceMatplotlib component
28
+ """
16
29
if agent_portrayal is None :
17
30
18
31
def agent_portrayal (a ):
19
32
return {"id" : a .unique_id }
20
33
21
34
def MakeSpaceMatplotlib (model ):
22
- return SpaceMatplotlib (model , agent_portrayal )
35
+ return SpaceMatplotlib (model , agent_portrayal , propertylayer_portrayal )
23
36
24
37
return MakeSpaceMatplotlib
25
38
26
39
27
40
@solara .component
28
- def SpaceMatplotlib (model , agent_portrayal , dependencies : list [any ] | None = None ): # noqa: D103
41
+ def SpaceMatplotlib (
42
+ model ,
43
+ agent_portrayal ,
44
+ propertylayer_portrayal ,
45
+ dependencies : list [any ] | None = None ,
46
+ ):
47
+ """Create a Matplotlib-based space visualization component."""
29
48
update_counter .get ()
30
49
space_fig = Figure ()
31
50
space_ax = space_fig .subplots ()
32
51
space = getattr (model , "grid" , None )
33
52
if space is None :
34
- # Sometimes the space is defined as model. space instead of model.grid
35
- space = model . space
36
- if isinstance (space , mesa .space .NetworkGrid ):
37
- _draw_network_grid (space , space_ax , agent_portrayal )
53
+ space = getattr ( model , " space" , None )
54
+
55
+ if isinstance (space , mesa .space ._Grid ):
56
+ _draw_grid (space , space_ax , agent_portrayal , propertylayer_portrayal , model )
38
57
elif isinstance (space , mesa .space .ContinuousSpace ):
39
- _draw_continuous_space (space , space_ax , agent_portrayal )
58
+ _draw_continuous_space (space , space_ax , agent_portrayal , model )
59
+ elif isinstance (space , mesa .space .NetworkGrid ):
60
+ _draw_network_grid (space , space_ax , agent_portrayal )
40
61
elif isinstance (space , VoronoiGrid ):
41
62
_draw_voronoi (space , space_ax , agent_portrayal )
42
- else :
43
- _draw_grid (space , space_ax , agent_portrayal )
44
- solara .FigureMatplotlib (space_fig , format = "png" , dependencies = dependencies )
45
-
63
+ elif space is None and propertylayer_portrayal :
64
+ draw_property_layers (space_ax , space , propertylayer_portrayal , model )
46
65
47
- # matplotlib scatter does not allow for multiple shapes in one call
48
- def _split_and_scatter (portray_data , space_ax ):
49
- grouped_data = defaultdict (lambda : {"x" : [], "y" : [], "s" : [], "c" : []})
50
-
51
- # Extract data from the dictionary
52
- x = portray_data ["x" ]
53
- y = portray_data ["y" ]
54
- s = portray_data ["s" ]
55
- c = portray_data ["c" ]
56
- m = portray_data ["m" ]
57
-
58
- if not (len (x ) == len (y ) == len (s ) == len (c ) == len (m )):
59
- raise ValueError (
60
- "Length mismatch in portrayal data lists: "
61
- f"x: { len (x )} , y: { len (y )} , size: { len (s )} , "
62
- f"color: { len (c )} , marker: { len (m )} "
63
- )
64
-
65
- # Group the data by marker
66
- for i in range (len (x )):
67
- marker = m [i ]
68
- grouped_data [marker ]["x" ].append (x [i ])
69
- grouped_data [marker ]["y" ].append (y [i ])
70
- grouped_data [marker ]["s" ].append (s [i ])
71
- grouped_data [marker ]["c" ].append (c [i ])
66
+ solara .FigureMatplotlib (space_fig , format = "png" , dependencies = dependencies )
72
67
73
- # Plot each group with the same marker
74
- for marker , data in grouped_data .items ():
75
- space_ax .scatter (data ["x" ], data ["y" ], s = data ["s" ], c = data ["c" ], marker = marker )
76
68
69
+ def draw_property_layers (ax , space , propertylayer_portrayal , model ):
70
+ """Draw PropertyLayers on the given axes.
71
+
72
+ Args:
73
+ ax (matplotlib.axes.Axes): The axes to draw on.
74
+ space (mesa.space._Grid): The space containing the PropertyLayers.
75
+ propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications.
76
+ model (mesa.Model): The model instance.
77
+ """
78
+ for layer_name , portrayal in propertylayer_portrayal .items ():
79
+ layer = getattr (model , layer_name , None )
80
+ if not isinstance (layer , PropertyLayer ):
81
+ continue
82
+
83
+ data = layer .data .astype (float ) if layer .data .dtype == bool else layer .data
84
+ width , height = data .shape if space is None else (space .width , space .height )
85
+
86
+ if space and data .shape != (width , height ):
87
+ warnings .warn (
88
+ f"Layer { layer_name } dimensions ({ data .shape } ) do not match space dimensions ({ width } , { height } )." ,
89
+ UserWarning ,
90
+ stacklevel = 2 ,
91
+ )
92
+
93
+ # Get portrayal properties, or use defaults
94
+ alpha = portrayal .get ("alpha" , 1 )
95
+ vmin = portrayal .get ("vmin" , np .min (data ))
96
+ vmax = portrayal .get ("vmax" , np .max (data ))
97
+ colorbar = portrayal .get ("colorbar" , True )
98
+
99
+ # Draw the layer
100
+ if "color" in portrayal :
101
+ rgba_color = to_rgba (portrayal ["color" ])
102
+ normalized_data = (data - vmin ) / (vmax - vmin )
103
+ rgba_data = np .full ((* data .shape , 4 ), rgba_color )
104
+ rgba_data [..., 3 ] *= normalized_data * alpha
105
+ rgba_data = np .clip (rgba_data , 0 , 1 )
106
+ cmap = LinearSegmentedColormap .from_list (
107
+ layer_name , [(0 , 0 , 0 , 0 ), (* rgba_color [:3 ], alpha )]
108
+ )
109
+ im = ax .imshow (
110
+ rgba_data .transpose (1 , 0 , 2 ),
111
+ extent = (0 , width , 0 , height ),
112
+ origin = "lower" ,
113
+ )
114
+ if colorbar :
115
+ norm = Normalize (vmin = vmin , vmax = vmax )
116
+ sm = ScalarMappable (norm = norm , cmap = cmap )
117
+ sm .set_array ([])
118
+ ax .figure .colorbar (sm , ax = ax , orientation = "vertical" )
119
+
120
+ elif "colormap" in portrayal :
121
+ cmap = portrayal .get ("colormap" , "viridis" )
122
+ if isinstance (cmap , list ):
123
+ cmap = LinearSegmentedColormap .from_list (layer_name , cmap )
124
+ im = ax .imshow (
125
+ data .T ,
126
+ cmap = cmap ,
127
+ alpha = alpha ,
128
+ vmin = vmin ,
129
+ vmax = vmax ,
130
+ extent = (0 , width , 0 , height ),
131
+ origin = "lower" ,
132
+ )
133
+ if colorbar :
134
+ plt .colorbar (im , ax = ax , label = layer_name )
135
+ else :
136
+ raise ValueError (
137
+ f"PropertyLayer { layer_name } portrayal must include 'color' or 'colormap'."
138
+ )
139
+
140
+
141
+ def _draw_grid (space , space_ax , agent_portrayal , propertylayer_portrayal , model ):
142
+ if propertylayer_portrayal :
143
+ draw_property_layers (space_ax , space , propertylayer_portrayal , model )
144
+
145
+ agent_data = _get_agent_data (space , agent_portrayal )
146
+
147
+ space_ax .set_xlim (0 , space .width )
148
+ space_ax .set_ylim (0 , space .height )
149
+ _split_and_scatter (agent_data , space_ax )
150
+
151
+ # Draw grid lines
152
+ for x in range (space .width + 1 ):
153
+ space_ax .axvline (x , color = "gray" , linestyle = ":" )
154
+ for y in range (space .height + 1 ):
155
+ space_ax .axhline (y , color = "gray" , linestyle = ":" )
156
+
157
+
158
+ def _get_agent_data (space , agent_portrayal ):
159
+ """Helper function to get agent data for visualization."""
160
+ x , y , s , c , m = [], [], [], [], []
161
+ for agents , pos in space .coord_iter ():
162
+ if not agents :
163
+ continue
164
+ if not isinstance (agents , list ):
165
+ agents = [agents ] # noqa PLW2901
166
+ for agent in agents :
167
+ data = agent_portrayal (agent )
168
+ x .append (pos [0 ] + 0.5 ) # Center the agent in the cell
169
+ y .append (pos [1 ] + 0.5 ) # Center the agent in the cell
170
+ default_size = (180 / max (space .width , space .height )) ** 2
171
+ s .append (data .get ("size" , default_size ))
172
+ c .append (data .get ("color" , "tab:blue" ))
173
+ m .append (data .get ("shape" , "o" ))
174
+ return {"x" : x , "y" : y , "s" : s , "c" : c , "m" : m }
77
175
78
- def _draw_grid (space , space_ax , agent_portrayal ):
79
- def portray (g ):
80
- x = []
81
- y = []
82
- s = [] # size
83
- c = [] # color
84
- m = [] # shape
85
- for i in range (g .width ):
86
- for j in range (g .height ):
87
- content = g ._grid [i ][j ]
88
- if not content :
89
- continue
90
- if not hasattr (content , "__iter__" ):
91
- # Is a single grid
92
- content = [content ]
93
- for agent in content :
94
- data = agent_portrayal (agent )
95
- x .append (i )
96
- y .append (j )
97
-
98
- # This is the default value for the marker size, which auto-scales
99
- # according to the grid area.
100
- default_size = (180 / max (g .width , g .height )) ** 2
101
- # establishing a default prevents misalignment if some agents are not given size, color, etc.
102
- size = data .get ("size" , default_size )
103
- s .append (size )
104
- color = data .get ("color" , "tab:blue" )
105
- c .append (color )
106
- mark = data .get ("shape" , "o" )
107
- m .append (mark )
108
- out = {"x" : x , "y" : y , "s" : s , "c" : c , "m" : m }
109
- return out
110
176
111
- space_ax .set_xlim (- 1 , space .width )
112
- space_ax .set_ylim (- 1 , space .height )
113
- _split_and_scatter (portray (space ), space_ax )
177
+ def _split_and_scatter (portray_data , space_ax ):
178
+ """Helper function to split and scatter agent data."""
179
+ for marker in set (portray_data ["m" ]):
180
+ mask = [m == marker for m in portray_data ["m" ]]
181
+ space_ax .scatter (
182
+ [x for x , show in zip (portray_data ["x" ], mask ) if show ],
183
+ [y for y , show in zip (portray_data ["y" ], mask ) if show ],
184
+ s = [s for s , show in zip (portray_data ["s" ], mask ) if show ],
185
+ c = [c for c , show in zip (portray_data ["c" ], mask ) if show ],
186
+ marker = marker ,
187
+ )
114
188
115
189
116
190
def _draw_network_grid (space , space_ax , agent_portrayal ):
@@ -124,7 +198,7 @@ def _draw_network_grid(space, space_ax, agent_portrayal):
124
198
)
125
199
126
200
127
- def _draw_continuous_space (space , space_ax , agent_portrayal ):
201
+ def _draw_continuous_space (space , space_ax , agent_portrayal , model ):
128
202
def portray (space ):
129
203
x = []
130
204
y = []
@@ -139,15 +213,13 @@ def portray(space):
139
213
140
214
# This is matplotlib's default marker size
141
215
default_size = 20
142
- # establishing a default prevents misalignment if some agents are not given size, color, etc.
143
216
size = data .get ("size" , default_size )
144
217
s .append (size )
145
218
color = data .get ("color" , "tab:blue" )
146
219
c .append (color )
147
220
mark = data .get ("shape" , "o" )
148
221
m .append (mark )
149
- out = {"x" : x , "y" : y , "s" : s , "c" : c , "m" : m }
150
- return out
222
+ return {"x" : x , "y" : y , "s" : s , "c" : c , "m" : m }
151
223
152
224
# Determine border style based on space.torus
153
225
border_style = "solid" if not space .torus else (0 , (5 , 10 ))
@@ -186,8 +258,6 @@ def portray(g):
186
258
if "color" in data :
187
259
c .append (data ["color" ])
188
260
out = {"x" : x , "y" : y }
189
- # This is the default value for the marker size, which auto-scales
190
- # according to the grid area.
191
261
out ["s" ] = s
192
262
if len (c ) > 0 :
193
263
out ["c" ] = c
@@ -216,18 +286,37 @@ def portray(g):
216
286
alpha = min (1 , cell .properties [space .cell_coloring_property ]),
217
287
c = "red" ,
218
288
) # Plot filled polygon
219
- space_ax .plot (* zip (* polygon ), color = "black" ) # Plot polygon edges in red
289
+ space_ax .plot (* zip (* polygon ), color = "black" ) # Plot polygon edges in black
290
+
291
+
292
+ def make_plot_measure (measure : str | dict [str , str ] | list [str ] | tuple [str ]):
293
+ """Create a plotting function for a specified measure.
220
294
295
+ Args:
296
+ measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
297
+
298
+ Returns:
299
+ function: A function that creates a PlotMatplotlib component.
300
+ """
221
301
222
- def make_plot_measure (measure : str | dict [str , str ] | list [str ] | tuple [str ]): # noqa: D103
223
302
def MakePlotMeasure (model ):
224
303
return PlotMatplotlib (model , measure )
225
304
226
305
return MakePlotMeasure
227
306
228
307
229
308
@solara .component
230
- def PlotMatplotlib (model , measure , dependencies : list [any ] | None = None ): # noqa: D103
309
+ def PlotMatplotlib (model , measure , dependencies : list [any ] | None = None ):
310
+ """Create a Matplotlib-based plot for a measure or measures.
311
+
312
+ Args:
313
+ model (mesa.Model): The model instance.
314
+ measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
315
+ dependencies (list[any] | None): Optional dependencies for the plot.
316
+
317
+ Returns:
318
+ solara.FigureMatplotlib: A component for rendering the plot.
319
+ """
231
320
update_counter .get ()
232
321
fig = Figure ()
233
322
ax = fig .subplots ()
@@ -244,5 +333,5 @@ def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # no
244
333
ax .plot (df .loc [:, m ], label = m )
245
334
fig .legend ()
246
335
# Set integer x axis
247
- ax .xaxis .set_major_locator (MaxNLocator (integer = True ))
336
+ ax .xaxis .set_major_locator (plt . MaxNLocator (integer = True ))
248
337
solara .FigureMatplotlib (fig , dependencies = dependencies )
0 commit comments