Skip to content

Added LCA and tests #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 5, 2020
137 changes: 137 additions & 0 deletions pydatastructs/trees/binary_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,143 @@ def rank(self, x):
walk = p
return r

def _simple_path(self, key, root):
"""
Utility funtion to find the simple path between root and node.

Parameter
=========

key: Node.key
Key of the node to be searched

Returns
=======

path: list
"""

stack = Stack()
stack.push(root)
path = []
node_idx = -1

while not stack.is_empty:
node = stack.pop()
if self.tree[node].key == key:
node_idx = node
break
if self.tree[node].left:
stack.push(self.tree[node].left)
if self.tree[node].right:
stack.push(self.tree[node].right)

if node_idx == -1:
return path

while node_idx != 0:
path.append(node_idx)
node_idx = self.tree[node_idx].parent
path.append(0)
path.reverse()

return path

def _lca_1(self, j, k):
root = self.root_idx
path1 = self._simple_path(j, root)
path2 = self._simple_path(k, root)
if not path1 or not path2:
raise ValueError("One of two path doesn't exists. See %s, %s"
%(path1, path2))

n, m = len(path1), len(path2)
i = j = 0
while i < n and j < m:
if path1[i] != path2[j]:
return self.tree[path1[i - 1]].key
i += 1
j += 1
if path1 < path2:
return self.tree[path1[-1]].key
return self.tree[path2[-1]].key

def _lca_2(self, j, k):
curr_root = self.root_idx
u, v = self.search(j), self.search(k)
if (u is None) or (v is None):
raise ValueError("One of the nodes with key %s "
"or %s doesn't exits"%(j, k))
u_left = self.comparator(self.tree[u].key, \
self.tree[curr_root].key)
v_left = self.comparator(self.tree[v].key, \
self.tree[curr_root].key)

while not (u_left ^ v_left):
if u_left and v_left:
curr_root = self.tree[curr_root].left
else:
curr_root = self.tree[curr_root].right

if curr_root == u or curr_root == v:
if curr_root is None:
return None
return self.tree[curr_root].key

u_left = self.comparator(self.tree[u].key, \
self.tree[curr_root].key)
v_left = self.comparator(self.tree[v].key, \
self.tree[curr_root].key)

if curr_root is None:
return curr_root
return self.tree[curr_root].key

def lowest_common_ancestor(self, j, k, algorithm=1):

"""
Computes the lowest common ancestor of two nodes.

Parameters
==========

j: Node.key
Key of first node
k: Node.key
Key of second node
algorithm: int
The algorithm to be used for computing the
lowest common ancestor.
Optional, by default uses algorithm 1.

1 -> Determines the lowest common ancestor by finding
the first intersection of the paths from v and w
to the root.

2 -> Modifed version of the algorithm given in the
following publication,
D. Harel. A linear time algorithm for the
lowest common ancestors problem. In 21s
Annual Symposium On Foundations of
Computer Science, pages 308-319, 1980.

Returns
=======

Node.key
The key of the lowest common ancestor in the tree.
if both the nodes are present in the tree.

References
==========

.. [1] https://en.wikipedia.org/wiki/Lowest_common_ancestor

.. [2] https://pdfs.semanticscholar.org/e75b/386cc554214aa0ebd6bd6dbdd0e490da3739.pdf

"""
return getattr(self, "_lca_"+str(algorithm))(j, k)

class AVLTree(BinarySearchTree):
"""
Represents AVL trees.
Expand Down
33 changes: 32 additions & 1 deletion pydatastructs/trees/tests/test_binary_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,38 @@ def test_BinarySearchTree():
assert b.delete(-10) is True
assert b.delete(-3) is True
assert b.delete(-13) is None
raises(ValueError, lambda: BST(root_data=6))
bl = BST()
nodes = [50, 30, 90, 70, 100, 60, 80, 55, 20, 40, 15, 10, 16, 17, 18]
for node in nodes:
bl.insert(node, node)

assert bl.lowest_common_ancestor(80, 55, 2) == 70
assert bl.lowest_common_ancestor(60, 70, 2) == 70
assert bl.lowest_common_ancestor(18, 18, 2) == 18
assert bl.lowest_common_ancestor(40, 90, 2) == 50

assert bl.lowest_common_ancestor(18, 10, 2) == 15
assert bl.lowest_common_ancestor(55, 100, 2) == 90
assert bl.lowest_common_ancestor(16, 80, 2) == 50
assert bl.lowest_common_ancestor(30, 55, 2) == 50

assert raises(ValueError, lambda: bl.lowest_common_ancestor(60, 200, 2))
assert raises(ValueError, lambda: bl.lowest_common_ancestor(200, 60, 2))
assert raises(ValueError, lambda: bl.lowest_common_ancestor(-3, 4, 2))

assert bl.lowest_common_ancestor(80, 55, 1) == 70
assert bl.lowest_common_ancestor(60, 70, 1) == 70
assert bl.lowest_common_ancestor(18, 18, 1) == 18
assert bl.lowest_common_ancestor(40, 90, 1) == 50

assert bl.lowest_common_ancestor(18, 10, 1) == 15
assert bl.lowest_common_ancestor(55, 100, 1) == 90
assert bl.lowest_common_ancestor(16, 80, 1) == 50
assert bl.lowest_common_ancestor(30, 55, 1) == 50

assert raises(ValueError, lambda: bl.lowest_common_ancestor(60, 200, 1))
assert raises(ValueError, lambda: bl.lowest_common_ancestor(200, 60, 1))
assert raises(ValueError, lambda: bl.lowest_common_ancestor(-3, 4, 1))

def test_BinaryTreeTraversal():
BST = BinarySearchTree
Expand Down