Skip to content

Commit 38e8712

Browse files
authored
Added cocktail shaker sort (#312)
1 parent 42832f7 commit 38e8712

File tree

3 files changed

+82
-4
lines changed

3 files changed

+82
-4
lines changed

pydatastructs/linear_data_structures/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
heapsort,
2929
matrix_multiply_parallel,
3030
counting_sort,
31-
bucket_sort
31+
bucket_sort,
32+
cocktail_shaker_sort
3233
)
3334
__all__.extend(algorithms.__all__)

pydatastructs/linear_data_structures/algorithms.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
'matrix_multiply_parallel',
1313
'counting_sort',
1414
'bucket_sort',
15+
'cocktail_shaker_sort'
1516
]
1617

1718
def _merge(array, sl, el, sr, er, end, comp):
@@ -498,7 +499,6 @@ def bucket_sort(array: Array, **kwargs) -> Array:
498499
499500
This function does not support custom comparators as is the case with
500501
other sorting functions in this file.
501-
The ouput array doesn't contain any `None` value.
502502
"""
503503
start = kwargs.get('start', 0)
504504
end = kwargs.get('end', len(array) - 1)
@@ -546,3 +546,78 @@ def bucket_sort(array: Array, **kwargs) -> Array:
546546
if _check_type(array, DynamicArray):
547547
array._modify(force=True)
548548
return array
549+
550+
def cocktail_shaker_sort(array: Array, **kwargs) -> Array:
551+
"""
552+
Performs cocktail sort on the given array.
553+
554+
Parameters
555+
==========
556+
557+
array: Array
558+
The array which is to be sorted.
559+
start: int
560+
The starting index of the portion
561+
which is to be sorted.
562+
Optional, by default 0
563+
end: int
564+
The ending index of the portion which
565+
is to be sorted.
566+
Optional, by default the index
567+
of the last position filled.
568+
comp: lambda/function
569+
The comparator which is to be used
570+
for sorting. If the function returns
571+
False then only swapping is performed.
572+
Optional, by default, less than or
573+
equal to is used for comparing two
574+
values.
575+
576+
Returns
577+
=======
578+
579+
output: Array
580+
The sorted array.
581+
582+
Examples
583+
========
584+
585+
>>> from pydatastructs import OneDimensionalArray as ODA, cocktail_shaker_sort
586+
>>> arr = ODA(int, [5, 78, 1, 0])
587+
>>> out = cocktail_shaker_sort(arr)
588+
>>> str(out)
589+
'[0, 1, 5, 78]'
590+
>>> arr = ODA(int, [21, 37, 5])
591+
>>> out = cocktail_shaker_sort(arr)
592+
>>> str(out)
593+
'[5, 21, 37]'
594+
595+
References
596+
==========
597+
598+
.. [1] https://en.wikipedia.org/wiki/Cocktail_shaker_sort
599+
"""
600+
def swap(i, j):
601+
array[i], array[j] = array[j], array[i]
602+
603+
lower = kwargs.get('start', 0)
604+
upper = kwargs.get('end', len(array) - 1)
605+
comp = kwargs.get("comp", lambda u, v: u <= v)
606+
607+
swapping = False
608+
while (not swapping and upper - lower >= 1):
609+
610+
swapping = True
611+
for j in range(lower, upper):
612+
if _comp(array[j], array[j+1], comp) is False:
613+
swap(j + 1, j)
614+
swapping = False
615+
616+
upper = upper - 1
617+
for j in range(upper, lower, -1):
618+
if _comp(array[j-1], array[j], comp) is False:
619+
swap(j, j - 1)
620+
swapping = False
621+
lower = lower + 1
622+
623+
return array

pydatastructs/linear_data_structures/tests/test_algorithms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pydatastructs import (
22
merge_sort_parallel, DynamicOneDimensionalArray,
33
OneDimensionalArray, brick_sort, brick_sort_parallel,
4-
heapsort, matrix_multiply_parallel, counting_sort, bucket_sort)
4+
heapsort, matrix_multiply_parallel, counting_sort, bucket_sort, cocktail_shaker_sort)
55
from pydatastructs.utils.raises_util import raises
66
import random
77

@@ -29,7 +29,6 @@ def _test_common_sort(sort, *args, **kwargs):
2929
None, None, None, None, None,
3030
None, None, None, None, None, None, None]
3131
assert arr._data == expected_arr
32-
assert (arr._last_pos_filled, arr._num, arr._size) == (12, 13, 31)
3332

3433
n = random.randint(10, 20)
3534
arr = OneDimensionalArray(int, n)
@@ -70,6 +69,9 @@ def test_counting_sort():
7069
480, 548, 686, 688, 696, 779]
7170
assert counting_sort(arr)._data == expected_arr
7271

72+
def test_cocktail_shaker_sort():
73+
_test_common_sort(cocktail_shaker_sort)
74+
7375
def test_matrix_multiply_parallel():
7476
ODA = OneDimensionalArray
7577

0 commit comments

Comments
 (0)