diff --git a/roller/core/prover/prover.go b/roller/core/prover/prover.go index 32762039b0..4a61fa492d 100644 --- a/roller/core/prover/prover.go +++ b/roller/core/prover/prover.go @@ -42,11 +42,7 @@ func NewProver(cfg *config.ProverConfig) (*Prover, error) { // Prove call rust ffi to generate proof, if first failed, try again. func (p *Prover) Prove(traces *types.BlockResult) (*message.AggProof, error) { - proof, err := p.prove(traces) - if err != nil { - return p.prove(traces) - } - return proof, nil + return p.prove(traces) } func (p *Prover) prove(traces *types.BlockResult) (*message.AggProof, error) { diff --git a/roller/core/roller.go b/roller/core/roller.go index 134635114a..1f3103fb01 100644 --- a/roller/core/roller.go +++ b/roller/core/roller.go @@ -111,12 +111,7 @@ func (r *Roller) Register() error { return fmt.Errorf("sign auth message failed %v", err) } - msgByt, err := MakeMsgByt(message.Register, authMsg) - if err != nil { - return err - } - - return r.conn.WriteMessage(websocket.BinaryMessage, msgByt) + return r.sendMessage(message.Register, authMsg) } // HandleScroll accepts block-traces from Scroll through the Websocket and store it into Stack. @@ -187,6 +182,14 @@ func (r *Roller) ProveLoop() (err error) { } } +func (r *Roller) sendMessage(msgType message.MsgType, payload interface{}) error { + msgByt, err := MakeMsgByt(msgType, payload) + if err != nil { + return err + } + return r.conn.WriteMessage(websocket.BinaryMessage, msgByt) +} + func (r *Roller) handMessage() error { mt, msg, err := r.conn.ReadMessage() if err != nil { @@ -207,32 +210,49 @@ func (r *Roller) prove() error { if err != nil { return err } - log.Info("start to prove block", "block-id", traces.ID) var proofMsg *message.ProofMsg - proof, err := r.prover.Prove(traces.Traces) + if traces.Times > 2 { + proofMsg = &message.ProofMsg{ + Status: message.StatusProofError, + Error: "prover has retried several times due to FFI panic", + ID: traces.Traces.ID, + Proof: &message.AggProof{}, + } + return r.sendMessage(message.Proof, proofMsg) + } + + err = r.stack.Push(traces) + if err != nil { + return err + } + + log.Info("start to prove block", "block-id", traces.Traces.ID) + + proof, err := r.prover.Prove(traces.Traces.Traces) if err != nil { proofMsg = &message.ProofMsg{ Status: message.StatusProofError, Error: err.Error(), - ID: traces.ID, + ID: traces.Traces.ID, Proof: &message.AggProof{}, } - log.Error("prove block failed!", "block-id", traces.ID) + log.Error("prove block failed!", "block-id", traces.Traces.ID) } else { + proofMsg = &message.ProofMsg{ Status: message.StatusOk, - ID: traces.ID, + ID: traces.Traces.ID, Proof: proof, } - log.Info("prove block successfully!", "block-id", traces.ID) + log.Info("prove block successfully!", "block-id", traces.Traces.ID) } - - msgByt, err := MakeMsgByt(message.Proof, proofMsg) + _, err = r.stack.Pop() if err != nil { return err } - return r.conn.WriteMessage(websocket.BinaryMessage, msgByt) + + return r.sendMessage(message.Proof, proofMsg) } // Close closes the websocket connection. @@ -266,7 +286,10 @@ func (r *Roller) persistTrace(byt []byte) error { return err } log.Info("Accept BlockTrace from Scroll", "ID", traces.ID) - return r.stack.Push(traces) + return r.stack.Push(&store.ProvingTraces{ + Traces: traces, + Times: 0, + }) } func (r *Roller) loadOrCreateKey() (*ecdsa.PrivateKey, error) { diff --git a/roller/store/stack.go b/roller/store/stack.go index 83a4fda6a2..be658da69c 100644 --- a/roller/store/stack.go +++ b/roller/store/stack.go @@ -21,6 +21,14 @@ type Stack struct { *bbolt.DB } +// ProvingTraces is the value in stack. +// It contains traces and proved times. +type ProvingTraces struct { + Traces *rollertypes.BlockTraces `json:"traces"` + // Times is how many times roller proved. + Times int `json:"times"` +} + var bucket = []byte("stack") // NewStack new a Stack object. @@ -40,20 +48,20 @@ func NewStack(path string) (*Stack, error) { } // Push appends the block-traces on the top of Stack. -func (s *Stack) Push(traces *rollertypes.BlockTraces) error { +func (s *Stack) Push(traces *ProvingTraces) error { byt, err := json.Marshal(traces) if err != nil { return err } key := make([]byte, 8) - binary.BigEndian.PutUint64(key, traces.ID) + binary.BigEndian.PutUint64(key, traces.Traces.ID) return s.Update(func(tx *bbolt.Tx) error { return tx.Bucket(bucket).Put(key, byt) }) } // Pop pops the block-traces on the top of Stack. -func (s *Stack) Pop() (*rollertypes.BlockTraces, error) { +func (s *Stack) Pop() (*ProvingTraces, error) { var value []byte if err := s.Update(func(tx *bbolt.Tx) error { var key []byte @@ -68,6 +76,11 @@ func (s *Stack) Pop() (*rollertypes.BlockTraces, error) { return nil, ErrEmpty } - traces := &rollertypes.BlockTraces{} - return traces, json.Unmarshal(value, traces) + traces := &ProvingTraces{} + err := json.Unmarshal(value, traces) + if err != nil { + return nil, err + } + traces.Times++ + return traces, nil } diff --git a/roller/store/stack_test.go b/roller/store/stack_test.go index 76952cc619..c914e387ec 100644 --- a/roller/store/stack_test.go +++ b/roller/store/stack_test.go @@ -4,11 +4,10 @@ import ( "io/ioutil" "os" "path/filepath" + "scroll-tech/common/message" "testing" "github.com/stretchr/testify/assert" - - "scroll-tech/common/message" ) func TestStack(t *testing.T) { @@ -23,17 +22,41 @@ func TestStack(t *testing.T) { defer s.Close() for i := 0; i < 3; i++ { - trace := &message.BlockTraces{ - ID: uint64(i), - Traces: nil, + trace := &ProvingTraces{ + Traces: &message.BlockTraces{ + ID: uint64(i), + Traces: nil, + }, + Times: 0, } - err := s.Push(trace) + + err = s.Push(trace) assert.NoError(t, err) } for i := 2; i >= 0; i-- { - trace, err := s.Pop() + var pop *ProvingTraces + pop, err = s.Pop() assert.NoError(t, err) - assert.Equal(t, uint64(i), trace.ID) + assert.Equal(t, uint64(i), pop.Traces.ID) } + + // test times + trace := &ProvingTraces{ + Traces: &message.BlockTraces{ + ID: 1, + Traces: nil, + }, + Times: 0, + } + err = s.Push(trace) + assert.NoError(t, err) + pop, err := s.Pop() + assert.NoError(t, err) + err = s.Push(pop) + assert.NoError(t, err) + + pop2, err := s.Pop() + assert.NoError(t, err) + assert.Equal(t, 2, pop2.Times) }