diff --git a/pydatastructs/trees/heaps.py b/pydatastructs/trees/heaps.py index ee69c7b99..4d6eb1ad5 100644 --- a/pydatastructs/trees/heaps.py +++ b/pydatastructs/trees/heaps.py @@ -1,6 +1,6 @@ from pydatastructs.utils.misc_util import _check_type, NoneType, TreeNode, BinomialTreeNode from pydatastructs.linear_data_structures.arrays import (ArrayForTrees, - DynamicOneDimensionalArray) + DynamicOneDimensionalArray, Array) from pydatastructs.miscellaneous_data_structures.binomial_trees import BinomialTree __all__ = [ @@ -24,9 +24,9 @@ class DHeap(Heap): Parameters ========== - elements : list, tuple + elements : list, tuple, Array Optional, by default 'None'. - List/tuple of initial TreeNode in Heap. + list/tuple/Array of initial TreeNode in Heap. heap_property : str @@ -84,9 +84,12 @@ def __new__(cls, elements=None, heap_property="min", d=4): raise ValueError("%s is invalid heap property"%(heap_property)) if elements is None: elements = DynamicOneDimensionalArray(TreeNode, 0) + elif _check_type(elements, (list,tuple)): + elements = DynamicOneDimensionalArray(TreeNode, len(elements), elements) + elif _check_type(elements, Array): + elements = DynamicOneDimensionalArray(TreeNode, len(elements), elements._data) else: - if not all(map(lambda x: _check_type(x, TreeNode), elements)): - raise ValueError("Expect a list/tuple of TreeNode got %s"%(elements)) + raise ValueError(f'Expected a list/tuple/Array of TreeNode got {type(elements)}') obj.heap = elements obj._last_pos_filled = obj.heap._last_pos_filled obj._build() @@ -326,7 +329,7 @@ class BinomialHeap(Heap): Parameters ========== - root_list: list/tuple + root_list: list/tuple/Array By default, [] The list of BinomialTree object references in sorted order. diff --git a/pydatastructs/trees/tests/test_heaps.py b/pydatastructs/trees/tests/test_heaps.py index ed9fedff4..dece2f132 100644 --- a/pydatastructs/trees/tests/test_heaps.py +++ b/pydatastructs/trees/tests/test_heaps.py @@ -41,7 +41,7 @@ def test_BinaryHeap(): TreeNode(1, 1), TreeNode(2, 2), TreeNode(3, 3), TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) ] - min_heap = BinaryHeap(elements=DynamicOneDimensionalArray(TreeNode, 9, elements), heap_property="min") + min_heap = BinaryHeap(elements=elements, heap_property="min") assert min_heap.extract().key == 1 expected_sorted_elements = [2, 3, 7, 17, 19, 25, 36, 100] @@ -53,8 +53,19 @@ def test_BinaryHeap(): TreeNode(1, 1), (2, 2), TreeNode(3, 3), TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) ] - assert raises(ValueError, lambda: + assert raises(TypeError, lambda: + BinaryHeap(elements = non_TreeNode_elements, heap_property='min')) + + non_TreeNode_elements = DynamicOneDimensionalArray(int, 0) + non_TreeNode_elements.append(1) + non_TreeNode_elements.append(2) + assert raises(TypeError, lambda: BinaryHeap(elements = non_TreeNode_elements, heap_property='min')) + + non_heapable = "[1, 2, 3]" + assert raises(ValueError, lambda: + BinaryHeap(elements = non_heapable, heap_property='min')) + def test_TernaryHeap(): max_heap = TernaryHeap(heap_property="max") assert raises(IndexError, lambda: max_heap.extract()) @@ -86,7 +97,7 @@ def test_TernaryHeap(): TreeNode(1, 1), TreeNode(2, 2), TreeNode(3, 3), TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) ] - min_heap = TernaryHeap(elements=DynamicOneDimensionalArray(TreeNode, 9, elements), heap_property="min") + min_heap = TernaryHeap(elements=elements, heap_property="min") expected_extracted_element = min_heap.heap[0].key assert min_heap.extract().key == expected_extracted_element