diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8f9a24aef..45cc22a6d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,7 +20,7 @@ env: # If you change this value, please change it in the following files as well: # /Dockerfile - GO_VERSION: 1.21.10 + GO_VERSION: 1.24.0 jobs: ######################## diff --git a/Dockerfile b/Dockerfile index 488b3c63b..53862c735 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=${BUILDPLATFORM} golang:1.22-alpine as builder +FROM --platform=${BUILDPLATFORM} golang:1.24-alpine as builder # Copy in the local repository to build from. COPY . /go/src/github.com/lightningnetwork/loop diff --git a/sweepbatcher/greedy_batch_selection.go b/sweepbatcher/greedy_batch_selection.go index 30f1cb33a..163e0c3a3 100644 --- a/sweepbatcher/greedy_batch_selection.go +++ b/sweepbatcher/greedy_batch_selection.go @@ -92,8 +92,8 @@ func (b *Batcher) greedyAddSweep(ctx context.Context, sweep *sweep) error { return nil } - log.Debugf("Batch selection algorithm returned batch id %d for"+ - " sweep %x, but acceptance failed.", batchId, + debugf("Batch selection algorithm returned batch id %d "+ + "for sweep %x, but acceptance failed.", batchId, sweep.swapHash[:6]) } diff --git a/sweepbatcher/log.go b/sweepbatcher/log.go index 24d6cc297..29b230994 100644 --- a/sweepbatcher/log.go +++ b/sweepbatcher/log.go @@ -2,15 +2,21 @@ package sweepbatcher import ( "fmt" + "sync/atomic" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" ) -// log is a logger that is initialized with no output filters. This +// log_ is a logger that is initialized with no output filters. This // means the package will not perform any logging by default until the // caller requests it. -var log btclog.Logger +var log_ atomic.Pointer[btclog.Logger] + +// log returns active logger. +func log() btclog.Logger { + return *log_.Load() +} // The default amount of logging is none. func init() { @@ -20,12 +26,32 @@ func init() { // batchPrefixLogger returns a logger that prefixes all log messages with // the ID. func batchPrefixLogger(batchID string) btclog.Logger { - return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log) + return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log()) } // UseLogger uses a specified Logger to output package logging info. // This should be used in preference to SetLogWriter if the caller is also // using btclog. func UseLogger(logger btclog.Logger) { - log = logger + log_.Store(&logger) +} + +// debugf logs a message with level DEBUG. +func debugf(format string, params ...interface{}) { + log().Debugf(format, params...) +} + +// infof logs a message with level INFO. +func infof(format string, params ...interface{}) { + log().Infof(format, params...) +} + +// warnf logs a message with level WARN. +func warnf(format string, params ...interface{}) { + log().Warnf(format, params...) +} + +// errorf logs a message with level ERROR. +func errorf(format string, params ...interface{}) { + log().Errorf(format, params...) } diff --git a/sweepbatcher/store_mock.go b/sweepbatcher/store_mock.go index 57cdd34b7..815b19917 100644 --- a/sweepbatcher/store_mock.go +++ b/sweepbatcher/store_mock.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sort" + "sync" "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/lntypes" @@ -13,6 +14,7 @@ import ( type StoreMock struct { batches map[int32]dbBatch sweeps map[lntypes.Hash]dbSweep + mu sync.Mutex } // NewStoreMock instantiates a new mock store. @@ -28,6 +30,9 @@ func NewStoreMock() *StoreMock { func (s *StoreMock) FetchUnconfirmedSweepBatches(ctx context.Context) ( []*dbBatch, error) { + s.mu.Lock() + defer s.mu.Unlock() + result := []*dbBatch{} for _, batch := range s.batches { batch := batch @@ -44,6 +49,9 @@ func (s *StoreMock) FetchUnconfirmedSweepBatches(ctx context.Context) ( func (s *StoreMock) InsertSweepBatch(ctx context.Context, batch *dbBatch) (int32, error) { + s.mu.Lock() + defer s.mu.Unlock() + var id int32 if len(s.batches) == 0 { @@ -66,12 +74,18 @@ func (s *StoreMock) DropBatch(ctx context.Context, id int32) error { func (s *StoreMock) UpdateSweepBatch(ctx context.Context, batch *dbBatch) error { + s.mu.Lock() + defer s.mu.Unlock() + s.batches[batch.ID] = *batch return nil } // ConfirmBatch confirms a batch. func (s *StoreMock) ConfirmBatch(ctx context.Context, id int32) error { + s.mu.Lock() + defer s.mu.Unlock() + batch, ok := s.batches[id] if !ok { return errors.New("batch not found") @@ -87,6 +101,9 @@ func (s *StoreMock) ConfirmBatch(ctx context.Context, id int32) error { func (s *StoreMock) FetchBatchSweeps(ctx context.Context, id int32) ([]*dbSweep, error) { + s.mu.Lock() + defer s.mu.Unlock() + result := []*dbSweep{} for _, sweep := range s.sweeps { sweep := sweep @@ -104,7 +121,11 @@ func (s *StoreMock) FetchBatchSweeps(ctx context.Context, // UpsertSweep inserts a sweep into the database, or updates an existing sweep. func (s *StoreMock) UpsertSweep(ctx context.Context, sweep *dbSweep) error { + s.mu.Lock() + defer s.mu.Unlock() + s.sweeps[sweep.SwapHash] = *sweep + return nil } @@ -112,6 +133,9 @@ func (s *StoreMock) UpsertSweep(ctx context.Context, sweep *dbSweep) error { func (s *StoreMock) GetSweepStatus(ctx context.Context, swapHash lntypes.Hash) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + sweep, ok := s.sweeps[swapHash] if !ok { return false, nil @@ -127,6 +151,9 @@ func (s *StoreMock) Close() error { // AssertSweepStored asserts that a sweep is stored. func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.sweeps[id] return ok } @@ -135,6 +162,9 @@ func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool { func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( *dbBatch, error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, sweep := range s.sweeps { if sweep.SwapHash == swapHash { batch, ok := s.batches[sweep.BatchID] @@ -153,6 +183,9 @@ func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( func (s *StoreMock) TotalSweptAmount(ctx context.Context, batchID int32) ( btcutil.Amount, error) { + s.mu.Lock() + defer s.mu.Unlock() + batch, ok := s.batches[batchID] if !ok { return 0, errors.New("batch not found") diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index 6a4989cd9..b4e6c65fd 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -9,6 +9,7 @@ import ( "math" "strings" "sync" + "sync/atomic" "time" "github.com/btcsuite/btcd/blockchain" @@ -214,6 +215,12 @@ type batch struct { // reorgChan is the channel over which reorg notifications are received. reorgChan chan struct{} + // testReqs is a channel where test requests are received. + // This is used only in unit tests! The reason to have this is to + // avoid data races in require.Eventually calls running in parallel + // to the event loop. See method testRunInEventLoop(). + testReqs chan *testRequest + // errChan is the channel over which errors are received. errChan chan error @@ -284,8 +291,8 @@ type batch struct { // cfg is the configuration for this batch. cfg *batchConfig - // log is the logger for this batch. - log btclog.Logger + // log_ is the logger for this batch. + log_ atomic.Pointer[btclog.Logger] wg sync.WaitGroup } @@ -351,6 +358,7 @@ func NewBatch(cfg batchConfig, bk batchKit) *batch { spendChan: make(chan *chainntnfs.SpendDetail), confChan: make(chan *chainntnfs.TxConfirmation, 1), reorgChan: make(chan struct{}, 1), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), callEnter: make(chan struct{}), callLeave: make(chan struct{}), @@ -387,7 +395,7 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { } } - return &batch{ + b := &batch{ id: bk.id, state: bk.state, primarySweepID: bk.primaryID, @@ -395,6 +403,7 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { spendChan: make(chan *chainntnfs.SpendDetail), confChan: make(chan *chainntnfs.TxConfirmation, 1), reorgChan: make(chan struct{}, 1), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), callEnter: make(chan struct{}), callLeave: make(chan struct{}), @@ -412,9 +421,42 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { publishErrorHandler: bk.publishErrorHandler, purger: bk.purger, store: bk.store, - log: bk.log, cfg: &cfg, - }, nil + } + + b.setLog(bk.log) + + return b, nil +} + +// log returns current logger. +func (b *batch) log() btclog.Logger { + return *b.log_.Load() +} + +// setLog atomically replaces the logger. +func (b *batch) setLog(logger btclog.Logger) { + b.log_.Store(&logger) +} + +// Debugf logs a message with level DEBUG. +func (b *batch) Debugf(format string, params ...interface{}) { + b.log().Debugf(format, params...) +} + +// Infof logs a message with level INFO. +func (b *batch) Infof(format string, params ...interface{}) { + b.log().Infof(format, params...) +} + +// Warnf logs a message with level WARN. +func (b *batch) Warnf(format string, params ...interface{}) { + b.log().Warnf(format, params...) +} + +// Errorf logs a message with level ERROR. +func (b *batch) Errorf(format string, params ...interface{}) { + b.log().Errorf(format, params...) } // addSweep tries to add a sweep to the batch. If this is the first sweep being @@ -430,7 +472,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // If the provided sweep is nil, we can't proceed with any checks, so // we just return early. if sweep == nil { - b.log.Infof("the sweep is nil") + b.Infof("the sweep is nil") return false, nil } @@ -473,7 +515,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // the batch, do not add another sweep to prevent the tx from becoming // non-standard. if len(b.sweeps) >= MaxSweepsPerBatch { - b.log.Infof("the batch has already too many sweeps (%d >= %d)", + b.Infof("the batch has already too many sweeps %d >= %d", len(b.sweeps), MaxSweepsPerBatch) return false, nil @@ -483,7 +525,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // arrive here after the batch got closed because of a spend. In this // case we cannot add the sweep to this batch. if b.state != Open { - b.log.Infof("the batch state (%v) is not open", b.state) + b.Infof("the batch state (%v) is not open", b.state) return false, nil } @@ -493,15 +535,15 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // we cannot add this sweep to the batch. for _, s := range b.sweeps { if s.isExternalAddr { - b.log.Infof("the batch already has a sweep (%x) with "+ + b.Infof("the batch already has a sweep %x with "+ "an external address", s.swapHash[:6]) return false, nil } if sweep.isExternalAddr { - b.log.Infof("the batch is not empty and new sweep (%x)"+ - " has an external address", sweep.swapHash[:6]) + b.Infof("the batch is not empty and new sweep %x "+ + "has an external address", sweep.swapHash[:6]) return false, nil } @@ -515,7 +557,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { int32(math.Abs(float64(sweep.timeout - s.timeout))) if timeoutDistance > b.cfg.maxTimeoutDistance { - b.log.Infof("too long timeout distance between the "+ + b.Infof("too long timeout distance between the "+ "batch and sweep %x: %d > %d", sweep.swapHash[:6], timeoutDistance, b.cfg.maxTimeoutDistance) @@ -544,7 +586,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { } // Add the sweep to the batch's sweeps. - b.log.Infof("adding sweep %x", sweep.swapHash[:6]) + b.Infof("adding sweep %x", sweep.swapHash[:6]) b.sweeps[sweep.swapHash] = *sweep // Update FeeRate. Max(sweep.minFeeRate) for all the sweeps of @@ -572,7 +614,7 @@ func (b *batch) sweepExists(hash lntypes.Hash) bool { // Wait waits for the batch to gracefully stop. func (b *batch) Wait() { - b.log.Infof("Stopping") + b.Infof("Stopping") <-b.finished } @@ -613,8 +655,7 @@ func (b *batch) Run(ctx context.Context) error { // Set currentHeight here, because it may be needed in monitorSpend. select { case b.currentHeight = <-blockChan: - b.log.Debugf("initial height for the batch is %v", - b.currentHeight) + b.Debugf("initial height for the batch is %v", b.currentHeight) case <-runCtx.Done(): return runCtx.Err() @@ -652,7 +693,7 @@ func (b *batch) Run(ctx context.Context) error { // completes. timerChan := clock.TickAfter(b.cfg.batchPublishDelay) - b.log.Infof("started, primary %x, total sweeps %v", + b.Infof("started, primary %x, total sweeps %v", b.primarySweepID[0:6], len(b.sweeps)) for { @@ -662,7 +703,7 @@ func (b *batch) Run(ctx context.Context) error { // blockChan provides immediately the current tip. case height := <-blockChan: - b.log.Debugf("received block %v", height) + b.Debugf("received block %v", height) // Set the timer to publish the batch transaction after // the configured delay. @@ -670,7 +711,7 @@ func (b *batch) Run(ctx context.Context) error { b.currentHeight = height case <-initialDelayChan: - b.log.Debugf("initial delay of duration %v has ended", + b.Debugf("initial delay of duration %v has ended", b.cfg.initialDelay) // Set the timer to publish the batch transaction after @@ -680,8 +721,8 @@ func (b *batch) Run(ctx context.Context) error { case <-timerChan: // Check that batch is still open. if b.state != Open { - b.log.Debugf("Skipping publishing, because the"+ - " batch is not open (%v).", b.state) + b.Debugf("Skipping publishing, because "+ + "the batch is not open (%v).", b.state) continue } @@ -695,7 +736,7 @@ func (b *batch) Run(ctx context.Context) error { // initialDelayChan has just fired, this check passes. now := clock.Now() if skipBefore.After(now) { - b.log.Debugf(stillWaitingMsg, skipBefore, now) + b.Debugf(stillWaitingMsg, skipBefore, now) continue } @@ -715,14 +756,18 @@ func (b *batch) Run(ctx context.Context) error { case <-b.reorgChan: b.state = Open - b.log.Warnf("reorg detected, batch is able to accept " + - "new sweeps") + b.Warnf("reorg detected, batch is able to " + + "accept new sweeps") err := b.monitorSpend(ctx, b.sweeps[b.primarySweepID]) if err != nil { return err } + case testReq := <-b.testReqs: + testReq.handler() + close(testReq.quit) + case err := <-blockErrChan: return err @@ -735,6 +780,36 @@ func (b *batch) Run(ctx context.Context) error { } } +// testRunInEventLoop runs a function in the event loop blocking until +// the function returns. For unit tests only! +func (b *batch) testRunInEventLoop(ctx context.Context, handler func()) { + // If the event loop is finished, run the function. + select { + case <-b.stopping: + handler() + + return + default: + } + + quit := make(chan struct{}) + req := &testRequest{ + handler: handler, + quit: quit, + } + + select { + case b.testReqs <- req: + case <-ctx.Done(): + return + } + + select { + case <-quit: + case <-ctx.Done(): + } +} + // timeout returns minimum timeout as block height among sweeps of the batch. // If the batch is empty, return -1. func (b *batch) timeout() int32 { @@ -755,8 +830,10 @@ func (b *batch) timeout() int32 { func (b *batch) isUrgent(skipBefore time.Time) bool { timeout := b.timeout() if timeout <= 0 { - b.log.Warnf("Method timeout() returned %v. Number of"+ - " sweeps: %d. It may be an empty batch.", + // This may happen if the batch is empty or if SweepInfo.Timeout + // is not set, may be possible in tests or if there is a bug. + b.Warnf("Method timeout() returned %v. Number of "+ + "sweeps: %d. It may be an empty batch.", timeout, len(b.sweeps)) return false } @@ -779,7 +856,7 @@ func (b *batch) isUrgent(skipBefore time.Time) bool { return false } - b.log.Debugf("cancelling waiting for urgent sweep (timeBank is %v, "+ + b.Debugf("cancelling waiting for urgent sweep (timeBank is %v, "+ "remainingWaiting is %v)", timeBank, remainingWaiting) // Signal to the caller to cancel initialDelay. @@ -795,7 +872,7 @@ func (b *batch) publish(ctx context.Context) error { ) if len(b.sweeps) == 0 { - b.log.Debugf("skipping publish: no sweeps in the batch") + b.Debugf("skipping publish: no sweeps in the batch") return nil } @@ -808,7 +885,7 @@ func (b *batch) publish(ctx context.Context) error { // logPublishError is a function which logs publish errors. logPublishError := func(errMsg string, err error) { - b.publishErrorHandler(err, errMsg, b.log) + b.publishErrorHandler(err, errMsg, b.log()) } fee, err, signSuccess = b.publishMixedBatch(ctx) @@ -830,9 +907,9 @@ func (b *batch) publish(ctx context.Context) error { } } - b.log.Infof("published, total sweeps: %v, fees: %v", len(b.sweeps), fee) + b.Infof("published, total sweeps: %v, fees: %v", len(b.sweeps), fee) for _, sweep := range b.sweeps { - b.log.Infof("published sweep %x, value: %v", + b.Infof("published sweep %x, value: %v", sweep.swapHash[:6], sweep.value) } @@ -1026,7 +1103,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, coopInputs int ) for attempt := 1; ; attempt++ { - b.log.Infof("Attempt %d of collecting cooperative signatures.", + b.Infof("Attempt %d of collecting cooperative signatures.", attempt) // Construct unsigned batch transaction. @@ -1062,7 +1139,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, ctx, i, sweep, tx, prevOutsMap, psbtBytes, ) if err != nil { - b.log.Infof("cooperative signing failed for "+ + b.Infof("cooperative signing failed for "+ "sweep %x: %v", sweep.swapHash[:6], err) // Set coopFailed flag for this sweep in all the @@ -1201,7 +1278,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, } } txHash := tx.TxHash() - b.log.Infof("attempting to publish batch tx=%v with feerate=%v, "+ + b.Infof("attempting to publish batch tx=%v with feerate=%v, "+ "weight=%v, feeForWeight=%v, fee=%v, sweeps=%d, "+ "%d cooperative: (%s) and %d non-cooperative (%s), destAddr=%s", txHash, b.rbfCache.FeeRate, weight, feeForWeight, fee, @@ -1215,7 +1292,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, blockchain.GetTransactionWeight(btcutil.NewTx(tx)), ) if realWeight != weight { - b.log.Warnf("actual weight of tx %v is %v, estimated as %d", + b.Warnf("actual weight of tx %v is %v, estimated as %d", txHash, realWeight, weight) } @@ -1239,11 +1316,11 @@ func (b *batch) debugLogTx(msg string, tx *wire.MsgTx) { // Serialize the transaction and convert to hex string. buf := bytes.NewBuffer(make([]byte, 0, tx.SerializeSize())) if err := tx.Serialize(buf); err != nil { - b.log.Errorf("failed to serialize tx for debug log: %v", err) + b.Errorf("failed to serialize tx for debug log: %v", err) return } - b.log.Debugf("%s: %s", msg, hex.EncodeToString(buf.Bytes())) + b.Debugf("%s: %s", msg, hex.EncodeToString(buf.Bytes())) } // musig2sign signs one sweep using musig2. @@ -1405,15 +1482,16 @@ func (b *batch) updateRbfRate(ctx context.Context) error { if b.rbfCache.FeeRate == 0 { // We set minFeeRate in each sweep, so fee rate is expected to // be initiated here. - b.log.Warnf("rbfCache.FeeRate is 0, which must not happen.") + b.Warnf("rbfCache.FeeRate is 0, which must not happen.") if b.cfg.batchConfTarget == 0 { - b.log.Warnf("updateRbfRate called with zero " + + b.Warnf("updateRbfRate called with zero " + "batchConfTarget") } - b.log.Infof("initializing rbf fee rate for conf target=%v", + b.Infof("initializing rbf fee rate for conf target=%v", b.cfg.batchConfTarget) + rate, err := b.wallet.EstimateFeeRate( ctx, b.cfg.batchConfTarget, ) @@ -1453,6 +1531,7 @@ func (b *batch) monitorSpend(ctx context.Context, primarySweep sweep) error { ) if err != nil { cancel() + return err } @@ -1461,7 +1540,7 @@ func (b *batch) monitorSpend(ctx context.Context, primarySweep sweep) error { defer cancel() defer b.wg.Done() - b.log.Infof("monitoring spend for outpoint %s", + b.Infof("monitoring spend for outpoint %s", primarySweep.outpoint.String()) for { @@ -1584,7 +1663,7 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { if len(spendTx.TxOut) > 0 { b.batchPkScript = spendTx.TxOut[0].PkScript } else { - b.log.Warnf("transaction %v has no outputs", txHash) + b.Warnf("transaction %v has no outputs", txHash) } // As a previous version of the batch transaction may get confirmed, @@ -1666,13 +1745,13 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { err := b.purger(&sweep) if err != nil { - b.log.Errorf("unable to purge sweep %x: %v", + b.Errorf("unable to purge sweep %x: %v", sweep.SwapHash[:6], err) } } }() - b.log.Infof("spent, total sweeps: %v, purged sweeps: %v", + b.Infof("spent, total sweeps: %v, purged sweeps: %v", len(notifyList), len(purgeList)) err := b.monitorConfirmations(ctx) @@ -1690,7 +1769,7 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { // handleConf handles a confirmation notification. This is the final step of the // batch. Here we signal to the batcher that this batch was completed. func (b *batch) handleConf(ctx context.Context) error { - b.log.Infof("confirmed in txid %s", b.batchTxid) + b.Infof("confirmed in txid %s", b.batchTxid) b.state = Confirmed return b.store.ConfirmBatch(ctx, b.id) @@ -1769,7 +1848,7 @@ func (b *batch) insertAndAcquireID(ctx context.Context) (int32, error) { } b.id = id - b.log = batchPrefixLogger(fmt.Sprintf("%d", b.id)) + b.setLog(batchPrefixLogger(fmt.Sprintf("%d", b.id))) return id, nil } diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 3fe9fe9c4..ef5d808d6 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -225,6 +225,16 @@ var ( ErrBatcherShuttingDown = errors.New("batcher shutting down") ) +// testRequest is a function passed to an event loop and a channel used to +// wait until the function is executed. This is used in unit tests only! +type testRequest struct { + // handler is the function to an event loop. + handler func() + + // quit is closed when the handler completes. + quit chan struct{} +} + // Batcher is a system that is responsible for accepting sweep requests and // placing them in appropriate batches. It will spin up new batches as needed. type Batcher struct { @@ -234,6 +244,12 @@ type Batcher struct { // sweepReqs is a channel where sweep requests are received. sweepReqs chan SweepRequest + // testReqs is a channel where test requests are received. + // This is used only in unit tests! The reason to have this is to + // avoid data races in require.Eventually calls running in parallel + // to the event loop. See method testRunInEventLoop(). + testReqs chan *testRequest + // errChan is a channel where errors are received. errChan chan error @@ -461,6 +477,7 @@ func NewBatcher(wallet lndclient.WalletKitClient, return &Batcher{ batches: make(map[int32]*batch), sweepReqs: make(chan SweepRequest), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), quit: make(chan struct{}), initDone: make(chan struct{}), @@ -518,22 +535,30 @@ func (b *Batcher) Run(ctx context.Context) error { case sweepReq := <-b.sweepReqs: sweep, err := b.fetchSweep(runCtx, sweepReq) if err != nil { - log.Warnf("fetchSweep failed: %v.", err) + warnf("fetchSweep failed: %v.", err) + return err } err = b.handleSweep(runCtx, sweep, sweepReq.Notifier) if err != nil { - log.Warnf("handleSweep failed: %v.", err) + warnf("handleSweep failed: %v.", err) + return err } + case testReq := <-b.testReqs: + testReq.handler() + close(testReq.quit) + case err := <-b.errChan: - log.Warnf("Batcher received an error: %v.", err) + warnf("Batcher received an error: %v.", err) + return err case <-runCtx.Done(): - log.Infof("Stopping Batcher: run context cancelled.") + infof("Stopping Batcher: run context cancelled.") + return runCtx.Err() } } @@ -551,6 +576,36 @@ func (b *Batcher) AddSweep(sweepReq *SweepRequest) error { } } +// testRunInEventLoop runs a function in the event loop blocking until +// the function returns. For unit tests only! +func (b *Batcher) testRunInEventLoop(ctx context.Context, handler func()) { + // If the event loop is finished, run the function. + select { + case <-b.quit: + handler() + + return + default: + } + + quit := make(chan struct{}) + req := &testRequest{ + handler: handler, + quit: quit, + } + + select { + case b.testReqs <- req: + case <-ctx.Done(): + return + } + + select { + case <-quit: + case <-ctx.Done(): + } +} + // handleSweep handles a sweep request by either placing it in an existing // batch, or by spinning up a new batch for it. func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, @@ -561,8 +616,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return err } - log.Infof("Batcher handling sweep %x, completed=%v", sweep.swapHash[:6], - completed) + infof("Batcher handling sweep %x, completed=%v", + sweep.swapHash[:6], completed) // If the sweep has already been completed in a confirmed batch then we // can't attach its notifier to the batch as that is no longer running. @@ -573,8 +628,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, // on-chain confirmations to prevent issues caused by reorgs. parentBatch, err := b.store.GetParentBatch(ctx, sweep.swapHash) if err != nil { - log.Errorf("unable to get parent batch for sweep %x: "+ - "%v", sweep.swapHash[:6], err) + errorf("unable to get parent batch for sweep %x:"+ + " %v", sweep.swapHash[:6], err) return err } @@ -590,16 +645,17 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, sweep.notifier = notifier - // Check if the sweep is already in a batch. If that is the case, we - // provide the sweep to that batch and return. + // This is a check to see if a batch is completed. In that case we just + // lazily delete it. for _, batch := range b.batches { - // This is a check to see if a batch is completed. In that case - // we just lazily delete it and continue our scan. if batch.isComplete() { delete(b.batches, batch.id) - continue } + } + // Check if the sweep is already in a batch. If that is the case, we + // provide the sweep to that batch and return. + for _, batch := range b.batches { if batch.sweepExists(sweep.swapHash) { accepted, err := batch.addSweep(ctx, sweep) if err != nil && !errors.Is(err, ErrBatchShuttingDown) { @@ -624,8 +680,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return nil } - log.Warnf("Greedy batch selection algorithm failed for sweep %x: %v. "+ - "Falling back to old approach.", sweep.swapHash[:6], err) + warnf("Greedy batch selection algorithm failed for sweep %x: %v."+ + " Falling back to old approach.", sweep.swapHash[:6], err) // If one of the batches accepts the sweep, we provide it to that batch. for _, batch := range b.batches { @@ -730,13 +786,13 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error { } if len(dbSweeps) == 0 { - log.Infof("skipping restored batch %d as it has no sweeps", + infof("skipping restored batch %d as it has no sweeps", batch.id) // It is safe to drop this empty batch as it has no sweeps. err := b.store.DropBatch(ctx, batch.id) if err != nil { - log.Warnf("unable to drop empty batch %d: %v", + warnf("unable to drop empty batch %d: %v", batch.id, err) } @@ -878,7 +934,7 @@ func (b *Batcher) monitorSpendAndNotify(ctx context.Context, sweep *sweep, b.wg.Add(1) go func() { defer b.wg.Done() - log.Infof("Batcher monitoring spend for swap %x", + infof("Batcher monitoring spend for swap %x", sweep.swapHash[:6]) for { @@ -1057,7 +1113,7 @@ func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash, } } else { if s.ConfTarget == 0 { - log.Warnf("Fee estimation was requested for zero "+ + warnf("Fee estimation was requested for zero "+ "confTarget for sweep %x.", swapHash[:6]) } minFeeRate, err = b.wallet.EstimateFeeRate(ctx, s.ConfTarget) diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index fbfb0d418..addd90420 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -39,6 +39,8 @@ const ( eventuallyCheckFrequency = 100 * time.Millisecond ntfnBufferSize = 1024 + + confTarget = 123 ) // destAddr is a dummy p2wkh address to use as the destination address for @@ -109,18 +111,82 @@ func checkBatcherError(t *testing.T, err error) { } } -// getOnlyBatch makes sure the batcher has exactly one batch and returns it. -func getOnlyBatch(batcher *Batcher) *batch { - if len(batcher.batches) != 1 { - panic(fmt.Sprintf("getOnlyBatch called on a batcher having "+ - "%d batches", len(batcher.batches))) - } +// getBatches returns batches in thread-safe way. +func getBatches(ctx context.Context, batcher *Batcher) []*batch { + var batches []*batch + batcher.testRunInEventLoop(ctx, func() { + for _, batch := range batcher.batches { + batches = append(batches, batch) + } + }) + + return batches +} - for _, batch := range batcher.batches { - return batch +// tryGetOnlyBatch returns a single batch if there is exactly one batch, or nil. +func tryGetOnlyBatch(ctx context.Context, batcher *Batcher) *batch { + batches := getBatches(ctx, batcher) + + if len(batches) == 1 { + return batches[0] + } else { + return nil } +} - panic("unreachable") +// getOnlyBatch makes sure the batcher has exactly one batch and returns it. +func getOnlyBatch(t *testing.T, ctx context.Context, batcher *Batcher) *batch { + batches := getBatches(ctx, batcher) + require.Len(t, batches, 1) + + return batches[0] +} + +// numBatches returns the number of batches in the batcher. +func (b *Batcher) numBatches(ctx context.Context) int { + return len(getBatches(ctx, b)) +} + +// numSweeps returns the number of sweeps in the batch. +func (b *batch) numSweeps(ctx context.Context) int { + var numSweeps int + b.testRunInEventLoop(ctx, func() { + numSweeps = len(b.sweeps) + }) + + return numSweeps +} + +// snapshot returns the snapshot of the batch. It is safe to read in parallel +// with the event loop running. +func (b *batch) snapshot(ctx context.Context) *batch { + var snapshot *batch + b.testRunInEventLoop(ctx, func() { + // Deep copy sweeps. + sweeps := make(map[lntypes.Hash]sweep, len(b.sweeps)) + for h, s := range b.sweeps { + sweeps[h] = s + } + + // Deep copy cfg. + cfg := *b.cfg + + // Deep copy the batch, only data fields. + snapshot = &batch{ + id: b.id, + state: b.state, + primarySweepID: b.primarySweepID, + sweeps: sweeps, + currentHeight: b.currentHeight, + batchTxid: b.batchTxid, + batchPkScript: b.batchPkScript, + batchAddress: b.batchAddress, + rbfCache: b.rbfCache, + cfg: &cfg, + } + }) + + return snapshot } // testSweepBatcherBatchCreation tests that sweep requests enter the expected @@ -186,7 +252,7 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Wait for tx to be published. @@ -236,7 +302,7 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, // Batcher should not create a second batch as timeout distance is small // enough. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Create a third sweep request that has more timeout distance than @@ -273,23 +339,26 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since the second batch got created we check that it registered its + // primary sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as timeout distance is greater // than the threshold require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return batcher.numBatches(ctx) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since the second batch got created we check that it registered its - // primary sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps // in it. - for _, batch := range batcher.batches { + batches := getBatches(ctx, batcher) + + for _, batch := range batches { + batch := batch.snapshot(ctx) switch batch.primarySweepID { case sweepReq1.SwapHash: if len(batch.sweeps) != 2 { @@ -480,28 +549,30 @@ func testTxLabeler(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return batcher.numBatches(ctx) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Wait for tx to be published. <-lnd.TxPublishChannel // Find the batch and assign it to a local variable for easier access. - var theBatch *batch - for _, btch := range batcher.batches { + var wantLabel string + for _, btch := range getBatches(ctx, batcher) { + btch := btch.snapshot(ctx) if btch.primarySweepID == sweepReq1.SwapHash { - theBatch = btch + wantLabel = fmt.Sprintf( + "BatchOutSweepSuccess -- %d", btch.id, + ) } } // Now test the label. - wantLabel := fmt.Sprintf("BatchOutSweepSuccess -- %d", theBatch.id) require.Equal(t, wantLabel, walletKit.lastLabel) // Now make the batcher quit by canceling the context. @@ -632,15 +703,15 @@ func testPublishErrorHandler(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return batcher.numBatches(ctx) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // The first attempt to publish the batch tx is expected to fail. require.ErrorIs(t, <-publishErrorChan, testPublishError) @@ -710,26 +781,28 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return batcher.numBatches(ctx) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Find the batch and assign it to a local variable for easier access. batch := &batch{} - for _, btch := range batcher.batches { - if btch.primarySweepID == sweepReq1.SwapHash { - batch = btch - } + for _, btch := range getBatches(ctx, batcher) { + btch.testRunInEventLoop(ctx, func() { + if btch.primarySweepID == sweepReq1.SwapHash { + batch = btch + } + }) } require.Eventually(t, func() bool { // Batch should have the sweep stored. - return len(batch.sweeps) == 1 + return batch.numSweeps(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // The primary sweep id should be that of the first inserted sweep. @@ -744,6 +817,8 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // After receiving a height notification the batch will step again, // leading to a new spend monitoring. require.Eventually(t, func() bool { + batch := batch.snapshot(ctx) + return batch.currentHeight == 601 }, test.Timeout, eventuallyCheckFrequency) @@ -788,6 +863,8 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // The batch should eventually read the spend notification and progress // its state to closed. require.Eventually(t, func() bool { + batch := batch.snapshot(ctx) + return batch.state == Closed }, test.Timeout, eventuallyCheckFrequency) @@ -811,18 +888,26 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, type wrappedLogger struct { btclog.Logger + mu sync.Mutex + debugMessages []string infoMessages []string } // Debugf logs debug message. func (l *wrappedLogger) Debugf(format string, params ...interface{}) { + l.mu.Lock() + defer l.mu.Unlock() + l.debugMessages = append(l.debugMessages, format) l.Logger.Debugf(format, params...) } // Infof logs info message. func (l *wrappedLogger) Infof(format string, params ...interface{}) { + l.mu.Lock() + defer l.mu.Unlock() + l.infoMessages = append(l.infoMessages, format) l.Logger.Infof(format, params...) } @@ -887,7 +972,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { DestAddr: destAddr, SwapInvoice: swapInvoice, - SweepConfTarget: 123, + SweepConfTarget: confTarget, } err = store.CreateLoopOut(ctx, sweepReq.SwapHash, swap) @@ -930,17 +1015,15 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Eventually the batch is launched. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Replace the logger in the batch with wrappedLogger to watch messages. - var batch1 *batch - for _, batch := range batcher.batches { - batch1 = batch + batch1 := getOnlyBatch(t, ctx, batcher) + testLogger := &wrappedLogger{ + Logger: batch1.log(), } - require.NotNil(t, batch1) - testLogger := &wrappedLogger{Logger: batch1.log} - batch1.log = testLogger + batch1.setLog(testLogger) // Advance the clock to publishDelay. It will trigger the publishDelay // timer, but won't result in publishing, because of initialDelay. @@ -950,7 +1033,10 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for batch publishing to be skipped, because initialDelay has not // ended. require.EventuallyWithT(t, func(c *assert.CollectT) { - require.Contains(t, testLogger.debugMessages, stillWaitingMsg) + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + + assert.Contains(c, testLogger.debugMessages, stillWaitingMsg) }, test.Timeout, eventuallyCheckFrequency) // Advance the clock to the end of initialDelay. @@ -975,16 +1061,13 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) - // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 + return batch.numSweeps(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have stored the batch. @@ -1020,25 +1103,6 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for the batcher to be initialized. <-batcher.initDone - // Wait for batch to load. - require.Eventually(t, func() bool { - // Make sure that the sweep was stored - if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { - return false - } - - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { - return false - } - - // Get the batch. - batch := getOnlyBatch(batcher) - - // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // Expect a timer to be set: 0 (instead of publishDelay), and // RegisterSpend to be called. The order is not determined, so catch // these actions from two separate goroutines. @@ -1051,6 +1115,9 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Since a batch was created we check that it registered for its // primary sweep's spend. <-lnd.RegisterSpendChannel + + // Wait for tx to be published. + <-lnd.TxPublishChannel }() wg3.Add(1) @@ -1065,6 +1132,22 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for RegisterSpend and for timer registration. wg3.Wait() + // Wait for batch to load. + require.Eventually(t, func() bool { + // Make sure that the sweep was stored + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { + return false + } + + // Make sure the batch has one sweep. + return batch.numSweeps(ctx) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Expect one timer: publishDelay (0). wantDelays = []time.Duration{0} require.Equal(t, wantDelays, delays) @@ -1073,9 +1156,6 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { now = now.Add(time.Millisecond) testClock.SetTime(now) - // Wait for tx to be published. - <-lnd.TxPublishChannel - // Tick tock next block. err = lnd.NotifyHeight(601) require.NoError(t, err) @@ -1184,7 +1264,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { DestAddr: destAddr, SwapInvoice: swapInvoice, - SweepConfTarget: 123, + SweepConfTarget: confTarget, } err = store.CreateLoopOut(ctx, sweepReq2.SwapHash, swap2) @@ -1226,15 +1306,16 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { require.Equal(t, wantDelays, delays) // Replace the logger in the batch with wrappedLogger to watch messages. - var batch2 *batch - for _, batch := range batcher.batches { + var testLogger2 *wrappedLogger + for _, batch := range getBatches(ctx, batcher) { if batch.id != batch1.id { - batch2 = batch + testLogger2 = &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger2) } } - require.NotNil(t, batch2) - testLogger2 := &wrappedLogger{Logger: batch2.log} - batch2.log = testLogger2 + require.NotNil(t, testLogger2) // Add another sweep which is urgent. It will go to the same batch // to make sure minimum timeout is calculated properly. @@ -1262,7 +1343,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { DestAddr: destAddr, SwapInvoice: swapInvoice, - SweepConfTarget: 123, + SweepConfTarget: confTarget, } err = store.CreateLoopOut(ctx, sweepReq3.SwapHash, swap3) @@ -1274,7 +1355,10 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for sweep to be added to the batch. require.EventuallyWithT(t, func(c *assert.CollectT) { - require.Contains(t, testLogger2.infoMessages, "adding sweep %x") + testLogger2.mu.Lock() + defer testLogger2.mu.Unlock() + + assert.Contains(c, testLogger2.infoMessages, "adding sweep %x") }, test.Timeout, eventuallyCheckFrequency) // Advance the clock by publishDelay. Don't wait largeInitialDelay. @@ -1283,7 +1367,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for tx to be published. tx := <-lnd.TxPublishChannel - require.Equal(t, 2, len(tx.TxIn)) + require.Len(t, tx.TxIn, 2) // Now make the batcher quit by canceling the context. cancel() @@ -1298,7 +1382,7 @@ func testMaxSweepsPerBatch(t *testing.T, store testStore, batcherStore testBatcherStore) { // Disable logging, because this test is very noisy. - oldLogger := log + oldLogger := log() UseLogger(build.NewSubLogger("SWEEP", nil)) defer UseLogger(oldLogger) @@ -1380,7 +1464,7 @@ func testMaxSweepsPerBatch(t *testing.T, store testStore, DestAddr: destAddr, SwapInvoice: swapInvoice, - SweepConfTarget: 123, + SweepConfTarget: confTarget, } err = store.CreateLoopOut(ctx, swapHash, swap) @@ -1399,15 +1483,17 @@ func testMaxSweepsPerBatch(t *testing.T, store testStore, // Eventually the batches are launched and all the sweeps are added. require.Eventually(t, func() bool { // Make sure all the batches have started. - if len(batcher.batches) != expectedBatches { + batches := getBatches(ctx, batcher) + if len(batches) != expectedBatches { return false } // Make sure all the sweeps were added. sweepsNum := 0 - for _, batch := range batcher.batches { - sweepsNum += len(batch.sweeps) + for _, batch := range batches { + sweepsNum += batch.numSweeps(ctx) } + return sweepsNum == swapsNum }, test.Timeout, eventuallyCheckFrequency) @@ -1588,20 +1674,22 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // Batcher should create a batch for the sweeps. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Find the batch and store it in a local variable for easier access. b := &batch{} - for _, btch := range batcher.batches { - if btch.primarySweepID == sweepReq1.SwapHash { - b = btch - } + for _, btch := range getBatches(ctx, batcher) { + btch.testRunInEventLoop(ctx, func() { + if btch.primarySweepID == sweepReq1.SwapHash { + b = btch + } + }) } // Batcher should contain all sweeps. require.Eventually(t, func() bool { - return len(b.sweeps) == 3 + return b.numSweeps(ctx) == 3 }, test.Timeout, eventuallyCheckFrequency) // Verify that the batch has a primary sweep id that matches the first @@ -1650,20 +1738,22 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // Eventually the batch reads the notification and proceeds to a closed // state. require.Eventually(t, func() bool { + b := b.snapshot(ctx) + return b.state == Closed }, test.Timeout, eventuallyCheckFrequency) + // Since second batch was created we check that it registered for its + // primary sweep's spend. + <-lnd.RegisterSpendChannel + // While handling the spend notification the batch should detect that // some sweeps did not appear in the spending tx, therefore it redirects // them back to the batcher and the batcher inserts them in a new batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return batcher.numBatches(ctx) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since second batch was created we check that it registered for its - // primary sweep's spend. - <-lnd.RegisterSpendChannel - // We mock the confirmation notification. lnd.ConfChannel <- &chainntnfs.TxConfirmation{ Tx: spendingTx, @@ -1678,26 +1768,28 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // confirmation forever. <-lnd.TxPublishChannel + // Re-add one of remaining sweeps to trigger removing the completed + // batch from the batcher. + require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Eventually the batch receives the confirmation notification, // gracefully exits and the batcher deletes it. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Find the other batch, which includes the sweeps that did not appear // in the spending tx. - b = &batch{} - for _, btch := range batcher.batches { - b = btch - } + b = getOnlyBatch(t, ctx, batcher) // After all the sweeps enter, it should contain 2 sweeps. require.Eventually(t, func() bool { - return len(b.sweeps) == 2 + return b.numSweeps(ctx) == 2 }, test.Timeout, eventuallyCheckFrequency) // The batch should be in an open state. - require.Equal(t, b.state, Open) + b1 := b.snapshot(ctx) + require.Equal(t, b1.state, Open) } // testSweepBatcherNonWalletAddr tests that sweep requests that sweep to a non @@ -1753,16 +1845,16 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel @@ -1803,16 +1895,16 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq2)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as first batch is a non wallet // addr batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return batcher.numBatches(ctx) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for second batch to be published. <-lnd.TxPublishChannel @@ -1850,23 +1942,25 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a new batch as timeout distance is greater than // the threshold require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return batcher.numBatches(ctx) == 3 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published for 3rd batch. <-lnd.TxPublishChannel require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps // in it. - for _, batch := range batcher.batches { + batches := getBatches(ctx, batcher) + for _, batch := range batches { + batch := batch.snapshot(ctx) switch batch.primarySweepID { case sweepReq1.SwapHash: if len(batch.sweeps) != 1 { @@ -2103,16 +2197,16 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel @@ -2124,7 +2218,7 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Batcher should not create a second batch as timeout distance is small // enough. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Publish a block to trigger batch 1 republishing. @@ -2133,39 +2227,39 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Wait for tx for the first batch to be published (2 sweeps). tx := <-lnd.TxPublishChannel - require.Equal(t, 2, len(tx.TxIn)) + require.Len(t, tx.TxIn, 2) require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as this sweep pays to a non // wallet address. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return batcher.numBatches(ctx) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the second batch to be published (1 sweep). tx = <-lnd.TxPublishChannel - require.Equal(t, 1, len(tx.TxIn)) + require.Len(t, tx.TxIn, 1) require.NoError(t, batcher.AddSweep(&sweepReq4)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a third batch as timeout distance is greater // than the threshold. require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return batcher.numBatches(ctx) == 3 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the third batch to be published (1 sweep). tx = <-lnd.TxPublishChannel - require.Equal(t, 1, len(tx.TxIn)) + require.Len(t, tx.TxIn, 1) require.NoError(t, batcher.AddSweep(&sweepReq5)) @@ -2181,29 +2275,31 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Batcher should not create a fourth batch as timeout distance is small // enough for it to join the last batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return batcher.numBatches(ctx) == 3 }, test.Timeout, eventuallyCheckFrequency) require.NoError(t, batcher.AddSweep(&sweepReq6)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a fourth batch as this sweep pays to a non // wallet address. require.Eventually(t, func() bool { - return len(batcher.batches) == 4 + return batcher.numBatches(ctx) == 4 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the 4th batch to be published (1 sweep). tx = <-lnd.TxPublishChannel - require.Equal(t, 1, len(tx.TxIn)) + require.Len(t, tx.TxIn, 1) require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps in // it. - for _, batch := range batcher.batches { + batches := getBatches(ctx, batcher) + for _, batch := range batches { + batch := batch.snapshot(ctx) switch batch.primarySweepID { case sweepReq1.SwapHash: if len(batch.sweeps) != 2 { @@ -2360,8 +2456,11 @@ func testRestoringEmptyBatch(t *testing.T, store testStore, require.Eventually(t, func() bool { // Make sure that the sweep was stored and we have exactly one // active batch. - return batcherStore.AssertSweepStored(sweepReq.SwapHash) && - len(batcher.batches) == 1 + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + return batcher.numBatches(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have only one batch stored (as we dropped the dormant @@ -2577,9 +2676,14 @@ func testHandleSweepTwice(t *testing.T, backend testStore, require.Eventually(t, func() bool { // Make sure that the sweep was stored and we have exactly one // active batch. - return batcherStore.AssertSweepStored(sweepReq1.SwapHash) && - batcherStore.AssertSweepStored(sweepReq2.SwapHash) && - len(batcher.batches) == 2 + if !batcherStore.AssertSweepStored(sweepReq1.SwapHash) { + return false + } + if !batcherStore.AssertSweepStored(sweepReq2.SwapHash) { + return false + } + + return batcher.numBatches(ctx) == 2 }, test.Timeout, eventuallyCheckFrequency) // Change the second sweep so that it can be added to the first batch. @@ -2608,7 +2712,8 @@ func testHandleSweepTwice(t *testing.T, backend testStore, require.Eventually(t, func() bool { // Make sure there are two batches. - batches := batcher.batches + batches := getBatches(ctx, batcher) + if len(batches) != 2 { return false } @@ -2622,9 +2727,10 @@ func testHandleSweepTwice(t *testing.T, backend testStore, secondBatch = batch } } + snapshot := secondBatch.snapshot(ctx) // Make sure the second batch has the second sweep. - sweep2, has := secondBatch.sweeps[sweepReq2.SwapHash] + sweep2, has := snapshot.sweeps[sweepReq2.SwapHash] if !has { return false } @@ -2635,8 +2741,10 @@ func testHandleSweepTwice(t *testing.T, backend testStore, // Make sure each batch has one sweep. If the second sweep was added to // both batches, the following check won't pass. - for _, batch := range batcher.batches { - require.Equal(t, 1, len(batch.sweeps)) + batches := getBatches(ctx, batcher) + for _, batch := range batches { + // Make sure the batch has one sweep. + require.Equal(t, 1, batch.numSweeps(ctx)) } // Publish a block to trigger batch 2 republishing. @@ -2704,7 +2812,7 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, DestAddr: destAddr, SwapInvoice: swapInvoice, - SweepConfTarget: 123, + SweepConfTarget: confTarget, } err = store.CreateLoopOut(ctx, sweepReq.SwapHash, swap) @@ -2729,21 +2837,21 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) + // Make sure the batch has one sweep. + snapshot := batch.snapshot(ctx) // Make sure the batch has one sweep. - if len(batch.sweeps) != 1 { + if len(snapshot.sweeps) != 1 { return false } // Make sure the batch has proper batchConfTarget. - return batch.cfg.batchConfTarget == 123 + return snapshot.cfg.batchConfTarget == confTarget }, test.Timeout, eventuallyCheckFrequency) // Make sure we have stored the batch. @@ -2772,6 +2880,12 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, // Wait for the batcher to be initialized. <-batcher.initDone + // Expect registration for spend notification. + <-lnd.RegisterSpendChannel + + // Wait for tx to be published. + <-lnd.TxPublishChannel + // Wait for batch to load. require.Eventually(t, func() bool { // Make sure that the sweep was stored @@ -2779,26 +2893,18 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) - // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 + return batch.numSweeps(ctx) == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure batchConfTarget was preserved. - require.Equal(t, 123, int(getOnlyBatch(batcher).cfg.batchConfTarget)) - - // Expect registration for spend notification. - <-lnd.RegisterSpendChannel - - // Wait for tx to be published. - <-lnd.TxPublishChannel + batch := getOnlyBatch(t, ctx, batcher).snapshot(ctx) + require.Equal(t, int32(confTarget), batch.cfg.batchConfTarget) // Now make the batcher quit by canceling the context. cancel() @@ -2810,11 +2916,22 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, type sweepFetcherMock struct { store map[lntypes.Hash]*SweepInfo + mu sync.Mutex +} + +func (f *sweepFetcherMock) setSweep(hash lntypes.Hash, info *SweepInfo) { + f.mu.Lock() + defer f.mu.Unlock() + + f.store[hash] = info } func (f *sweepFetcherMock) FetchSweep(ctx context.Context, hash lntypes.Hash) ( *SweepInfo, error) { + f.mu.Lock() + defer f.mu.Unlock() + return f.store[hash], nil } @@ -2859,7 +2976,7 @@ func testSweepFetcher(t *testing.T, store testStore, require.NoError(t, err) sweepInfo := &SweepInfo{ - ConfTarget: 123, + ConfTarget: confTarget, Timeout: 111, SwapInvoicePaymentAddr: *swapPaymentAddr, ProtocolVersion: loopdb.ProtocolVersionMuSig2, @@ -2932,21 +3049,20 @@ func testSweepFetcher(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + // Try to get the batch. + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) - // Make sure the batch has one sweep. - if len(batch.sweeps) != 1 { + snapshot := batch.snapshot(ctx) + if len(snapshot.sweeps) != 1 { return false } // Make sure the batch has proper batchConfTarget. - return batch.cfg.batchConfTarget == 123 + return snapshot.cfg.batchConfTarget == confTarget }, test.Timeout, eventuallyCheckFrequency) // Get the published transaction and check the fee rate. @@ -3266,7 +3382,7 @@ func testWithMixedBatch(t *testing.T, store testStore, sweepInfo := &SweepInfo{ Preimage: preimages[i], - ConfTarget: 123, + ConfTarget: confTarget, Timeout: 111, SwapInvoicePaymentAddr: *swapPaymentAddr, ProtocolVersion: loopdb.ProtocolVersionMuSig2, @@ -3279,7 +3395,7 @@ func testWithMixedBatch(t *testing.T, store testStore, if i == 0 { sweepInfo.NonCoopHint = true } - sweepFetcher.store[swapHash] = sweepInfo + sweepFetcher.setSweep(swapHash, sweepInfo) // Create sweep request. sweepReq := SweepRequest{ @@ -3305,7 +3421,7 @@ func testWithMixedBatch(t *testing.T, store testStore, // A transaction is published. tx := <-lnd.TxPublishChannel - require.Equal(t, i+1, len(tx.TxIn)) + require.Len(t, tx.TxIn, i+1) // Check types of inputs. var witnessSizes []int @@ -3433,11 +3549,11 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, ) require.NoError(t, err) - sweepFetcher.store[swapHash] = &SweepInfo{ + sweepFetcher.setSweep(swapHash, &SweepInfo{ Preimage: preimages[i], NonCoopHint: nonCoopHints[i], - ConfTarget: 123, + ConfTarget: confTarget, Timeout: 111, SwapInvoicePaymentAddr: *swapPaymentAddr, ProtocolVersion: loopdb.ProtocolVersionMuSig2, @@ -3445,7 +3561,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, HTLC: *htlc, HTLCSuccessEstimator: htlc.AddSuccessToEstimator, DestAddr: destAddr, - } + }) // Create sweep request. sweepReq := SweepRequest{ @@ -3474,7 +3590,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, // A transaction is published. tx := <-lnd.TxPublishChannel - require.Equal(t, len(preimages), len(tx.TxIn)) + require.Len(t, tx.TxIn, len(preimages)) // Check types of inputs. var witnessSizes []int @@ -3794,9 +3910,10 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is feeRateMedium. - batch := getOnlyBatch(batcher) - require.Len(t, batch.sweeps, 1) - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + batch := getOnlyBatch(t, ctx, batcher) + snapshot := batch.snapshot(ctx) + require.Len(t, snapshot.sweeps, 1) + require.Equal(t, feeRateMedium, snapshot.rbfCache.FeeRate) // Now decrease the fee of sweep1. setFeeRate(swapHash1, feeRateLow) @@ -3810,7 +3927,8 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is still feeRateMedium. - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + snapshot = batch.snapshot(ctx) + require.Equal(t, feeRateMedium, snapshot.rbfCache.FeeRate) // Add sweep2, with feeRateMedium. swapHash2 := lntypes.Hash{2, 2, 2} @@ -3856,8 +3974,9 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is still feeRateMedium. - require.Len(t, batch.sweeps, 2) - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + snapshot = batch.snapshot(ctx) + require.Len(t, snapshot.sweeps, 2) + require.Equal(t, feeRateMedium, snapshot.rbfCache.FeeRate) // Now update fee rate of second sweep (which is not primary) to // feeRateHigh. Fee rate of sweep 1 is still feeRateLow. @@ -3873,7 +3992,8 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate increased to feeRateHigh. - require.Equal(t, feeRateHigh, batch.rbfCache.FeeRate) + snapshot = batch.snapshot(ctx) + require.Equal(t, feeRateHigh, snapshot.rbfCache.FeeRate) } // TestSweepBatcherBatchCreation tests that sweep requests enter the expected @@ -4035,6 +4155,8 @@ type loopdbBatcherStore struct { BatcherStore sweepsSet map[lntypes.Hash]struct{} + + mu sync.Mutex } // UpsertSweep inserts a sweep into the database, or updates an existing sweep @@ -4042,6 +4164,9 @@ type loopdbBatcherStore struct { func (s *loopdbBatcherStore) UpsertSweep(ctx context.Context, sweep *dbSweep) error { + s.mu.Lock() + defer s.mu.Unlock() + err := s.BatcherStore.UpsertSweep(ctx, sweep) if err == nil { s.sweepsSet[sweep.SwapHash] = struct{}{} @@ -4051,7 +4176,11 @@ func (s *loopdbBatcherStore) UpsertSweep(ctx context.Context, // AssertSweepStored asserts that a sweep is stored. func (s *loopdbBatcherStore) AssertSweepStored(id lntypes.Hash) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, has := s.sweepsSet[id] + return has }