Skip to content

Commit fe1f3a2

Browse files
authored
Added parallel kruskal algorithm (#184)
Following changes have been made, * Adding parallel kruskal algorithm * Allowed custom comparator in `merge_sort_parallel` * Allowed, `len(array_object)`
1 parent aebae4c commit fe1f3a2

File tree

6 files changed

+142
-22
lines changed

6 files changed

+142
-22
lines changed

pydatastructs/graphs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from .algorithms import (
1111
breadth_first_search,
1212
breadth_first_search_parallel,
13-
minimum_spanning_tree
13+
minimum_spanning_tree,
14+
minimum_spanning_tree_parallel
1415
)
1516

1617
__all__.extend(algorithms.__all__)

pydatastructs/graphs/algorithms.py

Lines changed: 92 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
from pydatastructs.utils import GraphEdge
88
from pydatastructs.miscellaneous_data_structures import DisjointSetForest
99
from pydatastructs.graphs.graph import Graph
10+
from pydatastructs.linear_data_structures.algorithms import merge_sort_parallel
1011

1112
__all__ = [
1213
'breadth_first_search',
1314
'breadth_first_search_parallel',
14-
'minimum_spanning_tree'
15+
'minimum_spanning_tree',
16+
'minimum_spanning_tree_parallel'
1517
]
1618

1719
def breadth_first_search(
@@ -190,36 +192,109 @@ def _breadth_first_search_parallel_adjacency_list(
190192

191193
_breadth_first_search_parallel_adjacency_matrix = _breadth_first_search_parallel_adjacency_list
192194

195+
def _generate_mst_object(graph):
196+
mst = Graph(*[getattr(graph, str(v)) for v in graph.vertices])
197+
return mst
198+
199+
def _sort_edges(graph, num_threads=None):
200+
edges = list(graph.edge_weights.items())
201+
if num_threads is None:
202+
sort_key = lambda item: item[1].value
203+
return sorted(edges, key=sort_key)
204+
205+
merge_sort_parallel(edges, num_threads,
206+
comp=lambda u,v: u[1].value <= v[1].value)
207+
return edges
208+
193209
def _minimum_spanning_tree_kruskal_adjacency_list(graph):
194-
mst = Graph(*[getattr(graph, v) for v in graph.vertices])
195-
sort_key = lambda item: item[1].value
210+
mst = _generate_mst_object(graph)
196211
dsf = DisjointSetForest()
197212
for v in graph.vertices:
198213
dsf.make_set(v)
199-
for _, edge in sorted(graph.edge_weights.items(), key=sort_key):
214+
for _, edge in _sort_edges(graph):
200215
u, v = edge.source.name, edge.target.name
201216
if dsf.find_root(u) is not dsf.find_root(v):
202217
mst.add_edge(u, v, edge.value)
203218
dsf.union(u, v)
204219
return mst
205220

206-
def _minimum_spanning_tree_kruskal_adjacency_matrix(graph):
207-
mst = Graph(*[getattr(graph, str(v)) for v in graph.vertices])
208-
sort_key = lambda item: item[1].value
221+
_minimum_spanning_tree_kruskal_adjacency_matrix = \
222+
_minimum_spanning_tree_kruskal_adjacency_list
223+
224+
def minimum_spanning_tree(graph, algorithm):
225+
"""
226+
Computes a minimum spanning tree for the given
227+
graph and algorithm.
228+
229+
Parameters
230+
==========
231+
232+
graph: Graph
233+
The graph whose minimum spanning tree
234+
has to be computed.
235+
algorithm: str
236+
The algorithm which should be used for
237+
computing a minimum spanning tree.
238+
Currently the following algorithms are
239+
supported,
240+
'kruskal' -> Kruskal's algorithm as given in
241+
[1].
242+
243+
Returns
244+
=======
245+
246+
mst: Graph
247+
A minimum spanning tree using the implementation
248+
same as the graph provided in the input.
249+
250+
Examples
251+
========
252+
253+
>>> from pydatastructs import Graph, AdjacencyListGraphNode
254+
>>> from pydatastructs import minimum_spanning_tree
255+
>>> u = AdjacencyListGraphNode('u')
256+
>>> v = AdjacencyListGraphNode('v')
257+
>>> G = Graph(u, v)
258+
>>> G.add_edge(u.name, v.name, 3)
259+
>>> mst = minimum_spanning_tree(G, 'kruskal')
260+
>>> u_n = mst.neighbors(u.name)
261+
>>> mst.get_edge(u.name, u_n[0].name).value
262+
3
263+
264+
References
265+
==========
266+
267+
.. [1] https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
268+
"""
269+
import pydatastructs.graphs.algorithms as algorithms
270+
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
271+
if not hasattr(algorithms, func):
272+
raise NotImplementedError(
273+
"Currently %s algoithm for %s implementation of graphs "
274+
"isn't implemented for finding minimum spanning trees."
275+
%(algorithm, graph._impl))
276+
return getattr(algorithms, func)(graph)
277+
278+
def _minimum_spanning_tree_parallel_kruskal_adjacency_list(graph, num_threads):
279+
mst = _generate_mst_object(graph)
209280
dsf = DisjointSetForest()
210281
for v in graph.vertices:
211282
dsf.make_set(v)
212-
for _, edge in sorted(graph.edge_weights.items(), key=sort_key):
283+
edges = _sort_edges(graph, num_threads)
284+
for _, edge in edges:
213285
u, v = edge.source.name, edge.target.name
214286
if dsf.find_root(u) is not dsf.find_root(v):
215287
mst.add_edge(u, v, edge.value)
216288
dsf.union(u, v)
217289
return mst
218290

219-
def minimum_spanning_tree(graph, algorithm):
291+
_minimum_spanning_tree_parallel_kruskal_adjacency_matrix = \
292+
_minimum_spanning_tree_parallel_kruskal_adjacency_list
293+
294+
def minimum_spanning_tree_parallel(graph, algorithm, num_threads):
220295
"""
221296
Computes a minimum spanning tree for the given
222-
graph and algorithm.
297+
graph and algorithm using the given number of threads.
223298
224299
Parameters
225300
==========
@@ -234,6 +309,8 @@ def minimum_spanning_tree(graph, algorithm):
234309
supported,
235310
'kruskal' -> Kruskal's algorithm as given in
236311
[1].
312+
num_threads: int
313+
The number of threads to be used.
237314
238315
Returns
239316
=======
@@ -246,26 +323,26 @@ def minimum_spanning_tree(graph, algorithm):
246323
========
247324
248325
>>> from pydatastructs import Graph, AdjacencyListGraphNode
249-
>>> from pydatastructs import minimum_spanning_tree
326+
>>> from pydatastructs import minimum_spanning_tree_parallel
250327
>>> u = AdjacencyListGraphNode('u')
251328
>>> v = AdjacencyListGraphNode('v')
252329
>>> G = Graph(u, v)
253330
>>> G.add_edge(u.name, v.name, 3)
254-
>>> mst = minimum_spanning_tree(G, 'kruskal')
331+
>>> mst = minimum_spanning_tree_parallel(G, 'kruskal', 3)
255332
>>> u_n = mst.neighbors(u.name)
256333
>>> mst.get_edge(u.name, u_n[0].name).value
257334
3
258335
259336
References
260337
==========
261338
262-
.. [1] https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
339+
.. [1] https://en.wikipedia.org/wiki/Kruskal%27s_algorithm#Parallel_algorithm
263340
"""
264341
import pydatastructs.graphs.algorithms as algorithms
265-
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
342+
func = "_minimum_spanning_tree_parallel_" + algorithm + "_" + graph._impl
266343
if not hasattr(algorithms, func):
267344
raise NotImplementedError(
268345
"Currently %s algoithm for %s implementation of graphs "
269346
"isn't implemented for finding minimum spanning trees."
270347
%(algorithm, graph._impl))
271-
return getattr(algorithms, func)(graph)
348+
return getattr(algorithms, func)(graph, num_threads)

pydatastructs/graphs/tests/test_algorithms.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pydatastructs import (breadth_first_search, Graph,
2-
breadth_first_search_parallel, minimum_spanning_tree)
2+
breadth_first_search_parallel, minimum_spanning_tree,
3+
minimum_spanning_tree_parallel)
34

45

56
def test_breadth_first_search():
@@ -148,3 +149,30 @@ def _test_minimum_spanning_tree(ds, algorithm):
148149

149150
_test_minimum_spanning_tree("List", "kruskal")
150151
_test_minimum_spanning_tree("Matrix", "kruskal")
152+
153+
def test_minimum_spanning_tree_parallel():
154+
155+
def _test_minimum_spanning_tree_parallel(ds, algorithm):
156+
import pydatastructs.utils.misc_util as utils
157+
GraphNode = getattr(utils, "Adjacency" + ds + "GraphNode")
158+
a, b, c, d, e = [GraphNode(x) for x in [0, 1, 2, 3, 4]]
159+
graph = Graph(a, b, c, d, e)
160+
graph.add_edge(a.name, c.name, 10)
161+
graph.add_edge(c.name, a.name, 10)
162+
graph.add_edge(a.name, d.name, 7)
163+
graph.add_edge(d.name, a.name, 7)
164+
graph.add_edge(c.name, d.name, 9)
165+
graph.add_edge(d.name, c.name, 9)
166+
graph.add_edge(d.name, b.name, 32)
167+
graph.add_edge(b.name, d.name, 32)
168+
graph.add_edge(d.name, e.name, 23)
169+
graph.add_edge(e.name, d.name, 23)
170+
mst = minimum_spanning_tree_parallel(graph, algorithm, 3)
171+
expected_mst = [('0_3', 7), ('2_3', 9), ('3_4', 23), ('3_1', 32),
172+
('3_0', 7), ('3_2', 9), ('4_3', 23), ('1_3', 32)]
173+
assert len(expected_mst) == 2*len(mst.edge_weights.items())
174+
for k, v in mst.edge_weights.items():
175+
assert (k, v.value) in expected_mst
176+
177+
_test_minimum_spanning_tree_parallel("List", "kruskal")
178+
_test_minimum_spanning_tree_parallel("Matrix", "kruskal")

pydatastructs/linear_data_structures/algorithms.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
'merge_sort_parallel'
99
]
1010

11-
def _merge(array, sl, el, sr, er, end):
11+
def _merge(array, sl, el, sr, er, end, comp):
1212
l, r = [], []
1313
for i in range(sl, el + 1):
1414
if (i <= end and
@@ -22,7 +22,7 @@ def _merge(array, sl, el, sr, er, end):
2222
array[i] = None
2323
i, j, k = 0, 0, sl
2424
while i < len(l) and j < len(r):
25-
if l[i] <= r[j]:
25+
if comp(l[i], r[j]):
2626
array[k] = l[i]
2727
i += 1
2828
else:
@@ -61,6 +61,13 @@ def merge_sort_parallel(array, num_threads, **kwargs):
6161
is to be sorted.
6262
Optional, by default the index
6363
of the last position filled.
64+
comp: lambda/function
65+
The comparator which is to be used
66+
for sorting. If the function returns
67+
False then only swapping is performed.
68+
Optional, by default, less than or
69+
equal to is used for comparing two
70+
values.
6471
6572
Examples
6673
========
@@ -70,14 +77,18 @@ def merge_sort_parallel(array, num_threads, **kwargs):
7077
>>> merge_sort_parallel(arr, 3)
7178
>>> [arr[0], arr[1], arr[2]]
7279
[1, 2, 3]
80+
>>> merge_sort_parallel(arr, 3, comp=lambda u, v: u > v)
81+
>>> [arr[0], arr[1], arr[2]]
82+
[3, 2, 1]
7383
7484
References
7585
==========
7686
7787
.. [1] https://en.wikipedia.org/wiki/Merge_sort
7888
"""
7989
start = kwargs.get('start', 0)
80-
end = kwargs.get('end', array._size - 1)
90+
end = kwargs.get('end', len(array) - 1)
91+
comp = kwargs.get("comp", lambda u, v: u <= v)
8192
for size in range(floor(log(end - start + 1, 2)) + 1):
8293
pow_2 = 2**size
8394
with ThreadPoolExecutor(max_workers=num_threads) as Executor:
@@ -88,7 +99,7 @@ def merge_sort_parallel(array, num_threads, **kwargs):
8899
array,
89100
i, i + pow_2 - 1,
90101
i + pow_2, i + 2*pow_2 - 1,
91-
end).result()
102+
end, comp).result()
92103
i = i + 2*pow_2
93104

94105
if _check_type(array, DynamicArray):

pydatastructs/linear_data_structures/arrays.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def fill(self, elem):
127127
for i in range(self._size):
128128
self._data[i] = elem
129129

130+
def __len__(self):
131+
return self._size
132+
130133

131134
class DynamicArray(Array):
132135
"""

pydatastructs/trees/binary_trees.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class BinaryTree(object):
2828
key
2929
Required if tree is to be instantiated with
3030
root otherwise not needed.
31-
comp: lambda
31+
comp: lambda/function
3232
Optional, A lambda function which will be used
3333
for comparison of keys. Should return a
3434
bool value. By default it implements less

0 commit comments

Comments
 (0)