diff --git a/api/service/mempoolTransactionApiService.go b/api/service/mempoolTransactionApiService.go index 3d863dff5..4dd30072e 100644 --- a/api/service/mempoolTransactionApiService.go +++ b/api/service/mempoolTransactionApiService.go @@ -72,7 +72,7 @@ func (ut *MempoolTransactionService) GetMempoolTransactions( err error count uint64 selectQuery, countQuery string - rows *sql.Rows + rowCount *sql.Row rows2 *sql.Rows txs []*model.MempoolTransaction response *model.GetMempoolTransactionsResponse @@ -103,17 +103,13 @@ func (ut *MempoolTransactionService) GetMempoolTransactions( selectQuery, args = caseQuery.Build() countQuery = query.GetTotalRecordOfSelect(selectQuery) - rows, err = ut.Query.ExecuteSelect(countQuery, false, args...) + rowCount, err = ut.Query.ExecuteSelectRow(countQuery, false, args...) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } - defer rows.Close() - - if rows.Next() { - err = rows.Scan(&count) - if err != nil { - return response, status.Error(codes.Internal, err.Error()) - } + err = rowCount.Scan(&count) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) } // select records diff --git a/api/service/mempoolTransactionApiService_test.go b/api/service/mempoolTransactionApiService_test.go index 4adfd5500..4c3a03857 100644 --- a/api/service/mempoolTransactionApiService_test.go +++ b/api/service/mempoolTransactionApiService_test.go @@ -54,13 +54,13 @@ type ( } ) -func (*mockQueryExecutorGetMempoolTXsFail) ExecuteSelect(query string, tx bool, args ...interface{}) (*sql.Rows, error) { +func (*mockQueryExecutorGetMempoolTXsFail) ExecuteSelectRow(query string, tx bool, args ...interface{}) (*sql.Row, error) { return nil, errors.New("want error") } -func (*mockQueryExecutorGetMempoolTXsScanFail) ExecuteSelect(qStr string, tx bool, args ...interface{}) (*sql.Rows, error) { +func (*mockQueryExecutorGetMempoolTXsScanFail) ExecuteSelectRow(qStr string, tx bool, args ...interface{}) (*sql.Row, error) { db, mock, _ := sqlmock.New() mock.ExpectQuery(regexp.QuoteMeta(qStr)).WillReturnRows(sqlmock.NewRows([]string{"one", "two"}).AddRow(1, 2)) - return db.Query(qStr) + return db.QueryRow(qStr), nil } func (*mockQueryExecutorGetMempoolTXs) ExecuteSelect(qStr string, tx bool, args ...interface{}) (*sql.Rows, error) { db, mock, _ := sqlmock.New() @@ -82,6 +82,12 @@ func (*mockQueryExecutorGetMempoolTXs) ExecuteSelect(qStr string, tx bool, args } return db.Query(qStr) } +func (*mockQueryExecutorGetMempoolTXs) ExecuteSelectRow(qStr string, tx bool, args ...interface{}) (*sql.Row, error) { + db, mock, _ := sqlmock.New() + defer db.Close() + mock.ExpectQuery(regexp.QuoteMeta(qStr)).WillReturnRows(sqlmock.NewRows([]string{"total_record"}).AddRow(1)) + return db.QueryRow(qStr), nil +} func TestMempoolTransactionService_GetMempoolTransactions(t *testing.T) { type fields struct { Query query.ExecutorInterface diff --git a/api/service/nodeRegistryApiService.go b/api/service/nodeRegistryApiService.go index 234eb78fb..550911f55 100644 --- a/api/service/nodeRegistryApiService.go +++ b/api/service/nodeRegistryApiService.go @@ -33,7 +33,7 @@ func (ns NodeRegistryService) GetNodeRegistrations(params *model.GetNodeRegistra var ( err error - rows *sql.Rows + rowCount *sql.Row rows2 *sql.Rows selectQuery string args []interface{} @@ -57,19 +57,15 @@ func (ns NodeRegistryService) GetNodeRegistrations(params *model.GetNodeRegistra // count first selectQuery, args = caseQuery.Build() countQuery := query.GetTotalRecordOfSelect(selectQuery) - rows, err = ns.Query.ExecuteSelect(countQuery, false, args...) + rowCount, err = ns.Query.ExecuteSelectRow(countQuery, false, args...) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } - defer rows.Close() - - if rows.Next() { - err = rows.Scan( - &totalRecords, - ) - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } + err = rowCount.Scan( + &totalRecords, + ) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) } if page.GetOrderField() == "" { diff --git a/api/service/nodeRegistryApiService_test.go b/api/service/nodeRegistryApiService_test.go index 1a26ab476..666aa8462 100644 --- a/api/service/nodeRegistryApiService_test.go +++ b/api/service/nodeRegistryApiService_test.go @@ -60,30 +60,41 @@ func (*mockQueryGetNodeRegistrationsFail) ExecuteSelect(query string, tx bool, a return nil, errors.New("want error") } +func (*mockQueryGetNodeRegistrationsFail) ExecuteSelectRow(query string, tx bool, args ...interface{}) (*sql.Row, error) { + return nil, errors.New("want error") +} + func (*mockQueryGetNodeRegistrationsSuccess) ExecuteSelect(qStr string, tx bool, args ...interface{}) (*sql.Rows, error) { db, mock, _ := sqlmock.New() defer db.Close() + mock.ExpectQuery(""). + WillReturnRows(sqlmock.NewRows(query.NewNodeRegistrationQuery().Fields). + AddRow( + 1, + []byte{1, 2}, + "AccountA", + 1, + "127.0.0.1", + 1, + uint32(model.NodeRegistrationState_NodeQueued), + true, + 1, + ), + ) + return db.Query("") +} + +func (*mockQueryGetNodeRegistrationsSuccess) ExecuteSelectRow(qStr string, tx bool, args ...interface{}) (*sql.Row, error) { + db, mock, _ := sqlmock.New() + defer db.Close() switch strings.Contains(qStr, "total_record") { case true: - mock.ExpectQuery("").WillReturnRows(sqlmock.NewRows([]string{"total_record"}).AddRow(1)) + mock.ExpectQuery(regexp.QuoteMeta(qStr)).WillReturnRows(sqlmock.NewRows([]string{"total_record"}).AddRow(1)) default: - mock.ExpectQuery(""). - WillReturnRows(sqlmock.NewRows(query.NewNodeRegistrationQuery().Fields). - AddRow( - 1, - []byte{1, 2}, - "AccountA", - 1, - "127.0.0.1", - 1, - uint32(model.NodeRegistrationState_NodeQueued), - true, - 1, - ), - ) + return nil, nil } - return db.Query("") + return db.QueryRow(qStr), nil } func TestNodeRegistryService_GetNodeRegistrations(t *testing.T) { diff --git a/api/service/transactionApiService.go b/api/service/transactionApiService.go index fc189514d..f923e29ee 100644 --- a/api/service/transactionApiService.go +++ b/api/service/transactionApiService.go @@ -99,7 +99,7 @@ func (ts *TransactionService) GetTransactions( ) (*model.GetTransactionsResponse, error) { var ( err error - rows *sql.Rows + rowCount *sql.Row rows2 *sql.Rows txs []*model.Transaction selectQuery string @@ -146,19 +146,15 @@ func (ts *TransactionService) GetTransactions( // count first countQuery := query.GetTotalRecordOfSelect(selectQuery) - rows, err = ts.Query.ExecuteSelect(countQuery, false, args...) + rowCount, err = ts.Query.ExecuteSelectRow(countQuery, false, args...) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } - defer rows.Close() - - if rows.Next() { - err = rows.Scan( - &totalRecords, - ) - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } + err = rowCount.Scan( + &totalRecords, + ) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) } // Get Transactions with Pagination diff --git a/api/service/transactionApiService_test.go b/api/service/transactionApiService_test.go index 7b770aa3a..03873945a 100644 --- a/api/service/transactionApiService_test.go +++ b/api/service/transactionApiService_test.go @@ -6,7 +6,6 @@ import ( "database/sql" "errors" "reflect" - "strings" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -568,35 +567,38 @@ type ( func (*mockQueryGetTransactionsFail) ExecuteSelect(query string, tx bool, args ...interface{}) (*sql.Rows, error) { return nil, errors.New("want error") } +func (*mockQueryGetTransactionsFail) ExecuteSelectRow(query string, tx bool, args ...interface{}) (*sql.Row, error) { + return nil, errors.New("want error") +} func (*mockQueryGetTransactionsSuccess) ExecuteSelect(qStr string, tx bool, args ...interface{}) (*sql.Rows, error) { db, mock, _ := sqlmock.New() - switch strings.Contains(qStr, "total_record") { - case true: - mock.ExpectQuery("").WillReturnRows(sqlmock.NewRows([]string{"total_record"}).AddRow(1)) - default: - mock.ExpectQuery(""). - WillReturnRows(sqlmock.NewRows(query.NewTransactionQuery(&chaintype.MainChain{}).Fields). - AddRow( - 4545420970999433273, - 1, - 1, - "senderA", - "recipientA", - 1, - 1, - 10000, - []byte{1, 1}, - 8, - []byte{1, 2, 3, 4, 5, 6, 7, 8}, - []byte{0, 0, 0, 0, 0, 0, 0}, - 1, - 1, - false, - ), - ) - } + mock.ExpectQuery(""). + WillReturnRows(sqlmock.NewRows(query.NewTransactionQuery(&chaintype.MainChain{}).Fields). + AddRow( + 4545420970999433273, + 1, + 1, + "senderA", + "recipientA", + 1, + 1, + 10000, + []byte{1, 1}, + 8, + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + []byte{0, 0, 0, 0, 0, 0, 0}, + 1, + 1, + false, + ), + ) return db.Query("") } +func (*mockQueryGetTransactionsSuccess) ExecuteSelectRow(qStr string, tx bool, args ...interface{}) (*sql.Row, error) { + db, mock, _ := sqlmock.New() + mock.ExpectQuery("").WillReturnRows(sqlmock.NewRows([]string{"total_record"}).AddRow(1)) + return db.QueryRow(""), nil +} func TestTransactionService_GetTransactions(t *testing.T) { type fields struct { Query query.ExecutorInterface diff --git a/cmd/genesisblock/genesisGenerator.go b/cmd/genesisblock/genesisGenerator.go index 97f91769e..4459c035c 100644 --- a/cmd/genesisblock/genesisGenerator.go +++ b/cmd/genesisblock/genesisGenerator.go @@ -257,19 +257,29 @@ func getDbLastState(dbPath string) (bcEntries []genesisEntry, err error) { if acc.AccountAddress == constant.MainchainGenesisAccountAddress { continue } + + var nodeRegistrations []*model.NodeRegistration + bcEntry := new(genesisEntry) bcEntry.AccountAddress = acc.AccountAddress bcEntry.AccountBalance = acc.Balance - // get node registration for this account, if exists - qry, args := nodeRegistrationQuery.GetNodeRegistrationByAccountAddress(acc.AccountAddress) - nrRows, err := queryExecutor.ExecuteSelect(qry, false, args...) - if err != nil { - return nil, err - } - defer nrRows.Close() + err := func() error { + // get node registration for this account, if exists + qry, args := nodeRegistrationQuery.GetNodeRegistrationByAccountAddress(acc.AccountAddress) + nrRows, err := queryExecutor.ExecuteSelect(qry, false, args...) + if err != nil { + return err + } + defer nrRows.Close() + + nodeRegistrations, err = nodeRegistrationQuery.BuildModel([]*model.NodeRegistration{}, nrRows) + if err != nil { + return err + } + return nil + }() - nodeRegistrations, err := nodeRegistrationQuery.BuildModel([]*model.NodeRegistration{}, nrRows) if err != nil { return nil, err } @@ -284,18 +294,24 @@ func getDbLastState(dbPath string) (bcEntries []genesisEntry, err error) { } bcEntry.NodePublicKey = nr.NodePublicKey bcEntry.NodePublicKeyB64 = base64.StdEncoding.EncodeToString(nr.NodePublicKey) - // get the participation score for this node registration - qry, args := participationScoreQuery.GetParticipationScoreByNodeID(nr.NodeID) - psRows, err := queryExecutor.ExecuteSelect(qry, false, args...) + err := func() error { + // get the participation score for this node registration + qry, args := participationScoreQuery.GetParticipationScoreByNodeID(nr.NodeID) + psRows, err := queryExecutor.ExecuteSelect(qry, false, args...) + if err != nil { + return err + } + defer psRows.Close() + + participationScores, err := participationScoreQuery.BuildModel([]*model.ParticipationScore{}, psRows) + if (err != nil) || len(participationScores) > 0 { + bcEntry.ParticipationScore = participationScores[0].Score + } + return nil + }() if err != nil { return nil, err } - defer psRows.Close() - - participationScores, err := participationScoreQuery.BuildModel([]*model.ParticipationScore{}, psRows) - if (err != nil) || len(participationScores) > 0 { - bcEntry.ParticipationScore = participationScores[0].Score - } } bcEntries = append(bcEntries, *bcEntry) } diff --git a/common/transaction/nodeRegistration.go b/common/transaction/nodeRegistration.go index 9d388b862..9053b5531 100644 --- a/common/transaction/nodeRegistration.go +++ b/common/transaction/nodeRegistration.go @@ -209,10 +209,10 @@ func (tx *NodeRegistration) UndoApplyUnconfirmed() error { // Validate validate node registration transaction and tx body func (tx *NodeRegistration) Validate(dbTx bool) error { - var ( accountBalance model.AccountBalance nodeRegistrations, nodeRegistrations2 []*model.NodeRegistration + err error ) // no need to validate node registration transaction for genesis block @@ -233,58 +233,78 @@ func (tx *NodeRegistration) Validate(dbTx bool) error { return blocker.NewBlocker(blocker.ValidationErr, err.Error()) } - // check balance - qry, args := tx.AccountBalanceQuery.GetAccountBalanceByAccountAddress(tx.SenderAddress) - rows, err := tx.QueryExecutor.ExecuteSelect(qry, dbTx, args...) - if err != nil { - return blocker.NewBlocker(blocker.DBErr, err.Error()) - } - defer rows.Close() - if rows.Next() { - err = rows.Scan( - &accountBalance.AccountAddress, - &accountBalance.BlockHeight, - &accountBalance.SpendableBalance, - &accountBalance.Balance, - &accountBalance.PopRevenue, - &accountBalance.Latest, - ) + err = func() error { + // check balance + qry, args := tx.AccountBalanceQuery.GetAccountBalanceByAccountAddress(tx.SenderAddress) + rows, err := tx.QueryExecutor.ExecuteSelect(qry, dbTx, args...) if err != nil { - return err + return blocker.NewBlocker(blocker.DBErr, err.Error()) + } + defer rows.Close() + if rows.Next() { + err = rows.Scan( + &accountBalance.AccountAddress, + &accountBalance.BlockHeight, + &accountBalance.SpendableBalance, + &accountBalance.Balance, + &accountBalance.PopRevenue, + &accountBalance.Latest, + ) + if err != nil { + return err + } } + return nil + }() + if err != nil { + return err } if accountBalance.SpendableBalance < tx.Body.LockedBalance+tx.Fee { return blocker.NewBlocker(blocker.AppErr, "UserBalanceNotEnough") } - // check for public key duplication - nodeRow, err := tx.QueryExecutor.ExecuteSelect(tx.NodeRegistrationQuery.GetNodeRegistrationByNodePublicKey(), - dbTx, tx.Body.NodePublicKey) - if err != nil { - return err - } - defer nodeRow.Close() - nodeRegistrations, err = tx.NodeRegistrationQuery.BuildModel(nodeRegistrations, nodeRow) + err = func() error { + // check for public key duplication + nodeRow, err := tx.QueryExecutor.ExecuteSelect(tx.NodeRegistrationQuery.GetNodeRegistrationByNodePublicKey(), + dbTx, tx.Body.NodePublicKey) + if err != nil { + return err + } + defer nodeRow.Close() + nodeRegistrations, err = tx.NodeRegistrationQuery.BuildModel(nodeRegistrations, nodeRow) + if err != nil { + return err + } + return nil + }() if err != nil { return err } + // in case a node with same pub key exists, validation must pass only if that node is tagged as deleted // if any other state validation should fail if len(nodeRegistrations) > 0 && nodeRegistrations[0].RegistrationStatus != uint32(model.NodeRegistrationState_NodeDeleted) { return blocker.NewBlocker(blocker.AuthErr, "NodeAlreadyRegistered") } - // check for account address duplication (accounts can register one node at the time) - qryNodeByAccount, args := tx.NodeRegistrationQuery.GetNodeRegistrationByAccountAddress(tx.Body.AccountAddress) - nodeRow2, err := tx.QueryExecutor.ExecuteSelect(qryNodeByAccount, dbTx, args...) - if err != nil { - return err - } - defer nodeRow2.Close() - nodeRegistrations2, err = tx.NodeRegistrationQuery.BuildModel(nodeRegistrations2, nodeRow2) + err = func() error { + // check for account address duplication (accounts can register one node at the time) + qryNodeByAccount, args := tx.NodeRegistrationQuery.GetNodeRegistrationByAccountAddress(tx.Body.AccountAddress) + nodeRow2, err := tx.QueryExecutor.ExecuteSelect(qryNodeByAccount, dbTx, args...) + if err != nil { + return err + } + defer nodeRow2.Close() + nodeRegistrations2, err = tx.NodeRegistrationQuery.BuildModel(nodeRegistrations2, nodeRow2) + if err != nil { + return err + } + return nil + }() if err != nil { return err } + // in case a node with same account address, validation must pass only if that node is tagged as deleted // if any other state validation should fail if len(nodeRegistrations2) > 0 && nodeRegistrations2[0].RegistrationStatus != uint32(model.NodeRegistrationState_NodeDeleted) { diff --git a/common/transaction/nodeRegistrationUpdate.go b/common/transaction/nodeRegistrationUpdate.go index 09e193a99..da5d41065 100644 --- a/common/transaction/nodeRegistrationUpdate.go +++ b/common/transaction/nodeRegistrationUpdate.go @@ -238,38 +238,50 @@ func (tx *UpdateNodeRegistration) Validate(dbTx bool) error { tx.BlockQuery); err != nil { return err } - // check that sender is node's owner - qry, args := tx.NodeRegistrationQuery.GetNodeRegistrationByAccountAddress(tx.SenderAddress) - rows, err := tx.QueryExecutor.ExecuteSelect(qry, dbTx, args...) - if err != nil { - return err - } - defer rows.Close() + err := func() error { + // check that sender is node's owner + qry, args := tx.NodeRegistrationQuery.GetNodeRegistrationByAccountAddress(tx.SenderAddress) + rows, err := tx.QueryExecutor.ExecuteSelect(qry, dbTx, args...) + if err != nil { + return err + } + defer rows.Close() - tempNodeRegistrationResult, err = tx.NodeRegistrationQuery.BuildModel(tempNodeRegistrationResult, rows) - if (err != nil) || len(tempNodeRegistrationResult) > 0 { - prevNodeRegistration = tempNodeRegistrationResult[0] - if prevNodeRegistration.RegistrationStatus == uint32(model.NodeRegistrationState_NodeDeleted) { - return blocker.NewBlocker(blocker.AuthErr, "NodeDeleted") + tempNodeRegistrationResult, err = tx.NodeRegistrationQuery.BuildModel(tempNodeRegistrationResult, rows) + if (err != nil) || len(tempNodeRegistrationResult) > 0 { + prevNodeRegistration = tempNodeRegistrationResult[0] + if prevNodeRegistration.RegistrationStatus == uint32(model.NodeRegistrationState_NodeDeleted) { + return blocker.NewBlocker(blocker.AuthErr, "NodeDeleted") + } + } else { + return blocker.NewBlocker(blocker.ValidationErr, "SenderAccountNotNodeOwner") } - } else { - return blocker.NewBlocker(blocker.ValidationErr, "SenderAccountNotNodeOwner") + return nil + }() + if err != nil { + return err } // validate node public key, if we are updating that field // note: node pub key must be not already registered for another node if len(tx.Body.NodePublicKey) > 0 && !bytes.Equal(prevNodeRegistration.NodePublicKey, tx.Body.NodePublicKey) { - rows2, err := tx.QueryExecutor.ExecuteSelect(tx.NodeRegistrationQuery. - GetNodeRegistrationByNodePublicKey(), false, tx.Body.NodePublicKey) + err := func() error { + rows2, err := tx.QueryExecutor.ExecuteSelect(tx.NodeRegistrationQuery. + GetNodeRegistrationByNodePublicKey(), false, tx.Body.NodePublicKey) + if err != nil { + return err + } + defer rows2.Close() + + tempNodeRegistrationResult2, err = tx.NodeRegistrationQuery.BuildModel(tempNodeRegistrationResult2, rows2) + if (err != nil) || len(tempNodeRegistrationResult2) > 0 { + return blocker.NewBlocker(blocker.ValidationErr, "NodePublicKeyAlredyRegistered") + } + return nil + }() if err != nil { return err } - defer rows2.Close() - - tempNodeRegistrationResult2, err = tx.NodeRegistrationQuery.BuildModel(tempNodeRegistrationResult2, rows2) - if (err != nil) || len(tempNodeRegistrationResult2) > 0 { - return blocker.NewBlocker(blocker.ValidationErr, "NodePublicKeyAlredyRegistered") - } } // delta amount to be locked @@ -280,7 +292,7 @@ func (tx *UpdateNodeRegistration) Validate(dbTx bool) error { } // check balance - qry, args = tx.AccountBalanceQuery.GetAccountBalanceByAccountAddress(tx.SenderAddress) + qry, args := tx.AccountBalanceQuery.GetAccountBalanceByAccountAddress(tx.SenderAddress) row3, err := tx.QueryExecutor.ExecuteSelectRow(qry, dbTx, args...) if err != nil { return blocker.NewBlocker(blocker.DBErr, err.Error()) diff --git a/common/transaction/transactionGeneral.go b/common/transaction/transactionGeneral.go index b4a7e8876..0150f682b 100644 --- a/common/transaction/transactionGeneral.go +++ b/common/transaction/transactionGeneral.go @@ -411,19 +411,26 @@ func (mtu *MultisigTransactionUtil) CheckMultisigComplete( BlockHeight: txHeight, }) } - q, args := mtu.PendingTransactionQuery.GetPendingTransactionsBySenderAddress( - multisigAddress, model.PendingTransactionStatus_PendingTransactionPending, - txHeight, constant.MinRollbackBlocks, - ) - pendingTxRows, err := mtu.QueryExecutor.ExecuteSelect(q, false, args...) - if err != nil { - return nil, err - } - defer pendingTxRows.Close() - dbPendingTxs, err = mtu.PendingTransactionQuery.BuildModel(dbPendingTxs, pendingTxRows) + err := func() error { + q, args := mtu.PendingTransactionQuery.GetPendingTransactionsBySenderAddress( + multisigAddress, model.PendingTransactionStatus_PendingTransactionPending, + txHeight, constant.MinRollbackBlocks, + ) + pendingTxRows, err := mtu.QueryExecutor.ExecuteSelect(q, false, args...) + if err != nil { + return err + } + defer pendingTxRows.Close() + dbPendingTxs, err = mtu.PendingTransactionQuery.BuildModel(dbPendingTxs, pendingTxRows) + if err != nil { + return err + } + return nil + }() if err != nil { return nil, err } + pendingTxs = append(pendingTxs, dbPendingTxs...) if len(pendingTxs) < 1 { return nil, nil @@ -436,23 +443,30 @@ func (mtu *MultisigTransactionUtil) CheckMultisigComplete( signatures = make(map[string][]byte) validSignatureCounter uint32 ) - q, args := mtu.PendingSignatureQuery.GetPendingSignatureByHash( - v.TransactionHash, - txHeight, constant.MinRollbackBlocks, - ) - pendingSigRows, err := mtu.QueryExecutor.ExecuteSelect(q, false, args...) - if err != nil { - return nil, err - } - pendingSigs, err = mtu.PendingSignatureQuery.BuildModel(pendingSigs, pendingSigRows) - if err != nil { - pendingSigRows.Close() - return nil, err - } - pendingSigRows.Close() + + err := func() error { + q, args := mtu.PendingSignatureQuery.GetPendingSignatureByHash( + v.TransactionHash, + txHeight, constant.MinRollbackBlocks, + ) + pendingSigRows, err := mtu.QueryExecutor.ExecuteSelect(q, false, args...) + if err != nil { + return err + } + pendingSigs, err = mtu.PendingSignatureQuery.BuildModel(pendingSigs, pendingSigRows) + if err != nil { + return err + } + defer pendingSigRows.Close() + if err != nil { + return err + } + return nil + }() if err != nil { return nil, err } + for _, sig := range pendingSigs { signatures[sig.AccountAddress] = sig.Signature } @@ -494,6 +508,7 @@ func (mtu *MultisigTransactionUtil) CheckMultisigComplete( multisigInfo model.MultiSignatureInfo pendingSigs []*model.PendingSignature validSignatureCounter uint32 + err error ) txHash := sha3.Sum256(body.UnsignedTransactionBytes) innerTx, err := mtu.TransactionUtil.ParseTransactionBytes(body.UnsignedTransactionBytes, false) @@ -528,19 +543,26 @@ func (mtu *MultisigTransactionUtil) CheckMultisigComplete( } } var dbPendingSigs []*model.PendingSignature - q, args = mtu.PendingSignatureQuery.GetPendingSignatureByHash( - txHash[:], - txHeight, constant.MinRollbackBlocks, - ) - rows, err := mtu.QueryExecutor.ExecuteSelect(q, false, args...) - if err != nil { - return nil, err - } - defer rows.Close() - dbPendingSigs, err = mtu.PendingSignatureQuery.BuildModel(dbPendingSigs, rows) + err = func() error { + q, args = mtu.PendingSignatureQuery.GetPendingSignatureByHash( + txHash[:], + txHeight, constant.MinRollbackBlocks, + ) + rows, err := mtu.QueryExecutor.ExecuteSelect(q, false, args...) + if err != nil { + return err + } + defer rows.Close() + dbPendingSigs, err = mtu.PendingSignatureQuery.BuildModel(dbPendingSigs, rows) + if err != nil { + return err + } + return nil + }() if err != nil { return nil, err } + pendingSigs = append(pendingSigs, dbPendingSigs...) body.SignatureInfo = &model.SignatureInfo{ TransactionHash: txHash[:], @@ -582,6 +604,7 @@ func (mtu *MultisigTransactionUtil) CheckMultisigComplete( pendingSigs []*model.PendingSignature multisigInfo model.MultiSignatureInfo validSignatureCounter uint32 + err error ) txHash := body.SignatureInfo.TransactionHash @@ -608,19 +631,25 @@ func (mtu *MultisigTransactionUtil) CheckMultisigComplete( "FailToParseTransactionBytes", ) } - q, args = mtu.PendingSignatureQuery.GetPendingSignatureByHash( - txHash, - txHeight, constant.MinRollbackBlocks, - ) - rowsPendingSigs, err := mtu.QueryExecutor.ExecuteSelect(q, false, args...) - if err != nil { - return nil, err - } - pendingSigs, err = mtu.PendingSignatureQuery.BuildModel(pendingSigs, rowsPendingSigs) + err = func() error { + q, args = mtu.PendingSignatureQuery.GetPendingSignatureByHash( + txHash, + txHeight, constant.MinRollbackBlocks, + ) + rowsPendingSigs, err := mtu.QueryExecutor.ExecuteSelect(q, false, args...) + if err != nil { + return err + } + defer rowsPendingSigs.Close() + pendingSigs, err = mtu.PendingSignatureQuery.BuildModel(pendingSigs, rowsPendingSigs) + if err != nil { + return err + } + return nil + }() if err != nil { return nil, err } - defer rowsPendingSigs.Close() for _, sig := range pendingSigs { body.SignatureInfo.Signatures[sig.AccountAddress] = sig.Signature } diff --git a/core/service/blockMainService.go b/core/service/blockMainService.go index 1b0a8d114..5eb455c5f 100644 --- a/core/service/blockMainService.go +++ b/core/service/blockMainService.go @@ -1214,16 +1214,24 @@ func (bs *BlockService) GetBlockExtendedInfo(block *model.Block, includeReceipts } else { blExt.BlocksmithAccountAddress = constant.MainchainGenesisAccountAddress } - skippedBlocksmithsQuery := bs.SkippedBlocksmithQuery.GetSkippedBlocksmithsByBlockHeight(block.Height) - skippedBlocksmithsRows, err := bs.QueryExecutor.ExecuteSelect(skippedBlocksmithsQuery, false) - if err != nil { - return nil, err - } - defer skippedBlocksmithsRows.Close() - blExt.SkippedBlocksmiths, err = bs.SkippedBlocksmithQuery.BuildModel(skippedBlocksmiths, skippedBlocksmithsRows) + + err = func() error { + skippedBlocksmithsQuery := bs.SkippedBlocksmithQuery.GetSkippedBlocksmithsByBlockHeight(block.Height) + skippedBlocksmithsRows, err := bs.QueryExecutor.ExecuteSelect(skippedBlocksmithsQuery, false) + if err != nil { + return err + } + defer skippedBlocksmithsRows.Close() + blExt.SkippedBlocksmiths, err = bs.SkippedBlocksmithQuery.BuildModel(skippedBlocksmiths, skippedBlocksmithsRows) + if err != nil { + return err + } + return nil + }() if err != nil { return nil, err } + publishedReceipts, err = bs.PublishedReceiptUtil.GetPublishedReceiptsByBlockHeight(block.GetHeight()) if err != nil { return nil, err @@ -1236,16 +1244,24 @@ func (bs *BlockService) GetBlockExtendedInfo(block *model.Block, includeReceipts unLinkedPublishedReceiptCount++ } } - nodeRegistryAtHeightQ := bs.NodeRegistrationQuery.GetNodeRegistryAtHeight(block.Height) - nodeRegistryAtHeightRows, err := bs.QueryExecutor.ExecuteSelect(nodeRegistryAtHeightQ, false) - if err != nil { - return nil, err - } - defer nodeRegistryAtHeightRows.Close() - nodeRegistryAtHeight, err = bs.NodeRegistrationQuery.BuildModel(nodeRegistryAtHeight, nodeRegistryAtHeightRows) + + err = func() error { + nodeRegistryAtHeightQ := bs.NodeRegistrationQuery.GetNodeRegistryAtHeight(block.Height) + nodeRegistryAtHeightRows, err := bs.QueryExecutor.ExecuteSelect(nodeRegistryAtHeightQ, false) + if err != nil { + return err + } + defer nodeRegistryAtHeightRows.Close() + nodeRegistryAtHeight, err = bs.NodeRegistrationQuery.BuildModel(nodeRegistryAtHeight, nodeRegistryAtHeightRows) + if err != nil { + return err + } + return nil + }() if err != nil { return nil, err } + blExt.ReceiptValue = commonUtils.GetReceiptValue(linkedPublishedReceiptCount, unLinkedPublishedReceiptCount) blExt.PopChange, err = util.CalculateParticipationScore( linkedPublishedReceiptCount, diff --git a/core/service/receiptService.go b/core/service/receiptService.go index 214a9a41f..9ee4316c0 100644 --- a/core/service/receiptService.go +++ b/core/service/receiptService.go @@ -95,8 +95,9 @@ func (rs *ReceiptService) SelectReceipts( var ( linkedReceiptList = make(map[string][]*model.Receipt) // this variable is to store picked receipt recipient to avoid duplicates - pickedRecipients = make(map[string]bool) - lowerBlockHeight uint32 + pickedRecipients = make(map[string]bool) + lowerBlockHeight uint32 + linkedReceiptTree = make(map[string][]byte) ) if numberOfReceipt < 1 { // possible no connected node @@ -106,17 +107,24 @@ func (rs *ReceiptService) SelectReceipts( if lastBlockHeight > constant.NodeReceiptExpiryBlockHeight { lowerBlockHeight = lastBlockHeight - constant.NodeReceiptExpiryBlockHeight } - treeQ := rs.MerkleTreeQuery.SelectMerkleTree( - lowerBlockHeight, - lastBlockHeight, - numberOfReceipt*constant.ReceiptBatchPickMultiplier) - linkedTreeRows, err := rs.QueryExecutor.ExecuteSelect(treeQ, false) - if err != nil { - return nil, err - } - defer linkedTreeRows.Close() - linkedReceiptTree, err := rs.MerkleTreeQuery.BuildTree(linkedTreeRows) + err := func() error { + treeQ := rs.MerkleTreeQuery.SelectMerkleTree( + lowerBlockHeight, + lastBlockHeight, + numberOfReceipt*constant.ReceiptBatchPickMultiplier) + linkedTreeRows, err := rs.QueryExecutor.ExecuteSelect(treeQ, false) + if err != nil { + return err + } + defer linkedTreeRows.Close() + + linkedReceiptTree, err = rs.MerkleTreeQuery.BuildTree(linkedTreeRows) + if err != nil { + return err + } + return nil + }() if err != nil { return nil, err }