Skip to content

Commit 6e6d4d7

Browse files
authored
refactor heapsort to support generic (TheAlgorithms#553)
* refactor: generic heapsort * revert: remove generic test * revert: max heap * refactor: make max heap more generic * fix: zero index * revert: generic MaxHeap * revert: heapifyDown
1 parent 03c8ce8 commit 6e6d4d7

File tree

2 files changed

+41
-38
lines changed

2 files changed

+41
-38
lines changed

sort/heapsort.go

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
11
package sort
22

3+
import "github.com/TheAlgorithms/Go/constraints"
4+
35
type MaxHeap struct {
46
slice []Comparable
57
heapSize int
68
indices map[int]int
79
}
810

9-
func buildMaxHeap(slice0 []int) MaxHeap {
10-
var slice []Comparable
11-
for _, i := range slice0 {
12-
slice = append(slice, Int(i))
13-
}
14-
h := MaxHeap{}
15-
h.Init(slice)
16-
return h
17-
}
18-
1911
func (h *MaxHeap) Init(slice []Comparable) {
2012
if slice == nil {
2113
slice = make([]Comparable, 0)
@@ -73,62 +65,73 @@ func (h MaxHeap) updateidx(i int) {
7365
h.indices[h.slice[i].Idx()] = i
7466
}
7567

68+
func (h *MaxHeap) swap(i, j int) {
69+
h.slice[i], h.slice[j] = h.slice[j], h.slice[i]
70+
h.updateidx(i)
71+
h.updateidx(j)
72+
}
73+
74+
func (h MaxHeap) more(i, j int) bool {
75+
return h.slice[i].More(h.slice[j])
76+
}
77+
7678
func (h MaxHeap) heapifyUp(i int) {
7779
if i == 0 {
7880
return
7981
}
8082
p := i / 2
8183

8284
if h.slice[i].More(h.slice[p]) {
83-
h.slice[i], h.slice[p] = h.slice[p], h.slice[i]
84-
h.updateidx(i)
85-
h.updateidx(p)
85+
h.swap(i, p)
8686
h.heapifyUp(p)
8787
}
8888
}
8989

9090
func (h MaxHeap) heapifyDown(i int) {
91+
heapifyDown(h.slice, h.heapSize, i, h.more, h.swap)
92+
}
93+
94+
func heapifyDown[T any](slice []T, N, i int, moreFunc func(i, j int) bool, swapFunc func(i, j int)) {
9195
l, r := 2*i+1, 2*i+2
9296
max := i
9397

94-
if l < h.heapSize && h.slice[l].More(h.slice[max]) {
98+
if l < N && moreFunc(l, max) {
9599
max = l
96100
}
97-
if r < h.heapSize && h.slice[r].More(h.slice[max]) {
101+
if r < N && moreFunc(r, max) {
98102
max = r
99103
}
100104
if max != i {
101-
h.slice[i], h.slice[max] = h.slice[max], h.slice[i]
102-
h.updateidx(i)
103-
h.updateidx(max)
104-
h.heapifyDown(max)
105+
swapFunc(i, max)
106+
107+
heapifyDown(slice, N, max, moreFunc, swapFunc)
105108
}
106109
}
107110

108111
type Comparable interface {
109112
Idx() int
110113
More(any) bool
111114
}
112-
type Int int
113115

114-
func (a Int) More(b any) bool {
115-
return a > b.(Int)
116-
}
117-
func (a Int) Idx() int {
118-
return int(a)
119-
}
116+
func HeapSort[T constraints.Ordered](slice []T) []T {
117+
N := len(slice)
120118

121-
func HeapSort(slice []int) []int {
122-
h := buildMaxHeap(slice)
123-
for i := len(h.slice) - 1; i >= 1; i-- {
124-
h.slice[0], h.slice[i] = h.slice[i], h.slice[0]
125-
h.heapSize--
126-
h.heapifyDown(0)
119+
moreFunc := func(i, j int) bool {
120+
return slice[i] > slice[j]
121+
}
122+
swapFunc := func(i, j int) {
123+
slice[i], slice[j] = slice[j], slice[i]
124+
}
125+
126+
// build a maxheap
127+
for i := N/2 - 1; i >= 0; i-- {
128+
heapifyDown(slice, N, i, moreFunc, swapFunc)
127129
}
128130

129-
res := []int{}
130-
for _, i := range h.slice {
131-
res = append(res, int(i.(Int)))
131+
for i := N - 1; i > 0; i-- {
132+
slice[i], slice[0] = slice[0], slice[i]
133+
heapifyDown(slice, i, 0, moreFunc, swapFunc)
132134
}
133-
return res
135+
136+
return slice
134137
}

sort/sorts_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func TestMergeParallel(t *testing.T) {
118118
}
119119

120120
func TestHeap(t *testing.T) {
121-
testFramework(t, sort.HeapSort)
121+
testFramework(t, sort.HeapSort[int])
122122
}
123123

124124
func TestCount(t *testing.T) {
@@ -227,7 +227,7 @@ func BenchmarkMergeParallel(b *testing.B) {
227227
}
228228

229229
func BenchmarkHeap(b *testing.B) {
230-
benchmarkFramework(b, sort.HeapSort)
230+
benchmarkFramework(b, sort.HeapSort[int])
231231
}
232232

233233
func BenchmarkCount(b *testing.B) {

0 commit comments

Comments
 (0)