Skip to content

sweepbatcher: fix race conditions in unit tests #889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ env:

# If you change this value, please change it in the following files as well:
# /Dockerfile
GO_VERSION: 1.21.10
GO_VERSION: 1.24.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉


jobs:
########################
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM --platform=${BUILDPLATFORM} golang:1.22-alpine as builder
FROM --platform=${BUILDPLATFORM} golang:1.24-alpine as builder

# Copy in the local repository to build from.
COPY . /go/src/github.com/lightningnetwork/loop
4 changes: 2 additions & 2 deletions sweepbatcher/greedy_batch_selection.go
Original file line number Diff line number Diff line change
@@ -92,8 +92,8 @@ func (b *Batcher) greedyAddSweep(ctx context.Context, sweep *sweep) error {
return nil
}

log.Debugf("Batch selection algorithm returned batch id %d for"+
" sweep %x, but acceptance failed.", batchId,
debugf("Batch selection algorithm returned batch id %d "+
"for sweep %x, but acceptance failed.", batchId,
sweep.swapHash[:6])
}

34 changes: 30 additions & 4 deletions sweepbatcher/log.go
Original file line number Diff line number Diff line change
@@ -2,15 +2,21 @@ package sweepbatcher

import (
"fmt"
"sync/atomic"

"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
)

// log is a logger that is initialized with no output filters. This
// log_ is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the
// caller requests it.
var log btclog.Logger
var log_ atomic.Pointer[btclog.Logger]

// log returns active logger.
func log() btclog.Logger {
return *log_.Load()
}

// The default amount of logging is none.
func init() {
@@ -20,12 +26,32 @@ func init() {
// batchPrefixLogger returns a logger that prefixes all log messages with
// the ID.
func batchPrefixLogger(batchID string) btclog.Logger {
return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log)
return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log())
}

// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
log_.Store(&logger)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we only really call UseLogger once it should not race iiuc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call it in runTests function in each test. It increases logging to include debug messages. I caught a race between tests: in one test a sweepbatcher was shutting down and logging something, while another test was calling UseLogger.

}

// debugf logs a message with level DEBUG.
func debugf(format string, params ...interface{}) {
log().Debugf(format, params...)
}

// infof logs a message with level INFO.
func infof(format string, params ...interface{}) {
log().Infof(format, params...)
}

// warnf logs a message with level WARN.
func warnf(format string, params ...interface{}) {
log().Warnf(format, params...)
}

// errorf logs a message with level ERROR.
func errorf(format string, params ...interface{}) {
log().Errorf(format, params...)
}
33 changes: 33 additions & 0 deletions sweepbatcher/store_mock.go
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"sort"
"sync"

"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/lntypes"
@@ -13,6 +14,7 @@ import (
type StoreMock struct {
batches map[int32]dbBatch
sweeps map[lntypes.Hash]dbSweep
mu sync.Mutex
}

// NewStoreMock instantiates a new mock store.
@@ -28,6 +30,9 @@ func NewStoreMock() *StoreMock {
func (s *StoreMock) FetchUnconfirmedSweepBatches(ctx context.Context) (
[]*dbBatch, error) {

s.mu.Lock()
defer s.mu.Unlock()

result := []*dbBatch{}
for _, batch := range s.batches {
batch := batch
@@ -44,6 +49,9 @@ func (s *StoreMock) FetchUnconfirmedSweepBatches(ctx context.Context) (
func (s *StoreMock) InsertSweepBatch(ctx context.Context,
batch *dbBatch) (int32, error) {

s.mu.Lock()
defer s.mu.Unlock()

var id int32

if len(s.batches) == 0 {
@@ -66,12 +74,18 @@ func (s *StoreMock) DropBatch(ctx context.Context, id int32) error {
func (s *StoreMock) UpdateSweepBatch(ctx context.Context,
batch *dbBatch) error {

s.mu.Lock()
defer s.mu.Unlock()

s.batches[batch.ID] = *batch
return nil
}

// ConfirmBatch confirms a batch.
func (s *StoreMock) ConfirmBatch(ctx context.Context, id int32) error {
s.mu.Lock()
defer s.mu.Unlock()

batch, ok := s.batches[id]
if !ok {
return errors.New("batch not found")
@@ -87,6 +101,9 @@ func (s *StoreMock) ConfirmBatch(ctx context.Context, id int32) error {
func (s *StoreMock) FetchBatchSweeps(ctx context.Context,
id int32) ([]*dbSweep, error) {

s.mu.Lock()
defer s.mu.Unlock()

result := []*dbSweep{}
for _, sweep := range s.sweeps {
sweep := sweep
@@ -104,14 +121,21 @@ func (s *StoreMock) FetchBatchSweeps(ctx context.Context,

// UpsertSweep inserts a sweep into the database, or updates an existing sweep.
func (s *StoreMock) UpsertSweep(ctx context.Context, sweep *dbSweep) error {
s.mu.Lock()
defer s.mu.Unlock()

s.sweeps[sweep.SwapHash] = *sweep

return nil
}

// GetSweepStatus returns the status of a sweep.
func (s *StoreMock) GetSweepStatus(ctx context.Context,
swapHash lntypes.Hash) (bool, error) {

s.mu.Lock()
defer s.mu.Unlock()

sweep, ok := s.sweeps[swapHash]
if !ok {
return false, nil
@@ -127,6 +151,9 @@ func (s *StoreMock) Close() error {

// AssertSweepStored asserts that a sweep is stored.
func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool {
s.mu.Lock()
defer s.mu.Unlock()

_, ok := s.sweeps[id]
return ok
}
@@ -135,6 +162,9 @@ func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool {
func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) (
*dbBatch, error) {

s.mu.Lock()
defer s.mu.Unlock()

for _, sweep := range s.sweeps {
if sweep.SwapHash == swapHash {
batch, ok := s.batches[sweep.BatchID]
@@ -153,6 +183,9 @@ func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) (
func (s *StoreMock) TotalSweptAmount(ctx context.Context, batchID int32) (
btcutil.Amount, error) {

s.mu.Lock()
defer s.mu.Unlock()

batch, ok := s.batches[batchID]
if !ok {
return 0, errors.New("batch not found")
171 changes: 125 additions & 46 deletions sweepbatcher/sweep_batch.go

Large diffs are not rendered by default.

94 changes: 75 additions & 19 deletions sweepbatcher/sweep_batcher.go
Original file line number Diff line number Diff line change
@@ -225,6 +225,16 @@ var (
ErrBatcherShuttingDown = errors.New("batcher shutting down")
)

// testRequest is a function passed to an event loop and a channel used to
// wait until the function is executed. This is used in unit tests only!
type testRequest struct {
// handler is the function to an event loop.
handler func()

// quit is closed when the handler completes.
quit chan struct{}
}

// Batcher is a system that is responsible for accepting sweep requests and
// placing them in appropriate batches. It will spin up new batches as needed.
type Batcher struct {
@@ -234,6 +244,12 @@ type Batcher struct {
// sweepReqs is a channel where sweep requests are received.
sweepReqs chan SweepRequest

// testReqs is a channel where test requests are received.
// This is used only in unit tests! The reason to have this is to
// avoid data races in require.Eventually calls running in parallel
// to the event loop. See method testRunInEventLoop().
testReqs chan *testRequest

// errChan is a channel where errors are received.
errChan chan error

@@ -461,6 +477,7 @@ func NewBatcher(wallet lndclient.WalletKitClient,
return &Batcher{
batches: make(map[int32]*batch),
sweepReqs: make(chan SweepRequest),
testReqs: make(chan *testRequest),
errChan: make(chan error, 1),
quit: make(chan struct{}),
initDone: make(chan struct{}),
@@ -518,22 +535,30 @@ func (b *Batcher) Run(ctx context.Context) error {
case sweepReq := <-b.sweepReqs:
sweep, err := b.fetchSweep(runCtx, sweepReq)
if err != nil {
log.Warnf("fetchSweep failed: %v.", err)
warnf("fetchSweep failed: %v.", err)

return err
}

err = b.handleSweep(runCtx, sweep, sweepReq.Notifier)
if err != nil {
log.Warnf("handleSweep failed: %v.", err)
warnf("handleSweep failed: %v.", err)

return err
}

case testReq := <-b.testReqs:
testReq.handler()
close(testReq.quit)

case err := <-b.errChan:
log.Warnf("Batcher received an error: %v.", err)
warnf("Batcher received an error: %v.", err)

return err

case <-runCtx.Done():
log.Infof("Stopping Batcher: run context cancelled.")
infof("Stopping Batcher: run context cancelled.")

return runCtx.Err()
}
}
@@ -551,6 +576,36 @@ func (b *Batcher) AddSweep(sweepReq *SweepRequest) error {
}
}

// testRunInEventLoop runs a function in the event loop blocking until
// the function returns. For unit tests only!
func (b *Batcher) testRunInEventLoop(ctx context.Context, handler func()) {
// If the event loop is finished, run the function.
select {
case <-b.quit:
handler()

return
default:
}

quit := make(chan struct{})
req := &testRequest{
handler: handler,
quit: quit,
}

select {
case b.testReqs <- req:
case <-ctx.Done():
return
}

select {
case <-quit:
case <-ctx.Done():
}
}

// handleSweep handles a sweep request by either placing it in an existing
// batch, or by spinning up a new batch for it.
func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep,
@@ -561,8 +616,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep,
return err
}

log.Infof("Batcher handling sweep %x, completed=%v", sweep.swapHash[:6],
completed)
infof("Batcher handling sweep %x, completed=%v",
sweep.swapHash[:6], completed)

// If the sweep has already been completed in a confirmed batch then we
// can't attach its notifier to the batch as that is no longer running.
@@ -573,8 +628,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep,
// on-chain confirmations to prevent issues caused by reorgs.
parentBatch, err := b.store.GetParentBatch(ctx, sweep.swapHash)
if err != nil {
log.Errorf("unable to get parent batch for sweep %x: "+
"%v", sweep.swapHash[:6], err)
errorf("unable to get parent batch for sweep %x:"+
" %v", sweep.swapHash[:6], err)

return err
}
@@ -590,16 +645,17 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep,

sweep.notifier = notifier

// Check if the sweep is already in a batch. If that is the case, we
// provide the sweep to that batch and return.
// This is a check to see if a batch is completed. In that case we just
// lazily delete it.
for _, batch := range b.batches {
// This is a check to see if a batch is completed. In that case
// we just lazily delete it and continue our scan.
if batch.isComplete() {
delete(b.batches, batch.id)
continue
}
}

// Check if the sweep is already in a batch. If that is the case, we
// provide the sweep to that batch and return.
for _, batch := range b.batches {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, now I got it :)

if batch.sweepExists(sweep.swapHash) {
accepted, err := batch.addSweep(ctx, sweep)
if err != nil && !errors.Is(err, ErrBatchShuttingDown) {
@@ -624,8 +680,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep,
return nil
}

log.Warnf("Greedy batch selection algorithm failed for sweep %x: %v. "+
"Falling back to old approach.", sweep.swapHash[:6], err)
warnf("Greedy batch selection algorithm failed for sweep %x: %v."+
" Falling back to old approach.", sweep.swapHash[:6], err)

// If one of the batches accepts the sweep, we provide it to that batch.
for _, batch := range b.batches {
@@ -730,13 +786,13 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error {
}

if len(dbSweeps) == 0 {
log.Infof("skipping restored batch %d as it has no sweeps",
infof("skipping restored batch %d as it has no sweeps",
batch.id)

// It is safe to drop this empty batch as it has no sweeps.
err := b.store.DropBatch(ctx, batch.id)
if err != nil {
log.Warnf("unable to drop empty batch %d: %v",
warnf("unable to drop empty batch %d: %v",
batch.id, err)
}

@@ -878,7 +934,7 @@ func (b *Batcher) monitorSpendAndNotify(ctx context.Context, sweep *sweep,
b.wg.Add(1)
go func() {
defer b.wg.Done()
log.Infof("Batcher monitoring spend for swap %x",
infof("Batcher monitoring spend for swap %x",
sweep.swapHash[:6])

for {
@@ -1057,7 +1113,7 @@ func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash,
}
} else {
if s.ConfTarget == 0 {
log.Warnf("Fee estimation was requested for zero "+
warnf("Fee estimation was requested for zero "+
"confTarget for sweep %x.", swapHash[:6])
}
minFeeRate, err = b.wallet.EstimateFeeRate(ctx, s.ConfTarget)
539 changes: 334 additions & 205 deletions sweepbatcher/sweep_batcher_test.go

Large diffs are not rendered by default.