Skip to content

Commit 0469b29

Browse files
committed
try better parallelization
1 parent 0eb1fb4 commit 0469b29

File tree

2 files changed

+43
-64
lines changed

2 files changed

+43
-64
lines changed

cmd/migration-checker/main.go

+33-54
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import (
2424
)
2525

2626
var accountsDone atomic.Uint64
27-
var trieCheckers chan struct{}
27+
var numLoaders int
28+
var trieLoaders chan struct{}
2829

2930
type dbs struct {
3031
zkDb *leveldb.Database
@@ -54,33 +55,19 @@ func main() {
5455
zkRootHash := common.HexToHash(*zkRoot)
5556
mptRootHash := common.HexToHash(*mptRoot)
5657

57-
numTrieCheckers := runtime.GOMAXPROCS(0) * (*parallelismMultipler)
58-
trieCheckers = make(chan struct{}, numTrieCheckers)
59-
for i := 0; i < numTrieCheckers; i++ {
60-
trieCheckers <- struct{}{}
58+
numLoaders = runtime.GOMAXPROCS(0) * (*parallelismMultipler)
59+
trieLoaders = make(chan struct{}, numLoaders)
60+
for i := 0; i < numLoaders; i++ {
61+
trieLoaders <- struct{}{}
6162
}
6263

63-
done := make(chan struct{})
64-
totalCheckers := len(trieCheckers)
65-
go func() {
66-
for {
67-
select {
68-
case <-done:
69-
return
70-
case <-time.After(time.Minute):
71-
fmt.Println("Active checkers:", totalCheckers-len(trieCheckers))
72-
}
73-
}
74-
}()
75-
defer close(done)
76-
7764
checkTrieEquality(&dbs{
7865
zkDb: zkDb,
7966
mptDb: mptDb,
8067
}, zkRootHash, mptRootHash, "", checkAccountEquality, true, *paranoid)
8168

82-
for i := 0; i < numTrieCheckers; i++ {
83-
<-trieCheckers
69+
for i := 0; i < numLoaders; i++ {
70+
<-trieLoaders
8471
}
8572
}
8673

@@ -177,8 +164,9 @@ func checkAccountEquality(label string, dbs *dbs, zkAccountBytes, mptAccountByte
177164
} else if zkAccount.Root != (common.Hash{}) {
178165
zkRoot := common.BytesToHash(zkAccount.Root[:])
179166
mptRoot := common.BytesToHash(mptAccount.Root[:])
180-
<-trieCheckers
167+
<-trieLoaders
181168
go func() {
169+
trieLoaders <- struct{}{}
182170
defer func() {
183171
if p := recover(); p != nil {
184172
fmt.Println(p)
@@ -189,7 +177,6 @@ func checkAccountEquality(label string, dbs *dbs, zkAccountBytes, mptAccountByte
189177
checkTrieEquality(dbs, zkRoot, mptRoot, label, checkStorageEquality, false, paranoid)
190178
accountsDone.Add(1)
191179
fmt.Println("Accounts done:", accountsDone.Load())
192-
trieCheckers <- struct{}{}
193180
}()
194181
} else {
195182
accountsDone.Add(1)
@@ -207,42 +194,33 @@ func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes
207194
}
208195
}
209196

210-
func loadMPT(mptTrie *trie.SecureTrie, parallel bool) chan map[string][]byte {
197+
func loadMPT(mptTrie *trie.SecureTrie, top bool) chan map[string][]byte {
211198
startKey := make([]byte, 32)
212-
workers := 1 << 5
213-
if !parallel {
214-
workers = 1
215-
}
216-
step := byte(0xFF) / byte(workers)
217-
218199
mptLeafMap := make(map[string][]byte, 1000)
219200
var mptLeafMutex sync.Mutex
220201

221202
var mptWg sync.WaitGroup
222-
for i := 0; i < workers; i++ {
223-
startKey[0] = byte(i) * step
203+
for i := 0; i < 255; i++ {
204+
startKey[0] = byte(i)
224205
trieIt := trie.NewIterator(mptTrie.NodeIterator(startKey))
225206

226207
mptWg.Add(1)
208+
<-trieLoaders
227209
go func() {
228-
defer mptWg.Done()
210+
defer func() {
211+
mptWg.Done()
212+
trieLoaders <- struct{}{}
213+
}()
229214
for trieIt.Next() {
230-
if parallel {
231-
mptLeafMutex.Lock()
232-
}
233-
215+
mptLeafMutex.Lock()
234216
if _, ok := mptLeafMap[string(trieIt.Key)]; ok {
235217
mptLeafMutex.Unlock()
236218
break
237219
}
238-
239220
mptLeafMap[string(dup(trieIt.Key))] = dup(trieIt.Value)
221+
mptLeafMutex.Unlock()
240222

241-
if parallel {
242-
mptLeafMutex.Unlock()
243-
}
244-
245-
if parallel && len(mptLeafMap)%10000 == 0 {
223+
if top && len(mptLeafMap)%10000 == 0 {
246224
fmt.Println("MPT Accounts Loaded:", len(mptLeafMap))
247225
}
248226
}
@@ -257,7 +235,8 @@ func loadMPT(mptTrie *trie.SecureTrie, parallel bool) chan map[string][]byte {
257235
return respChan
258236
}
259237

260-
func loadZkTrie(zkTrie *trie.ZkTrie, parallel, paranoid bool) chan map[string][]byte {
238+
func loadZkTrie(zkTrie *trie.ZkTrie, top, paranoid bool) chan map[string][]byte {
239+
parallelismCutoffDepth := 8
261240
zkLeafMap := make(map[string][]byte, 1000)
262241
var zkLeafMutex sync.Mutex
263242
zkDone := make(chan map[string][]byte)
@@ -268,20 +247,20 @@ func loadZkTrie(zkTrie *trie.ZkTrie, parallel, paranoid bool) chan map[string][]
268247
panic(fmt.Sprintf("preimage not found zk trie %s", hex.EncodeToString(key)))
269248
}
270249

271-
if parallel {
272-
zkLeafMutex.Lock()
273-
}
274-
250+
zkLeafMutex.Lock()
275251
zkLeafMap[string(dup(preimageKey))] = value
252+
zkLeafMutex.Unlock()
276253

277-
if parallel {
278-
zkLeafMutex.Unlock()
279-
}
280-
281-
if parallel && len(zkLeafMap)%10000 == 0 {
254+
if top && len(zkLeafMap)%10000 == 0 {
282255
fmt.Println("ZK Accounts Loaded:", len(zkLeafMap))
283256
}
284-
}, parallel, paranoid)
257+
}, func(f func()) {
258+
<-trieLoaders
259+
go func() {
260+
f()
261+
trieLoaders <- struct{}{}
262+
}()
263+
}, parallelismCutoffDepth, paranoid)
285264
zkDone <- zkLeafMap
286265
}()
287266
return zkDone

trie/zk_trie.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,15 @@ func (t *ZkTrie) Witness() map[string]struct{} {
239239
panic("not implemented")
240240
}
241241

242-
func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel, verifyNodeHashes bool) uint64 {
242+
func (t *ZkTrie) CountLeaves(cb func(key, value []byte), spawnWorker func(func()), parallelismCutoffDepth int, verifyNodeHashes bool) uint64 {
243243
root, err := t.ZkTrie.Tree().Root()
244244
if err != nil {
245245
panic("CountLeaves cannot get root")
246246
}
247-
return t.countLeaves(root, cb, 0, parallel, verifyNodeHashes)
247+
return t.countLeaves(0, root, cb, spawnWorker, parallelismCutoffDepth, verifyNodeHashes)
248248
}
249249

250-
func (t *ZkTrie) countLeaves(root *zkt.Hash, cb func(key, value []byte), depth int, parallel, verifyNodeHashes bool) uint64 {
250+
func (t *ZkTrie) countLeaves(depth int, root *zkt.Hash, cb func(key, value []byte), spawnWorker func(func()), parallelismCutoffDepth int, verifyNodeHashes bool) uint64 {
251251
if root == nil {
252252
return 0
253253
}
@@ -271,19 +271,19 @@ func (t *ZkTrie) countLeaves(root *zkt.Hash, cb func(key, value []byte), depth i
271271
cb(append([]byte{}, rootNode.NodeKey.Bytes()...), append([]byte{}, rootNode.Data()...))
272272
return 1
273273
} else {
274-
if parallel && depth < 5 {
274+
if depth < parallelismCutoffDepth {
275275
count := make(chan uint64)
276276
leftT := t.Copy()
277277
rightT := t.Copy()
278278
go func() {
279-
count <- leftT.countLeaves(rootNode.ChildL, cb, depth+1, parallel, verifyNodeHashes)
279+
spawnWorker(func() {
280+
count <- leftT.countLeaves(depth+1, rootNode.ChildL, cb, spawnWorker, parallelismCutoffDepth, verifyNodeHashes)
281+
})
280282
}()
281-
go func() {
282-
count <- rightT.countLeaves(rootNode.ChildR, cb, depth+1, parallel, verifyNodeHashes)
283-
}()
284-
return <-count + <-count
283+
return rightT.countLeaves(depth+1, rootNode.ChildR, cb, spawnWorker, parallelismCutoffDepth, verifyNodeHashes) + <-count
285284
} else {
286-
return t.countLeaves(rootNode.ChildL, cb, depth+1, parallel, verifyNodeHashes) + t.countLeaves(rootNode.ChildR, cb, depth+1, parallel, verifyNodeHashes)
285+
return t.countLeaves(depth+1, rootNode.ChildL, cb, spawnWorker, parallelismCutoffDepth, verifyNodeHashes) +
286+
t.countLeaves(depth+1, rootNode.ChildR, cb, spawnWorker, parallelismCutoffDepth, verifyNodeHashes)
287287
}
288288
}
289289
}

0 commit comments

Comments
 (0)