Skip to content

Commit 423ce92

Browse files
authored
Added Segment Trees (#10)
1 parent 353e34a commit 423ce92

File tree

4 files changed

+256
-1
lines changed

4 files changed

+256
-1
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@ install:
66
- pip install -r requirements.txt
77
script:
88
- pytest --doctest-modules
9-
- pytest

pydatastructs/trees/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@
55
Node, BinaryTree, BinarySearchTree
66
)
77
__all__.extend(binary_trees.__all__)
8+
9+
from . import space_partitioning_trees
10+
from .space_partitioning_trees import (
11+
OneDimensionalSegmentTree
12+
)
13+
__all__.extend(space_partitioning_trees.__all__)
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
from pydatastructs.trees.binary_trees import Node
2+
from collections import deque as Queue
3+
from pydatastructs.linear_data_structures.arrays import _check_type
4+
5+
__all__ = [
6+
'OneDimensionalSegmentTree'
7+
]
8+
9+
class OneDimensionalSegmentTree(object):
10+
"""
11+
Represents one dimensional segment trees.
12+
13+
Parameters
14+
==========
15+
16+
segs: list/tuple/set
17+
The segs should contains tuples/list/set of size 2
18+
denoting the start and end points of the intervals.
19+
20+
Examples
21+
========
22+
23+
>>> from pydatastructs import OneDimensionalSegmentTree as ODST
24+
>>> segt = ODST([(3, 8), (9, 20)])
25+
>>> segt.build()
26+
>>> segt.tree[0].key
27+
[False, 2, 3, False]
28+
>>> len(segt.query(4))
29+
1
30+
31+
Note
32+
====
33+
34+
All the segments are assumed to be closed intervals,
35+
i.e., the ends are points of segments are also included in
36+
computation.
37+
38+
References
39+
==========
40+
41+
.. [1] https://en.wikipedia.org/wiki/Segment_tree
42+
43+
"""
44+
45+
__slots__ = ['segments', 'tree', 'root_idx', 'cache']
46+
47+
def __new__(cls, segs):
48+
obj = object.__new__(cls)
49+
if any((not isinstance(seg, (tuple, list, set)) or len(seg) != 2)
50+
for seg in segs):
51+
raise ValueError('%s is invalid set of intervals'%(segs))
52+
for i in range(len(segs)):
53+
segs[i] = list(segs[i])
54+
segs[i].sort()
55+
obj.segments = [seg for seg in segs]
56+
obj.tree, obj.root_idx, obj.cache = [], None, False
57+
return obj
58+
59+
def _union(self, i1, i2):
60+
"""
61+
Helper function for taking union of two
62+
intervals.
63+
"""
64+
return Node([i1.key[0], i1.key[1], i2.key[2], i2.key[3]], None)
65+
66+
def _intersect(self, i1, i2):
67+
"""
68+
Helper function for finding intersection of two
69+
intervals.
70+
"""
71+
if i1 == None or i2 == None:
72+
return False
73+
if i1.key[2] < i2.key[1] or i2.key[2] < i1.key[1]:
74+
return False
75+
c1, c2 = None, None
76+
if i1.key[2] == i2.key[1]:
77+
c1 = (i1.key[3] and i2.key[0])
78+
if i2.key[2] == i1.key[1]:
79+
c2 = (i2.key[3] and i1.key[0])
80+
if c1 == False and c2 == False:
81+
return False
82+
return True
83+
84+
def _contains(self, i1, i2):
85+
"""
86+
Helper function for checking if the first interval
87+
is contained in second interval.
88+
"""
89+
if i1 == None or i2 == None:
90+
return False
91+
if i1.key[1] < i2.key[1] and i1.key[2] > i2.key[2]:
92+
return True
93+
if i1.key[1] == i2.key[1] and i1.key[2] > i2.key[2]:
94+
return (i1.key[0] or not i2.key[0])
95+
if i1.key[1] < i2.key[1] and i1.key[2] == i2.key[2]:
96+
return i1.key[3] or not i2.key[3]
97+
if i1.key[1] == i2.key[1] and i1.key[2] == i2.key[2]:
98+
return not ((not i1.key[3] and i2.key[3]) or (not i1.key[0] and i2.key[0]))
99+
return False
100+
101+
def _iterate(self, calls, I, idx):
102+
"""
103+
Helper function for filling the calls
104+
stack. Used for imitating the stack based
105+
approach used in recursion.
106+
"""
107+
if self.tree[idx].right == None:
108+
rc = None
109+
else:
110+
rc = self.tree[self.tree[idx].right]
111+
if self.tree[idx].left == None:
112+
lc = None
113+
else:
114+
lc = self.tree[self.tree[idx].left]
115+
if self._intersect(I, rc):
116+
calls.append(self.tree[idx].right)
117+
if self._intersect(I, lc):
118+
calls.append(self.tree[idx].left)
119+
return calls
120+
121+
def build(self):
122+
"""
123+
Builds the segment tree from the segments,
124+
using iterative algorithm based on stacks.
125+
"""
126+
if self.cache:
127+
return None
128+
endpoints = []
129+
for segment in self.segments:
130+
endpoints.extend(segment)
131+
endpoints.sort()
132+
133+
elem_int = Queue()
134+
elem_int.append(Node([False, endpoints[0] - 1, endpoints[0], False], None))
135+
i = 0
136+
while i < len(endpoints) - 1:
137+
elem_int.append(Node([True, endpoints[i], endpoints[i], True], None))
138+
elem_int.append(Node([False, endpoints[i], endpoints[i+1], False], None))
139+
i += 1
140+
elem_int.append(Node([True, endpoints[i], endpoints[i], True], None))
141+
elem_int.append(Node([False, endpoints[i], endpoints[i] + 1, False], None))
142+
143+
self.tree = []
144+
while len(elem_int) > 1:
145+
m = len(elem_int)
146+
while m >= 2:
147+
I1 = elem_int.popleft()
148+
I2 = elem_int.popleft()
149+
I = self._union(I1, I2)
150+
I.left = len(self.tree)
151+
I.right = len(self.tree) + 1
152+
self.tree.append(I1), self.tree.append(I2)
153+
elem_int.append(I)
154+
m -= 2
155+
if m & 1 == 1:
156+
Il = elem_int.popleft()
157+
elem_int.append(Il)
158+
159+
Ir = elem_int.popleft()
160+
Ir.left, Ir.right = -3, -2
161+
self.tree.append(Ir)
162+
self.root_idx = -1
163+
164+
for segment in self.segments:
165+
I = Node([True, segment[0], segment[1], True], None)
166+
calls = [self.root_idx]
167+
while calls:
168+
idx = calls.pop()
169+
if self._contains(I, self.tree[idx]):
170+
if self.tree[idx].data == None:
171+
self.tree[idx].data = []
172+
self.tree[idx].data.append(I)
173+
continue
174+
calls = self._iterate(calls, I, idx)
175+
self.cache = True
176+
177+
def query(self, qx, init_node=None):
178+
"""
179+
Queries the segment tree.
180+
181+
Parameters
182+
==========
183+
184+
qx: int/float
185+
The query point
186+
init_node: int
187+
The index of the node from which the query process
188+
is to be started.
189+
190+
Returns
191+
=======
192+
193+
intervals: set
194+
The set of the intervals which contain the query
195+
point.
196+
197+
References
198+
==========
199+
200+
.. [1] https://en.wikipedia.org/wiki/Segment_tree
201+
"""
202+
if not self.cache:
203+
self.build()
204+
if init_node == None:
205+
init_node = self.root_idx
206+
qn = Node([True, qx, qx, True], None)
207+
intervals = []
208+
calls = [init_node]
209+
while calls:
210+
idx = calls.pop()
211+
if _check_type(self.tree[idx].data, list):
212+
intervals.extend(self.tree[idx].data)
213+
calls = self._iterate(calls, qn, idx)
214+
return set(intervals)
215+
216+
def __str__(self):
217+
"""
218+
Used for printing.
219+
"""
220+
if not self.cache:
221+
self.build()
222+
str_tree = []
223+
for seg in self.tree:
224+
if seg.data == None:
225+
data = None
226+
else:
227+
data = [str(sd) for sd in seg.data]
228+
str_tree.append((seg.left, seg.key, data, seg.right))
229+
return str(str_tree)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from pydatastructs import OneDimensionalSegmentTree
2+
from pydatastructs.utils.raises_util import raises
3+
4+
def test_OneDimensionalSegmentTree():
5+
ODST = OneDimensionalSegmentTree
6+
segt = ODST([(0, 5), (1, 6), (9, 13), (1, 2), (3, 8), (9, 20)])
7+
assert segt.cache == False
8+
segt.build()
9+
assert segt.cache == True
10+
segt2 = ODST([(1, 4)])
11+
assert str(segt2) == ("[(None, [False, 0, 1, False], None, None), "
12+
"(None, [True, 1, 1, True], ['(None, [True, 1, 4, True], None, None)'], "
13+
"None), (None, [False, 1, 4, False], None, None), (None, [True, 4, 4, True], "
14+
"None, None), (0, [False, 0, 1, True], None, 1), (2, [False, 1, 4, True], "
15+
"['(None, [True, 1, 4, True], None, None)'], 3), (4, [False, 0, 4, True], "
16+
"None, 5), (None, [False, 4, 5, False], None, None), (-3, [False, 0, 5, "
17+
"False], None, -2)]")
18+
assert len(segt.query(1.5)) == 3
19+
assert len(segt.query(-1)) == 0
20+
assert len(segt.query(2.8)) == 2
21+
raises(ValueError, lambda: ODST([(1, 2, 3)]))

0 commit comments

Comments
 (0)