diff --git a/pydatastructs/trees/binary_trees.py b/pydatastructs/trees/binary_trees.py index 91fb96528..fce55fb21 100644 --- a/pydatastructs/trees/binary_trees.py +++ b/pydatastructs/trees/binary_trees.py @@ -415,6 +415,143 @@ def rank(self, x): walk = p return r + def _simple_path(self, key, root): + """ + Utility funtion to find the simple path between root and node. + + Parameter + ========= + + key: Node.key + Key of the node to be searched + + Returns + ======= + + path: list + """ + + stack = Stack() + stack.push(root) + path = [] + node_idx = -1 + + while not stack.is_empty: + node = stack.pop() + if self.tree[node].key == key: + node_idx = node + break + if self.tree[node].left: + stack.push(self.tree[node].left) + if self.tree[node].right: + stack.push(self.tree[node].right) + + if node_idx == -1: + return path + + while node_idx != 0: + path.append(node_idx) + node_idx = self.tree[node_idx].parent + path.append(0) + path.reverse() + + return path + + def _lca_1(self, j, k): + root = self.root_idx + path1 = self._simple_path(j, root) + path2 = self._simple_path(k, root) + if not path1 or not path2: + raise ValueError("One of two path doesn't exists. See %s, %s" + %(path1, path2)) + + n, m = len(path1), len(path2) + i = j = 0 + while i < n and j < m: + if path1[i] != path2[j]: + return self.tree[path1[i - 1]].key + i += 1 + j += 1 + if path1 < path2: + return self.tree[path1[-1]].key + return self.tree[path2[-1]].key + + def _lca_2(self, j, k): + curr_root = self.root_idx + u, v = self.search(j), self.search(k) + if (u is None) or (v is None): + raise ValueError("One of the nodes with key %s " + "or %s doesn't exits"%(j, k)) + u_left = self.comparator(self.tree[u].key, \ + self.tree[curr_root].key) + v_left = self.comparator(self.tree[v].key, \ + self.tree[curr_root].key) + + while not (u_left ^ v_left): + if u_left and v_left: + curr_root = self.tree[curr_root].left + else: + curr_root = self.tree[curr_root].right + + if curr_root == u or curr_root == v: + if curr_root is None: + return None + return self.tree[curr_root].key + + u_left = self.comparator(self.tree[u].key, \ + self.tree[curr_root].key) + v_left = self.comparator(self.tree[v].key, \ + self.tree[curr_root].key) + + if curr_root is None: + return curr_root + return self.tree[curr_root].key + + def lowest_common_ancestor(self, j, k, algorithm=1): + + """ + Computes the lowest common ancestor of two nodes. + + Parameters + ========== + + j: Node.key + Key of first node + k: Node.key + Key of second node + algorithm: int + The algorithm to be used for computing the + lowest common ancestor. + Optional, by default uses algorithm 1. + + 1 -> Determines the lowest common ancestor by finding + the first intersection of the paths from v and w + to the root. + + 2 -> Modifed version of the algorithm given in the + following publication, + D. Harel. A linear time algorithm for the + lowest common ancestors problem. In 21s + Annual Symposium On Foundations of + Computer Science, pages 308-319, 1980. + + Returns + ======= + + Node.key + The key of the lowest common ancestor in the tree. + if both the nodes are present in the tree. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lowest_common_ancestor + + .. [2] https://pdfs.semanticscholar.org/e75b/386cc554214aa0ebd6bd6dbdd0e490da3739.pdf + + """ + return getattr(self, "_lca_"+str(algorithm))(j, k) + class AVLTree(BinarySearchTree): """ Represents AVL trees. diff --git a/pydatastructs/trees/tests/test_binary_trees.py b/pydatastructs/trees/tests/test_binary_trees.py index 15bb2f79a..1bf391913 100644 --- a/pydatastructs/trees/tests/test_binary_trees.py +++ b/pydatastructs/trees/tests/test_binary_trees.py @@ -54,7 +54,38 @@ def test_BinarySearchTree(): assert b.delete(-10) is True assert b.delete(-3) is True assert b.delete(-13) is None - raises(ValueError, lambda: BST(root_data=6)) + bl = BST() + nodes = [50, 30, 90, 70, 100, 60, 80, 55, 20, 40, 15, 10, 16, 17, 18] + for node in nodes: + bl.insert(node, node) + + assert bl.lowest_common_ancestor(80, 55, 2) == 70 + assert bl.lowest_common_ancestor(60, 70, 2) == 70 + assert bl.lowest_common_ancestor(18, 18, 2) == 18 + assert bl.lowest_common_ancestor(40, 90, 2) == 50 + + assert bl.lowest_common_ancestor(18, 10, 2) == 15 + assert bl.lowest_common_ancestor(55, 100, 2) == 90 + assert bl.lowest_common_ancestor(16, 80, 2) == 50 + assert bl.lowest_common_ancestor(30, 55, 2) == 50 + + assert raises(ValueError, lambda: bl.lowest_common_ancestor(60, 200, 2)) + assert raises(ValueError, lambda: bl.lowest_common_ancestor(200, 60, 2)) + assert raises(ValueError, lambda: bl.lowest_common_ancestor(-3, 4, 2)) + + assert bl.lowest_common_ancestor(80, 55, 1) == 70 + assert bl.lowest_common_ancestor(60, 70, 1) == 70 + assert bl.lowest_common_ancestor(18, 18, 1) == 18 + assert bl.lowest_common_ancestor(40, 90, 1) == 50 + + assert bl.lowest_common_ancestor(18, 10, 1) == 15 + assert bl.lowest_common_ancestor(55, 100, 1) == 90 + assert bl.lowest_common_ancestor(16, 80, 1) == 50 + assert bl.lowest_common_ancestor(30, 55, 1) == 50 + + assert raises(ValueError, lambda: bl.lowest_common_ancestor(60, 200, 1)) + assert raises(ValueError, lambda: bl.lowest_common_ancestor(200, 60, 1)) + assert raises(ValueError, lambda: bl.lowest_common_ancestor(-3, 4, 1)) def test_BinaryTreeTraversal(): BST = BinarySearchTree