diff --git a/sweepbatcher/cpfp.go b/sweepbatcher/cpfp.go new file mode 100644 index 000000000..9a695e127 --- /dev/null +++ b/sweepbatcher/cpfp.go @@ -0,0 +1,651 @@ +package sweepbatcher + +import ( + "bytes" + "context" + "fmt" + + "github.com/btcsuite/btcd/blockchain" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/btcutil/psbt" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" +) + +// ensurePresigned checks that we can sign a 1:1 transaction sweeping the input. +func (b *batch) ensurePresigned(ctx context.Context, newSweep *sweep) error { + if b.cfg.cpfpHelper == nil { + return fmt.Errorf("cpfpHelper is not installed") + } + if len(b.sweeps) != 0 { + return fmt.Errorf("ensurePresigned is done when adding to an " + + "empty batch") + } + + sweeps := []sweep{ + { + outpoint: newSweep.outpoint, + value: newSweep.value, + cpfp: newSweep.cpfp, + }, + } + + // Cache the destination address. + destAddr, err := b.getSweepsDestAddr(ctx, sweeps) + if err != nil { + return fmt.Errorf("failed to find destination address: %w", err) + } + + // Set LockTime to 0. It is not critical. + const currentHeight = 0 + + // Check if we can sign with minimum fee rate. + const feeRate = chainfee.FeePerKwFloor + + tx, _, _, _, err := constructUnsignedTx( + sweeps, destAddr, currentHeight, feeRate, + ) + if err != nil { + return fmt.Errorf("failed to construct unsigned tx "+ + "for feeRate %v: %w", feeRate, err) + } + + // Try to presign this transaction. + batchAmt := newSweep.value + signedTx, err := b.cfg.cpfpHelper.SignTx(ctx, tx, batchAmt, feeRate) + if err != nil { + return fmt.Errorf("failed to sign unsigned tx %v "+ + "for feeRate %v: %w", tx.TxHash(), feeRate, err) + } + + // Check the SignTx worked correctly. + err = CheckSignedTx(tx, signedTx, batchAmt, feeRate) + if err != nil { + return fmt.Errorf("signed tx doesn't correspond the "+ + "unsigned tx: %w", err) + } + + return nil +} + +// presign tries to presign batch sweep transactions composed of this batch and +// the sweep. It signs multiple versions of the transaction to make sure there +// is a transaction to be published if minRelayFee grows. +func (b *batch) presign(ctx context.Context, newSweep *sweep) error { + if b.cfg.cpfpHelper == nil { + return fmt.Errorf("cpfpHelper is not installed") + } + if len(b.sweeps) == 0 { + return fmt.Errorf("presigning is done when adding to a " + + "non-empty batch") + } + + // Create the list of sweeps of the future batch. + sweeps := make([]sweep, 0, len(b.sweeps)+1) + for _, sweep := range b.sweeps { + sweeps = append(sweeps, sweep) + } + existingSweeps := sweeps + sweeps = append(sweeps, *newSweep) + + // Cache the destination address. + destAddr, err := b.getSweepsDestAddr(ctx, existingSweeps) + if err != nil { + return fmt.Errorf("failed to find destination address: %w", err) + } + + return presign(ctx, b.cfg.cpfpHelper, sweeps, destAddr) +} + +// presigner tries to presign a batch transaction. +type presigner interface { + // Presign tries to presign a batch transaction. If the method returns + // nil, it is guaranteed that future calls to SignTx on this set of + // sweeps return valid signed transactions. + Presign(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount) error +} + +// presign tries to presign batch sweep transactions of the sweeps. It signs +// multiple versions of the transaction to make sure there is a transaction to +// be published if minRelayFee grows. +func presign(ctx context.Context, presigner presigner, sweeps []sweep, + destAddr btcutil.Address) error { + + if presigner == nil { + return fmt.Errorf("presigner is not installed") + } + + if len(sweeps) == 0 { + return fmt.Errorf("there are no sweeps") + } + + // Keep track of the total amount this batch is sweeping back. + batchAmt := btcutil.Amount(0) + for _, sweep := range sweeps { + batchAmt += sweep.value + } + + // Go from the floor (1.01 sat/vbyte) to 2k sat/vbyte with step of 1.5x. + const ( + start = chainfee.FeePerKwFloor + stop = chainfee.AbsoluteFeePerKwFloor * 2_000 + ) + + // Set LockTime to 0. It is not critical. + const currentHeight = 0 + + for feeRate := start; feeRate <= stop; feeRate = (feeRate * 3) / 2 { + // Construct an unsigned transaction for this fee rate. + tx, _, feeForWeight, fee, err := constructUnsignedTx( + sweeps, destAddr, currentHeight, feeRate, + ) + if err != nil { + return fmt.Errorf("failed to construct unsigned tx "+ + "for feeRate %v: %w", feeRate, err) + } + + // Try to presign this transaction. + err = presigner.Presign(ctx, tx, batchAmt) + if err != nil { + return fmt.Errorf("failed to presign unsigned tx %v "+ + "for feeRate %v: %w", tx.TxHash(), feeRate, err) + } + + // If fee was clamped, stop here, because fee rate won't grow. + if fee < feeForWeight { + break + } + } + + return nil +} + +// cpfpLabelPrefix is a prefix added to the label of the batch to form a label +// for CPFP transaction. +const cpfpLabelPrefix = "cpfp-for-" + +// publishWithCPFP creates sweep transaction using a custom transaction signer +// and publishes it. It may use CPFP if the custom signer returned a pre-signed +// transaction with insufficient fee. It returns fee of the first transaction, +// not including CPFP's fee, an error (if signing and/or publishing failed) and +// a boolean flag indicating signing success. This mode is incompatible with +// an external address, because it may use CPFP and is designed for batches. +func (b *batch) publishWithCPFP(ctx context.Context) (btcutil.Amount, error, + bool) { + + // Sanity check, there should be at least 1 sweep in this batch. + if len(b.sweeps) == 0 { + return 0, fmt.Errorf("no sweeps in batch"), false + } + + // Make sure that no external address is used. + for _, sweep := range b.sweeps { + if sweep.isExternalAddr { + return 0, fmt.Errorf("external address was used with " + + "a custom transaction signer"), false + } + } + + // Cache current height and desired feerate of the batch. + currentHeight := b.currentHeight + feeRate := b.rbfCache.FeeRate + + // Append this sweep to an array of sweeps. This is needed to keep the + // order of sweeps stored, as iterating the sweeps map does not + // guarantee same order. + sweeps := make([]sweep, 0, len(b.sweeps)) + for _, sweep := range b.sweeps { + sweeps = append(sweeps, sweep) + } + + // Cache the destination address. + address, err := b.getSweepsDestAddr(ctx, sweeps) + if err != nil { + return 0, fmt.Errorf("failed to find destination address: %w", + err), false + } + + // Construct unsigned batch transaction. + tx, weight, _, fee, err := constructUnsignedTx( + sweeps, address, currentHeight, feeRate, + ) + if err != nil { + return 0, fmt.Errorf("failed to construct tx: %w", err), + false + } + + // Adjust feeRate, because it may have been clamped. + feeRate = chainfee.NewSatPerKWeight(fee, weight) + + // Calculate total input amount. + batchAmt := btcutil.Amount(0) + for _, sweep := range sweeps { + batchAmt += sweep.value + } + + // Determine the current minimum relay fee based on our chain backend. + minRelayFee, err := b.wallet.MinRelayFee(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get minRelayFee: %w", err), + false + } + + // Get a signed transaction. It may be either new transaction or a + // pre-signed one. + signedTx, err := b.cfg.cpfpHelper.SignTx(ctx, tx, batchAmt, minRelayFee) + if err != nil { + return 0, fmt.Errorf("failed to sign tx: %w", err), + false + } + + // Run sanity checks to make sure cpfpHelper.SignTx complied with all + // the invariants. + err = CheckSignedTx(tx, signedTx, batchAmt, minRelayFee) + if err != nil { + return 0, fmt.Errorf("signed tx doesn't correspond the "+ + "unsigned tx: %w", err), false + } + tx = signedTx + txHash := tx.TxHash() + + // Make sure tx weight matches the expected value. + realWeight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(tx)), + ) + if realWeight != weight { + b.log().Warnf("actual weight of tx %v is %v, estimated as %d", + txHash, realWeight, weight) + } + + // Find actual fee rate of the signed transaction. It may differ from + // the desired fee rate, because SignTx may return a presigned tx. + output := btcutil.Amount(tx.TxOut[0].Value) + fee = batchAmt - output + signedFeeRate := chainfee.NewSatPerKWeight(fee, realWeight) + + b.log().Infof("attempting to publish custom signed tx=%v, "+ + "desiredFeerate=%v, signedFeeRate=%v, weight=%v, fee=%v, "+ + "sweeps=%d, destAddr=%s", txHash, feeRate, signedFeeRate, + weight, fee, len(tx.TxIn), address) + b.debugLogTx("serialized batch", tx) + + // Publish the transaction. If it fails, we don't return immediately, + // because we may still need a CPFP and it can be done against a + // previously published transaction. + publishErr1 := b.wallet.PublishTransaction( + ctx, tx, b.cfg.txLabeler(b.id), + ) + if publishErr1 == nil { + // Store the batch transaction's txid and pkScript, to use in + // CPFP and for monitoring purposes. + b.batchTxid = &txHash + b.batchPkScript = tx.TxOut[0].PkScript + + if err := b.persist(ctx); err != nil { + return 0, fmt.Errorf("failed to persist: %w", err), true + } + } else { + b.log().Infof("failed to publish custom signed tx=%v, "+ + "desiredFeerate=%v, signedFeeRate=%v, weight=%v, "+ + "fee=%v, sweeps=%d, destAddr=%s", txHash, feeRate, + signedFeeRate, weight, fee, len(tx.TxIn), address) + } + + // Load previously published tx if it exists. + var parentTx *wire.MsgTx + if b.batchTxid != nil { + parentTx, err = b.cfg.cpfpHelper.LoadTx(ctx, *b.batchTxid) + if err != nil { + return 0, fmt.Errorf("failed to load batch tx %v: %w", + *b.batchTxid, err), true + } + } else { + b.log().Warnf("need a CPFP, but there is no published tx known") + } + + // Print this log here, to keep isCPFPNeeded a pure function. + if parentTx != nil && len(parentTx.TxIn) < len(tx.TxIn) { + b.log().Infof("Skip publishing CPFP, because batch tx in mempool"+ + "has %d inputs, while the batch has now %d inputs", + len(parentTx.TxIn), len(tx.TxIn)) + } + + // Determine if CPFP is needed and its feerate. + needsCPFP, err := isCPFPNeeded( + parentTx, batchAmt, len(tx.TxIn), feeRate, signedFeeRate, + ) + if err != nil { + return 0, fmt.Errorf("failed to determine if CPFP is "+ + "needed: %w", err), false + } + + // If CPFP is not needed, we are done now. + if !needsCPFP { + b.log().Infof("CPFP is not needed") + + return fee, publishErr1, true + } + + b.log().Infof("CPFP is needed, parent is %v", parentTx.TxHash()) + + // Create and sign CPFP. + parentWeight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(parentTx)), + ) + parentOutput := btcutil.Amount(parentTx.TxOut[0].Value) + parentFee := batchAmt - parentOutput + childTx, childFeeRate, err := makeUnsignedCPFP( + *b.batchTxid, parentOutput, parentWeight, parentFee, + minRelayFee, feeRate, address, currentHeight, + ) + if err != nil { + return 0, fmt.Errorf("failed to make CPFP tx: %w", err), + true + } + + childTx, err = b.signChildTx(ctx, childTx) + if err != nil { + return 0, fmt.Errorf("failed to sign CPFP tx: %w", err), + true + } + + childTxHash := childTx.TxHash() + parentFeeRate := chainfee.NewSatPerKWeight(parentFee, parentWeight) + b.log().Infof("attempting to publish child tx %v to CPFP parent tx %v,"+ + " effectiveFeeRate=%v, parentFeeRate=%v, childFeeRate=%v", + childTxHash, *b.batchTxid, feeRate, parentFeeRate, + childFeeRate) + b.debugLogTx("serialized child tx", childTx) + + // Publish child transaction. + publishErr2 := b.wallet.PublishTransaction( + ctx, childTx, cpfpLabelPrefix+b.cfg.txLabeler(b.id), + ) + if publishErr2 != nil { + b.log().Infof("failed to publish child tx %v to CPFP parent "+ + "tx %v, effectiveFeeRate=%v, parentFeeRate=%v, "+ + "childFeeRate=%v", childTxHash, *b.batchTxid, feeRate, + parentFeeRate, childFeeRate) + + return fee, publishErr2, true + } + + return fee, publishErr1, true +} + +// getSweepsDestAddr returns the destination address used by a group of sweeps. +// The method must be used in CPFP mode only. +func (b *batch) getSweepsDestAddr(ctx context.Context, + sweeps []sweep) (btcutil.Address, error) { + + if b.cfg.cpfpHelper == nil { + return nil, fmt.Errorf("getSweepsDestAddr used without CPFP") + } + + inputs := make([]wire.OutPoint, len(sweeps)) + for i, s := range sweeps { + if !s.cpfp { + return nil, fmt.Errorf("getSweepsDestAddr used on a " + + "non-CPFP input") + } + + inputs[i] = s.outpoint + } + + // Load pkScript from the CPFP helper. + pkScriptBytes, err := b.cfg.cpfpHelper.DestPkScript(ctx, inputs) + if err != nil { + return nil, fmt.Errorf("cpfpHelper.DestPkScript failed for "+ + "inputs %v: %w", inputs, err) + } + + // Convert pkScript to btcutil.Address. + pkScript, err := txscript.ParsePkScript(pkScriptBytes) + if err != nil { + return nil, fmt.Errorf("txscript.ParsePkScript failed for "+ + "pkScript %x returned for inputs %v: %w", pkScriptBytes, + inputs, err) + } + + address, err := pkScript.Address(b.cfg.chainParams) + if err != nil { + return nil, fmt.Errorf("pkScript.Address failed for "+ + "pkScript %x returned for inputs %v: %w", pkScriptBytes, + inputs, err) + } + + return address, nil +} + +// CheckSignedTx makes sure that signedTx matches the unsignedTx. It checks +// according to criteria specified in the description of CpfpHelper.SignTx. +func CheckSignedTx(unsignedTx, signedTx *wire.MsgTx, inputAmt btcutil.Amount, + minRelayFee chainfee.SatPerKWeight) error { + + // Make sure the set of inputs is the same. + unsignedMap := make(map[wire.OutPoint]uint32, len(unsignedTx.TxIn)) + for _, txIn := range unsignedTx.TxIn { + unsignedMap[txIn.PreviousOutPoint] = txIn.Sequence + } + for _, txIn := range signedTx.TxIn { + seq, has := unsignedMap[txIn.PreviousOutPoint] + if !has { + return fmt.Errorf("input %s is new in signed tx", + txIn.PreviousOutPoint) + } + if seq != txIn.Sequence { + return fmt.Errorf("sequence mismatch in input %s: "+ + "%d in unsigned, %d in signed", + txIn.PreviousOutPoint, seq, txIn.Sequence) + } + delete(unsignedMap, txIn.PreviousOutPoint) + } + for outpoint := range unsignedMap { + return fmt.Errorf("input %s is missing in signed tx", outpoint) + } + + // Compare outputs. + if len(unsignedTx.TxOut) != 1 { + return fmt.Errorf("unsigned tx has %d outputs, want 1", + len(unsignedTx.TxOut)) + } + if len(signedTx.TxOut) != 1 { + return fmt.Errorf("the signed tx has %d outputs, want 1", + len(signedTx.TxOut)) + } + unsignedOut := unsignedTx.TxOut[0] + signedOut := signedTx.TxOut[0] + if !bytes.Equal(unsignedOut.PkScript, signedOut.PkScript) { + return fmt.Errorf("mismatch of output pkScript: %v, %v", + unsignedOut.PkScript, signedOut.PkScript) + } + + // Find the feerate of signedTx. + fee := inputAmt - btcutil.Amount(signedOut.Value) + weight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(signedTx)), + ) + feeRate := chainfee.NewSatPerKWeight(fee, weight) + if feeRate < minRelayFee { + return fmt.Errorf("feerate (%v) of signed tx is lower than "+ + "minRelayFee (%v)", feeRate, minRelayFee) + } + + // Check LockTime. + if signedTx.LockTime > unsignedTx.LockTime { + return fmt.Errorf("locktime (%d) of signed tx is higher than "+ + "locktime of unsigned tx (%d)", signedTx.LockTime, + unsignedTx.LockTime) + } + + // Check Version. + if signedTx.Version != unsignedTx.Version { + return fmt.Errorf("version (%d) of signed tx is not equal to "+ + "version of unsigned tx (%d)", signedTx.Version, + unsignedTx.Version) + } + + return nil +} + +// feeRateThresholdPPM is the ratio of accepted underpayment of fee for which +// no CPFP is used to adjust the effective fee rate. If the underpayment is +// higher, then CPFP is enabled. It is measured in PPM, current level is 2%. +const feeRateThresholdPPM = 2_0000 + +// isCPFPNeeded returns if CPFP is needed to make the effective fee rate close +// to the desired feeRate. The threshold is feeRateThresholdPPM. +func isCPFPNeeded(parentTx *wire.MsgTx, inputAmt btcutil.Amount, numSweeps int, + desiredFeeRate, signedFeeRate chainfee.SatPerKWeight) (bool, error) { + + // First, if feerate of the signed tx matches exactly the desired + // feerate, this means, that we didn't use a presigned transaction, + // which means that all the input are likely to be online, so we don't + // use CPFP. + if desiredFeeRate == signedFeeRate { + return false, nil + } + + // If no transaction was ever published, we can't do CPFP anyway. A + // warning is produced by the calling function in this case. + if parentTx == nil { + return false, nil + } + + // Sanity checks. + if len(parentTx.TxOut) != 1 { + return false, fmt.Errorf("batch tx must have one output, "+ + "but it has %d", len(parentTx.TxOut)) + } + + // Make sure that the parent transaction is signed. + for _, txIn := range parentTx.TxIn { + if len(txIn.Witness) == 0 { + return false, fmt.Errorf("the tx must be signed") + } + } + + // If previously published tx has fewer inputs than the current state + // of the batch, skip CPFP, since it would bump an outdated state. + if len(parentTx.TxIn) < numSweeps { + return false, nil + } + + // Previously published transaction must not have more inputs than the + // current batch state, because inputs are only added. + if len(parentTx.TxIn) > numSweeps { + return false, fmt.Errorf("parent tx has more inputs (%d) than "+ + "exist in the batch currently (%d)", len(parentTx.TxIn), + numSweeps) + } + + // Calculate fee rate of the transaction. + weight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(parentTx)), + ) + fee := inputAmt - btcutil.Amount(parentTx.TxOut[0].Value) + if fee < 0 { + return false, fmt.Errorf("the tx has negative fee %v", fee) + } + parentFeeRate := chainfee.NewSatPerKWeight(fee, weight) + + // Check of the observed_feerate < desired_feerate - threshold. + threshold := desiredFeeRate * feeRateThresholdPPM / 1_000_000 + cpfpNeeded := parentFeeRate < desiredFeeRate-threshold + + return cpfpNeeded, nil +} + +// maxChildFeeSharePPM specifies max share (in ppm) of total funds that can be +// burn in the child transaction in CPFP. Currently it is set to 20%. +const maxChildFeeSharePPM = 20_0000 + +// makeUnsignedCPFP constructs unsigned child tx for CPFP to achieve desired +// effective fee rate. It also returns fee rate of the constructed child tx. +// The transaction spends the UTXO to the same address. Supports P2WKH, P2TR. +func makeUnsignedCPFP(parentTxid chainhash.Hash, parentOutput btcutil.Amount, + parentWeight lntypes.WeightUnit, parentFee btcutil.Amount, minRelayFee, + effectiveFeeRate chainfee.SatPerKWeight, address btcutil.Address, + currentHeight int32) (*wire.MsgTx, chainfee.SatPerKWeight, error) { + + // Estimate the weight of the child tx. + var estimator input.TxWeightEstimator + switch address.(type) { + case *btcutil.AddressWitnessPubKeyHash: + estimator.AddP2WKHInput() + estimator.AddP2WKHOutput() + + case *btcutil.AddressTaproot: + estimator.AddTaprootKeySpendInput(txscript.SigHashDefault) + estimator.AddP2TROutput() + + default: + return nil, 0, fmt.Errorf("unknown address type %T", address) + } + childWeight := estimator.Weight() + + // Estimate the fee of the child tx. + totalWeight := parentWeight + childWeight + totalFee := effectiveFeeRate.FeeForWeight(totalWeight) + childFee := totalFee - parentFee + childFeeRate := chainfee.NewSatPerKWeight(childFee, childWeight) + if childFeeRate < minRelayFee { + childFeeRate = minRelayFee + childFee = childFeeRate.FeeForWeight(childWeight) + } + if childFeeRate < effectiveFeeRate { + return nil, 0, fmt.Errorf("got child fee rate %v lower than "+ + "effective fee rate %v", childFeeRate, effectiveFeeRate) + } + if childFee > parentOutput*maxChildFeeSharePPM/1_000_000 { + return nil, 0, fmt.Errorf("child fee %v is higher than %d%% "+ + "of total funds %v", childFee, + maxChildFeeSharePPM*100/1_000_000, parentOutput) + } + + // Construct child tx. + childTx := &wire.MsgTx{ + Version: 2, + LockTime: uint32(currentHeight), + } + childTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: wire.OutPoint{ + Hash: parentTxid, + Index: 0, + }, + }) + pkScript, err := txscript.PayToAddrScript(address) + if err != nil { + return nil, 0, fmt.Errorf("txscript.PayToAddrScript "+ + "failed: %w", err) + } + childTx.AddTxOut(&wire.TxOut{ + PkScript: pkScript, + Value: int64(parentOutput - childFee), + }) + + return childTx, childFeeRate, nil +} + +// signChildTx signs child CPFP transaction using LND client. +func (b *batch) signChildTx(ctx context.Context, + unsignedTx *wire.MsgTx) (*wire.MsgTx, error) { + + // Create PSBT packet object. + packet, err := psbt.NewFromUnsignedTx(unsignedTx) + if err != nil { + return nil, fmt.Errorf("failed to create PSBT: %w", err) + } + + packet, err = b.wallet.SignPsbt(ctx, packet) + if err != nil { + return nil, fmt.Errorf("signing PSBT failed: %w", err) + } + + return psbt.Extract(packet) +} diff --git a/sweepbatcher/cpfp_test.go b/sweepbatcher/cpfp_test.go new file mode 100644 index 000000000..399e37184 --- /dev/null +++ b/sweepbatcher/cpfp_test.go @@ -0,0 +1,1293 @@ +package sweepbatcher + +import ( + "context" + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/require" +) + +// mockPresigner is an implementation of Presigner used in TestPresign. +type mockPresigner struct { + // outputs collects outputs of presigned transactions. + outputs []btcutil.Amount + + // failAt is optional index of a call at which it fails, 1 based. + failAt int +} + +// Presign memorizes the value of the output and fails if the number of +// calls previously made is failAt. +func (p *mockPresigner) Presign(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount) error { + + if len(p.outputs)+1 == p.failAt { + return fmt.Errorf("test error in Presign") + } + + p.outputs = append(p.outputs, btcutil.Amount(tx.TxOut[0].Value)) + + return nil +} + +// TestPresign checks that function presign presigns correct set of transactions +// and handles edge cases properly. +func TestPresign(t *testing.T) { + // Prepare the necessary data for test cases. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } + + ctx := context.Background() + + cases := []struct { + name string + presigner presigner + sweeps []sweep + destAddr btcutil.Address + wantErr string + wantOutputs []btcutil.Amount + }{ + { + name: "error: no presigner", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + }, + destAddr: destAddr, + wantErr: "presigner is not installed", + }, + + { + name: "error: no sweeps", + presigner: &mockPresigner{}, + destAddr: destAddr, + wantErr: "there are no sweeps", + }, + + { + name: "error: no destAddr", + presigner: &mockPresigner{}, + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + }, + wantErr: "unsupported address type ", + }, + + { + name: "two coop sweeps", + presigner: &mockPresigner{}, + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + destAddr: destAddr, + wantOutputs: []btcutil.Amount{ + 2999842, 2999763, 2999645, 2999467, 2999200, + 2998800, 2998201, 2997301, 2995952, 2993927, + 2990890, 2986336, 2979503, 2969255, 2953882, + 2930824, 2896235, 2844353, 2766529, + }, + }, + + { + name: "small amount => fewer steps until clamped", + presigner: &mockPresigner{}, + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000, + }, + { + outpoint: op2, + value: 2_000, + }, + }, + destAddr: destAddr, + wantOutputs: []btcutil.Amount{ + 2842, 2763, 2645, 2467, 2400, + }, + }, + + { + name: "third signing fails", + presigner: &mockPresigner{ + failAt: 3, + }, + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000, + }, + { + outpoint: op2, + value: 2_000, + }, + }, + destAddr: destAddr, + wantErr: "for feeRate 568 sat/kw", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := presign(ctx, tc.presigner, tc.sweeps, tc.destAddr) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + outputs := tc.presigner.(*mockPresigner).outputs + require.Equal(t, tc.wantOutputs, outputs) + } + }) + } +} + +// TestCheckSignedTx tests that function CheckSignedTx checks all the criteria +// of CpfpHelper.SignTx correctly. +func TestCheckSignedTx(t *testing.T) { + // Prepare the necessary data for test cases. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + cases := []struct { + name string + unsignedTx *wire.MsgTx + signedTx *wire.MsgTx + inputAmt btcutil.Amount + minRelayFee chainfee.SatPerKWeight + wantErr string + }{ + { + name: "success", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "", + }, + + { + name: "bad locktime", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_001, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "locktime", + }, + + { + name: "bad version", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 3, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "version", + }, + + { + name: "missing input", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "is missing in signed tx", + }, + + { + name: "extra input", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "is new in signed tx", + }, + + { + name: "mismatch of sequence numbers", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 3, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "sequence mismatch", + }, + + { + name: "extra output in unsignedTx", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "unsigned tx has 2 outputs, want 1", + }, + + { + name: "extra output in signedTx", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "the signed tx has 2 outputs, want 1", + }, + + { + name: "mismatch of output pk_script", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript[1:], + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "mismatch of output pkScript", + }, + + { + name: "too low feerate in signedTx", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 250_000, + wantErr: "is lower than minRelayFee", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := CheckSignedTx( + tc.unsignedTx, tc.signedTx, tc.inputAmt, + tc.minRelayFee, + ) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestIsCPFPNeeded tests that function isCPFPNeeded works correctly, satisfying +// feeRateThresholdPPM. +func TestIsCPFPNeeded(t *testing.T) { + // Prepare the necessary data for test cases. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + witness := wire.TxWitness{ + make([]byte, 64), + } + + cases := []struct { + name string + parentTx *wire.MsgTx + inputAmt btcutil.Amount + numSweeps int + desiredFeeRate chainfee.SatPerKWeight + signedFeeRate chainfee.SatPerKWeight + wantErr string + wantNeedsCPFP bool + }{ + { + name: "fee rate matches exacly", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1000, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "fee rate higher than needed", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 900, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "fee rate slightly lower than needed", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1020, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "fee rate significantly lower than needed", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1100, + wantErr: "", + wantNeedsCPFP: true, + }, + { + name: "fewer inputs in parent transaction", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 3, + desiredFeeRate: 1100, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "more inputs in parent transaction", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 1, + desiredFeeRate: 1100, + wantErr: "parent tx has more inputs", + }, + { + name: "signed fee rate equal to desired fee rate", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1100, + signedFeeRate: 1100, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "error: tx has negative fee", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 3_001_000, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1000, + wantErr: "negative fee", + }, + { + name: "error: tx has multiple outputs", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 1_000_000, + PkScript: batchPkScript, + }, + { + Value: 2_000_000, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1000, + wantErr: "must have one output", + }, + { + name: "error: unsigned tx", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1000, + wantErr: "the tx must be signed", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + needsCPFP, err := isCPFPNeeded( + tc.parentTx, tc.inputAmt, tc.numSweeps, + tc.desiredFeeRate, tc.signedFeeRate, + ) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.wantNeedsCPFP, needsCPFP) + } + }) + } +} + +// TestMakeUnsignedCPFP tests that function makeUnsignedCPFP works correctly, +// satisfying maxChildFeeSharePPM and making sure that child fee rate is higher +// than effective fee rate and of minRelayFee. +func TestMakeUnsignedCPFP(t *testing.T) { + // Prepare the necessary data for test cases. + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + p2trAddr := "bcrt1pa38tp2hgjevqv3jcsxeu7v72n0s5a3ck8q2u8r" + + "k6mm67dv7uk26qq8je7e" + p2trAddress, err := btcutil.DecodeAddress(p2trAddr, nil) + require.NoError(t, err) + p2trPkScript, err := txscript.PayToAddrScript(p2trAddress) + require.NoError(t, err) + + serializedPubKey := []byte{ + 0x02, 0x19, 0x2d, 0x74, 0xd0, 0xcb, 0x94, 0x34, 0x4c, 0x95, + 0x69, 0xc2, 0xe7, 0x79, 0x01, 0x57, 0x3d, 0x8d, 0x79, 0x03, + 0xc3, 0xeb, 0xec, 0x3a, 0x95, 0x77, 0x24, 0x89, 0x5d, 0xca, + 0x52, 0xc6, 0xb4} + p2pkAddress, err := btcutil.NewAddressPubKey( + serializedPubKey, &chaincfg.RegressionNetParams, + ) + require.NoError(t, err) + + batchTxid := chainhash.Hash{5, 5, 5} + + op := wire.OutPoint{ + Hash: batchTxid, + Index: 0, + } + + cases := []struct { + name string + parentTxid chainhash.Hash + parentOutput btcutil.Amount + parentWeight lntypes.WeightUnit + parentFee btcutil.Amount + minRelayFee chainfee.SatPerKWeight + effectiveFeeRate chainfee.SatPerKWeight + address btcutil.Address + currentHeight int32 + wantErr string + wantUnsignedChild *wire.MsgTx + wantChildFeeRate chainfee.SatPerKWeight + }{ + { + name: "normal child creation", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 2000, + address: p2trAddress, + currentHeight: 800_000, + wantUnsignedChild: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2997860, + PkScript: p2trPkScript, + }, + }, + }, + wantChildFeeRate: 3410, + }, + { + name: "p2wpkh address", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 2000, + address: destAddr, + currentHeight: 800_000, + wantUnsignedChild: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2997870, + PkScript: batchPkScript, + }, + }, + }, + wantChildFeeRate: 3426, + }, + { + name: "error: p2pk address", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 2000, + address: p2pkAddress, + currentHeight: 800_000, + wantErr: "unknown address type", + }, + { + name: "effective feerate as in parent", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 1000, + address: p2trAddress, + currentHeight: 800_000, + wantUnsignedChild: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2998930, + PkScript: p2trPkScript, + }, + }, + }, + wantChildFeeRate: 1000, + }, + { + name: "effective feerate below parent", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 500, + address: p2trAddress, + currentHeight: 800_000, + wantErr: "lower than effective fee rate", + }, + { + name: "high minRelayFee", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 10_000, + effectiveFeeRate: 2000, + address: p2trAddress, + currentHeight: 800_000, + wantUnsignedChild: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2994934, + PkScript: p2trPkScript, + }, + }, + }, + wantChildFeeRate: 10_000, + }, + { + name: "child fee too high", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 750_000, + address: p2trAddress, + currentHeight: 800_000, + wantErr: "is higher than 20% of total funds", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + childTx, childFeeRate, err := makeUnsignedCPFP( + tc.parentTxid, tc.parentOutput, tc.parentWeight, + tc.parentFee, tc.minRelayFee, + tc.effectiveFeeRate, tc.address, + tc.currentHeight, + ) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.wantUnsignedChild, childTx) + require.Equal( + t, tc.wantChildFeeRate, childFeeRate, + ) + } + }) + } +} diff --git a/sweepbatcher/greedy_batch_selection.go b/sweepbatcher/greedy_batch_selection.go index 30f1cb33a..88036630c 100644 --- a/sweepbatcher/greedy_batch_selection.go +++ b/sweepbatcher/greedy_batch_selection.go @@ -92,8 +92,8 @@ func (b *Batcher) greedyAddSweep(ctx context.Context, sweep *sweep) error { return nil } - log.Debugf("Batch selection algorithm returned batch id %d for"+ - " sweep %x, but acceptance failed.", batchId, + log().Debugf("Batch selection algorithm returned batch id %d "+ + "for sweep %x, but acceptance failed.", batchId, sweep.swapHash[:6]) } diff --git a/sweepbatcher/log.go b/sweepbatcher/log.go index 24d6cc297..7f33dc76a 100644 --- a/sweepbatcher/log.go +++ b/sweepbatcher/log.go @@ -2,15 +2,21 @@ package sweepbatcher import ( "fmt" + "sync/atomic" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" ) -// log is a logger that is initialized with no output filters. This +// log_ is a logger that is initialized with no output filters. This // means the package will not perform any logging by default until the // caller requests it. -var log btclog.Logger +var log_ atomic.Pointer[btclog.Logger] + +// log returns active logger. +func log() btclog.Logger { + return *log_.Load() +} // The default amount of logging is none. func init() { @@ -20,12 +26,12 @@ func init() { // batchPrefixLogger returns a logger that prefixes all log messages with // the ID. func batchPrefixLogger(batchID string) btclog.Logger { - return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log) + return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log()) } // UseLogger uses a specified Logger to output package logging info. // This should be used in preference to SetLogWriter if the caller is also // using btclog. func UseLogger(logger btclog.Logger) { - log = logger + log_.Store(&logger) } diff --git a/sweepbatcher/store_mock.go b/sweepbatcher/store_mock.go index 57cdd34b7..815b19917 100644 --- a/sweepbatcher/store_mock.go +++ b/sweepbatcher/store_mock.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sort" + "sync" "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/lntypes" @@ -13,6 +14,7 @@ import ( type StoreMock struct { batches map[int32]dbBatch sweeps map[lntypes.Hash]dbSweep + mu sync.Mutex } // NewStoreMock instantiates a new mock store. @@ -28,6 +30,9 @@ func NewStoreMock() *StoreMock { func (s *StoreMock) FetchUnconfirmedSweepBatches(ctx context.Context) ( []*dbBatch, error) { + s.mu.Lock() + defer s.mu.Unlock() + result := []*dbBatch{} for _, batch := range s.batches { batch := batch @@ -44,6 +49,9 @@ func (s *StoreMock) FetchUnconfirmedSweepBatches(ctx context.Context) ( func (s *StoreMock) InsertSweepBatch(ctx context.Context, batch *dbBatch) (int32, error) { + s.mu.Lock() + defer s.mu.Unlock() + var id int32 if len(s.batches) == 0 { @@ -66,12 +74,18 @@ func (s *StoreMock) DropBatch(ctx context.Context, id int32) error { func (s *StoreMock) UpdateSweepBatch(ctx context.Context, batch *dbBatch) error { + s.mu.Lock() + defer s.mu.Unlock() + s.batches[batch.ID] = *batch return nil } // ConfirmBatch confirms a batch. func (s *StoreMock) ConfirmBatch(ctx context.Context, id int32) error { + s.mu.Lock() + defer s.mu.Unlock() + batch, ok := s.batches[id] if !ok { return errors.New("batch not found") @@ -87,6 +101,9 @@ func (s *StoreMock) ConfirmBatch(ctx context.Context, id int32) error { func (s *StoreMock) FetchBatchSweeps(ctx context.Context, id int32) ([]*dbSweep, error) { + s.mu.Lock() + defer s.mu.Unlock() + result := []*dbSweep{} for _, sweep := range s.sweeps { sweep := sweep @@ -104,7 +121,11 @@ func (s *StoreMock) FetchBatchSweeps(ctx context.Context, // UpsertSweep inserts a sweep into the database, or updates an existing sweep. func (s *StoreMock) UpsertSweep(ctx context.Context, sweep *dbSweep) error { + s.mu.Lock() + defer s.mu.Unlock() + s.sweeps[sweep.SwapHash] = *sweep + return nil } @@ -112,6 +133,9 @@ func (s *StoreMock) UpsertSweep(ctx context.Context, sweep *dbSweep) error { func (s *StoreMock) GetSweepStatus(ctx context.Context, swapHash lntypes.Hash) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + sweep, ok := s.sweeps[swapHash] if !ok { return false, nil @@ -127,6 +151,9 @@ func (s *StoreMock) Close() error { // AssertSweepStored asserts that a sweep is stored. func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.sweeps[id] return ok } @@ -135,6 +162,9 @@ func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool { func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( *dbBatch, error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, sweep := range s.sweeps { if sweep.SwapHash == swapHash { batch, ok := s.batches[sweep.BatchID] @@ -153,6 +183,9 @@ func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( func (s *StoreMock) TotalSweptAmount(ctx context.Context, batchID int32) ( btcutil.Amount, error) { + s.mu.Lock() + defer s.mu.Unlock() + batch, ok := s.batches[batchID] if !ok { return 0, errors.New("batch not found") diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index 6a4989cd9..a0ffb1b8a 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -9,6 +9,7 @@ import ( "math" "strings" "sync" + "sync/atomic" "time" "github.com/btcsuite/btcd/blockchain" @@ -16,6 +17,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil/psbt" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" @@ -119,6 +121,9 @@ type sweep struct { // but it failed. We try to spend a sweep cooperatively only once. This // status is not persisted in the DB. coopFailed bool + + // cpfp is set, if the sweep should be handled in CPFP mode. + cpfp bool } // batchState is the state of the batch. @@ -172,6 +177,14 @@ type batchConfig struct { // Note that musig2SignSweep must be nil in this case, however signer // client must still be provided, as it is used for non-coop spendings. customMuSig2Signer SignMuSig2 + + // cpfpHelper provides methods used when a custom tx signer and CPFP + // are enabled. + cpfpHelper CpfpHelper + + // chainParams are the chain parameters of the chain that is used by + // batches. + chainParams *chaincfg.Params } // rbfCache stores data related to our last fee bump. @@ -214,6 +227,12 @@ type batch struct { // reorgChan is the channel over which reorg notifications are received. reorgChan chan struct{} + // testReqs is a channel where test requests are received. + // This is used only in unit tests! The reason to have this is to + // avoid data races in require.Eventually calls running in parallel + // to the event loop. See method testRunInEventLoop(). + testReqs chan *testRequest + // errChan is the channel over which errors are received. errChan chan error @@ -284,8 +303,8 @@ type batch struct { // cfg is the configuration for this batch. cfg *batchConfig - // log is the logger for this batch. - log btclog.Logger + // log_ is the logger for this batch. + log_ atomic.Pointer[btclog.Logger] wg sync.WaitGroup } @@ -351,6 +370,7 @@ func NewBatch(cfg batchConfig, bk batchKit) *batch { spendChan: make(chan *chainntnfs.SpendDetail), confChan: make(chan *chainntnfs.TxConfirmation, 1), reorgChan: make(chan struct{}, 1), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), callEnter: make(chan struct{}), callLeave: make(chan struct{}), @@ -387,7 +407,7 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { } } - return &batch{ + b := &batch{ id: bk.id, state: bk.state, primarySweepID: bk.primaryID, @@ -395,6 +415,7 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { spendChan: make(chan *chainntnfs.SpendDetail), confChan: make(chan *chainntnfs.TxConfirmation, 1), reorgChan: make(chan struct{}, 1), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), callEnter: make(chan struct{}), callLeave: make(chan struct{}), @@ -412,13 +433,28 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { publishErrorHandler: bk.publishErrorHandler, purger: bk.purger, store: bk.store, - log: bk.log, cfg: &cfg, - }, nil + } + + b.setLog(bk.log) + + return b, nil +} + +// log returns current logger. +func (b *batch) log() btclog.Logger { + return *b.log_.Load() +} + +// setLog atomically replaces the logger. +func (b *batch) setLog(logger btclog.Logger) { + b.log_.Store(&logger) } // addSweep tries to add a sweep to the batch. If this is the first sweep being -// added to the batch then it also sets the primary sweep ID. +// added to the batch then it also sets the primary sweep ID. If CPFP mode is +// enabled, the result depends on the outcome of cpfpHelper.Presign for a +// non-empty batch. For an empty batch, the input needs to pass PresignSweep. func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { done, err := b.scheduleNextCall() defer done() @@ -430,7 +466,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // If the provided sweep is nil, we can't proceed with any checks, so // we just return early. if sweep == nil { - b.log.Infof("the sweep is nil") + b.log().Infof("the sweep is nil") return false, nil } @@ -473,7 +509,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // the batch, do not add another sweep to prevent the tx from becoming // non-standard. if len(b.sweeps) >= MaxSweepsPerBatch { - b.log.Infof("the batch has already too many sweeps (%d >= %d)", + b.log().Infof("the batch has already too many sweeps %d >= %d", len(b.sweeps), MaxSweepsPerBatch) return false, nil @@ -483,7 +519,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // arrive here after the batch got closed because of a spend. In this // case we cannot add the sweep to this batch. if b.state != Open { - b.log.Infof("the batch state (%v) is not open", b.state) + b.log().Infof("the batch state (%v) is not open", b.state) return false, nil } @@ -493,14 +529,14 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // we cannot add this sweep to the batch. for _, s := range b.sweeps { if s.isExternalAddr { - b.log.Infof("the batch already has a sweep (%x) with "+ + b.log().Infof("the batch already has a sweep %x with "+ "an external address", s.swapHash[:6]) return false, nil } if sweep.isExternalAddr { - b.log.Infof("the batch is not empty and new sweep (%x)"+ + b.log().Infof("the batch is not empty and new sweep %x"+ " has an external address", sweep.swapHash[:6]) return false, nil @@ -515,7 +551,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { int32(math.Abs(float64(sweep.timeout - s.timeout))) if timeoutDistance > b.cfg.maxTimeoutDistance { - b.log.Infof("too long timeout distance between the "+ + b.log().Infof("too long timeout distance between the "+ "batch and sweep %x: %d > %d", sweep.swapHash[:6], timeoutDistance, b.cfg.maxTimeoutDistance) @@ -524,6 +560,54 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { } } + // If CPFP mode is enabled, we should first presign the new version of + // batch transaction. Also ensure that all the sweeps in the batch use + // the same mode (CPFP or regular). + if sweep.cpfp { + // Ensure that all the sweeps in the batch use also CPFP mode. + for _, s := range b.sweeps { + if !s.cpfp { + b.log().Infof("failed to add sweep %x to the "+ + "batch, because the batch has "+ + "non-CPFP sweep %x", sweep.swapHash[:6], + s.swapHash[:6]) + + return false, nil + } + } + + if len(b.sweeps) != 0 { + if err := b.presign(ctx, sweep); err != nil { + b.log().Infof("failed to add sweep %x to the "+ + "batch, because failed to presign new "+ + "version of batch tx: %v", + sweep.swapHash[:6], err) + + return false, nil + } + } else { + if err := b.ensurePresigned(ctx, sweep); err != nil { + return false, fmt.Errorf("failed to check "+ + "signing of input %x, this means that "+ + "batcher.PresignSweep was not called "+ + "prior to AddSweep for this input: %w", + sweep.swapHash[:6], err) + } + } + } else { + // Ensure that all the sweeps in the batch don't use CPFP. + for _, s := range b.sweeps { + if s.cpfp { + b.log().Infof("failed to add sweep %x to the "+ + "batch, because the batch has "+ + "CPFP sweep %x", sweep.swapHash[:6], + s.swapHash[:6]) + + return false, nil + } + } + } + // Past this point we know that a new incoming sweep passes the // acceptance criteria and is now ready to be added to this batch. @@ -544,7 +628,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { } // Add the sweep to the batch's sweeps. - b.log.Infof("adding sweep %x", sweep.swapHash[:6]) + b.log().Infof("adding sweep %x", sweep.swapHash[:6]) b.sweeps[sweep.swapHash] = *sweep // Update FeeRate. Max(sweep.minFeeRate) for all the sweeps of @@ -572,7 +656,7 @@ func (b *batch) sweepExists(hash lntypes.Hash) bool { // Wait waits for the batch to gracefully stop. func (b *batch) Wait() { - b.log.Infof("Stopping") + b.log().Infof("Stopping") <-b.finished } @@ -613,7 +697,7 @@ func (b *batch) Run(ctx context.Context) error { // Set currentHeight here, because it may be needed in monitorSpend. select { case b.currentHeight = <-blockChan: - b.log.Debugf("initial height for the batch is %v", + b.log().Debugf("initial height for the batch is %v", b.currentHeight) case <-runCtx.Done(): @@ -652,7 +736,7 @@ func (b *batch) Run(ctx context.Context) error { // completes. timerChan := clock.TickAfter(b.cfg.batchPublishDelay) - b.log.Infof("started, primary %x, total sweeps %v", + b.log().Infof("started, primary %x, total sweeps %v", b.primarySweepID[0:6], len(b.sweeps)) for { @@ -662,7 +746,7 @@ func (b *batch) Run(ctx context.Context) error { // blockChan provides immediately the current tip. case height := <-blockChan: - b.log.Debugf("received block %v", height) + b.log().Debugf("received block %v", height) // Set the timer to publish the batch transaction after // the configured delay. @@ -670,7 +754,7 @@ func (b *batch) Run(ctx context.Context) error { b.currentHeight = height case <-initialDelayChan: - b.log.Debugf("initial delay of duration %v has ended", + b.log().Debugf("initial delay of duration %v has ended", b.cfg.initialDelay) // Set the timer to publish the batch transaction after @@ -680,8 +764,8 @@ func (b *batch) Run(ctx context.Context) error { case <-timerChan: // Check that batch is still open. if b.state != Open { - b.log.Debugf("Skipping publishing, because the"+ - " batch is not open (%v).", b.state) + b.log().Debugf("Skipping publishing, because "+ + "the batch is not open (%v).", b.state) continue } @@ -695,7 +779,7 @@ func (b *batch) Run(ctx context.Context) error { // initialDelayChan has just fired, this check passes. now := clock.Now() if skipBefore.After(now) { - b.log.Debugf(stillWaitingMsg, skipBefore, now) + b.log().Debugf(stillWaitingMsg, skipBefore, now) continue } @@ -715,14 +799,18 @@ func (b *batch) Run(ctx context.Context) error { case <-b.reorgChan: b.state = Open - b.log.Warnf("reorg detected, batch is able to accept " + - "new sweeps") + b.log().Warnf("reorg detected, batch is able to " + + "accept new sweeps") err := b.monitorSpend(ctx, b.sweeps[b.primarySweepID]) if err != nil { return err } + case testReq := <-b.testReqs: + testReq.handler() + close(testReq.quit) + case err := <-blockErrChan: return err @@ -735,6 +823,36 @@ func (b *batch) Run(ctx context.Context) error { } } +// testRunInEventLoop runs a function in the event loop blocking until +// the function returns. For unit tests only! +func (b *batch) testRunInEventLoop(ctx context.Context, handler func()) { + // If the event loop is finished, run the function. + select { + case <-b.stopping: + handler() + + return + default: + } + + quit := make(chan struct{}) + req := &testRequest{ + handler: handler, + quit: quit, + } + + select { + case b.testReqs <- req: + case <-ctx.Done(): + return + } + + select { + case <-quit: + case <-ctx.Done(): + } +} + // timeout returns minimum timeout as block height among sweeps of the batch. // If the batch is empty, return -1. func (b *batch) timeout() int32 { @@ -755,7 +873,7 @@ func (b *batch) timeout() int32 { func (b *batch) isUrgent(skipBefore time.Time) bool { timeout := b.timeout() if timeout <= 0 { - b.log.Warnf("Method timeout() returned %v. Number of"+ + b.log().Warnf("Method timeout() returned %v. Number of"+ " sweeps: %d. It may be an empty batch.", timeout, len(b.sweeps)) return false @@ -779,13 +897,29 @@ func (b *batch) isUrgent(skipBefore time.Time) bool { return false } - b.log.Debugf("cancelling waiting for urgent sweep (timeBank is %v, "+ + b.log().Debugf("cancelling waiting for urgent sweep (timeBank is %v, "+ "remainingWaiting is %v)", timeBank, remainingWaiting) // Signal to the caller to cancel initialDelay. return true } +// isCPFP returns if the batch uses CPFP mode. Currently CPFP and non-CPFP +// sweeps never appear in the same batch. Fails if the batch is empty. +func (b *batch) isCPFP() (bool, error) { + if len(b.sweeps) == 0 { + return false, fmt.Errorf("the batch is empty") + } + + for _, sweep := range b.sweeps { + if sweep.cpfp { + return true, nil + } + } + + return false, nil +} + // publish creates and publishes the latest batch transaction to the network. func (b *batch) publish(ctx context.Context) error { var ( @@ -795,7 +929,7 @@ func (b *batch) publish(ctx context.Context) error { ) if len(b.sweeps) == 0 { - b.log.Debugf("skipping publish: no sweeps in the batch") + b.log().Debugf("skipping publish: no sweeps in the batch") return nil } @@ -808,10 +942,22 @@ func (b *batch) publish(ctx context.Context) error { // logPublishError is a function which logs publish errors. logPublishError := func(errMsg string, err error) { - b.publishErrorHandler(err, errMsg, b.log) + b.publishErrorHandler(err, errMsg, b.log()) + } + + // Determine if we should use CPFP mode for the batch. + cpfp, err := b.isCPFP() + if err != nil { + return fmt.Errorf("failed to determine if the batch %d uses "+ + "CPFP mode: %w", b.id, err) + } + + if cpfp { + fee, err, signSuccess = b.publishWithCPFP(ctx) + } else { + fee, err, signSuccess = b.publishMixedBatch(ctx) } - fee, err, signSuccess = b.publishMixedBatch(ctx) if err != nil { if signSuccess { logPublishError("publish error", err) @@ -830,9 +976,9 @@ func (b *batch) publish(ctx context.Context) error { } } - b.log.Infof("published, total sweeps: %v, fees: %v", len(b.sweeps), fee) + b.log().Infof("published, total sweeps: %v, fees: %v", len(b.sweeps), fee) for _, sweep := range b.sweeps { - b.log.Infof("published sweep %x, value: %v", + b.log().Infof("published sweep %x, value: %v", sweep.swapHash[:6], sweep.value) } @@ -882,9 +1028,9 @@ func (b *batch) createPsbt(unsignedTx *wire.MsgTx, sweeps []sweep) ([]byte, // constructUnsignedTx creates unsigned tx from the sweeps, paying to the addr. // It also returns absolute fee (from weight and clamped). -func (b *batch) constructUnsignedTx(sweeps []sweep, - address btcutil.Address) (*wire.MsgTx, lntypes.WeightUnit, - btcutil.Amount, btcutil.Amount, error) { +func constructUnsignedTx(sweeps []sweep, address btcutil.Address, + currentHeight int32, feeRate chainfee.SatPerKWeight) (*wire.MsgTx, + lntypes.WeightUnit, btcutil.Amount, btcutil.Amount, error) { // Sanity check, there should be at least 1 sweep in this batch. if len(sweeps) == 0 { @@ -894,7 +1040,7 @@ func (b *batch) constructUnsignedTx(sweeps []sweep, // Create the batch transaction. batchTx := &wire.MsgTx{ Version: 2, - LockTime: uint32(b.currentHeight), + LockTime: uint32(currentHeight), } // Add transaction inputs and estimate its weight. @@ -946,7 +1092,7 @@ func (b *batch) constructUnsignedTx(sweeps []sweep, // Find weight and fee. weight := weightEstimate.Weight() - feeForWeight := b.rbfCache.FeeRate.FeeForWeight(weight) + feeForWeight := feeRate.FeeForWeight(weight) // Clamp the calculated fee to the max allowed fee amount for the batch. fee := clampBatchFee(feeForWeight, batchAmt) @@ -1026,13 +1172,13 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, coopInputs int ) for attempt := 1; ; attempt++ { - b.log.Infof("Attempt %d of collecting cooperative signatures.", - attempt) + b.log().Infof("Attempt %d of collecting cooperative "+ + "signatures.", attempt) // Construct unsigned batch transaction. var err error - tx, weight, feeForWeight, fee, err = b.constructUnsignedTx( - sweeps, address, + tx, weight, feeForWeight, fee, err = constructUnsignedTx( + sweeps, address, b.currentHeight, b.rbfCache.FeeRate, ) if err != nil { return 0, fmt.Errorf("failed to construct tx: %w", err), @@ -1062,7 +1208,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, ctx, i, sweep, tx, prevOutsMap, psbtBytes, ) if err != nil { - b.log.Infof("cooperative signing failed for "+ + b.log().Infof("cooperative signing failed for "+ "sweep %x: %v", sweep.swapHash[:6], err) // Set coopFailed flag for this sweep in all the @@ -1201,7 +1347,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, } } txHash := tx.TxHash() - b.log.Infof("attempting to publish batch tx=%v with feerate=%v, "+ + b.log().Infof("attempting to publish batch tx=%v with feerate=%v, "+ "weight=%v, feeForWeight=%v, fee=%v, sweeps=%d, "+ "%d cooperative: (%s) and %d non-cooperative (%s), destAddr=%s", txHash, b.rbfCache.FeeRate, weight, feeForWeight, fee, @@ -1215,7 +1361,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, blockchain.GetTransactionWeight(btcutil.NewTx(tx)), ) if realWeight != weight { - b.log.Warnf("actual weight of tx %v is %v, estimated as %d", + b.log().Warnf("actual weight of tx %v is %v, estimated as %d", txHash, realWeight, weight) } @@ -1239,11 +1385,11 @@ func (b *batch) debugLogTx(msg string, tx *wire.MsgTx) { // Serialize the transaction and convert to hex string. buf := bytes.NewBuffer(make([]byte, 0, tx.SerializeSize())) if err := tx.Serialize(buf); err != nil { - b.log.Errorf("failed to serialize tx for debug log: %v", err) + b.log().Errorf("failed to serialize tx for debug log: %v", err) return } - b.log.Debugf("%s: %s", msg, hex.EncodeToString(buf.Bytes())) + b.log().Debugf("%s: %s", msg, hex.EncodeToString(buf.Bytes())) } // musig2sign signs one sweep using musig2. @@ -1405,14 +1551,14 @@ func (b *batch) updateRbfRate(ctx context.Context) error { if b.rbfCache.FeeRate == 0 { // We set minFeeRate in each sweep, so fee rate is expected to // be initiated here. - b.log.Warnf("rbfCache.FeeRate is 0, which must not happen.") + b.log().Warnf("rbfCache.FeeRate is 0, which must not happen.") if b.cfg.batchConfTarget == 0 { - b.log.Warnf("updateRbfRate called with zero " + + b.log().Warnf("updateRbfRate called with zero " + "batchConfTarget") } - b.log.Infof("initializing rbf fee rate for conf target=%v", + b.log().Infof("initializing rbf fee rate for conf target=%v", b.cfg.batchConfTarget) rate, err := b.wallet.EstimateFeeRate( ctx, b.cfg.batchConfTarget, @@ -1461,7 +1607,7 @@ func (b *batch) monitorSpend(ctx context.Context, primarySweep sweep) error { defer cancel() defer b.wg.Done() - b.log.Infof("monitoring spend for outpoint %s", + b.log().Infof("monitoring spend for outpoint %s", primarySweep.outpoint.String()) for { @@ -1584,7 +1730,7 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { if len(spendTx.TxOut) > 0 { b.batchPkScript = spendTx.TxOut[0].PkScript } else { - b.log.Warnf("transaction %v has no outputs", txHash) + b.log().Warnf("transaction %v has no outputs", txHash) } // As a previous version of the batch transaction may get confirmed, @@ -1666,13 +1812,13 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { err := b.purger(&sweep) if err != nil { - b.log.Errorf("unable to purge sweep %x: %v", + b.log().Errorf("unable to purge sweep %x: %v", sweep.SwapHash[:6], err) } } }() - b.log.Infof("spent, total sweeps: %v, purged sweeps: %v", + b.log().Infof("spent, total sweeps: %v, purged sweeps: %v", len(notifyList), len(purgeList)) err := b.monitorConfirmations(ctx) @@ -1690,7 +1836,29 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { // handleConf handles a confirmation notification. This is the final step of the // batch. Here we signal to the batcher that this batch was completed. func (b *batch) handleConf(ctx context.Context) error { - b.log.Infof("confirmed in txid %s", b.batchTxid) + // If the batch is in CPFP mode, cleanup cpfpHelper. + cpfp, err := b.isCPFP() + if err != nil { + return fmt.Errorf("failed to determine if the batch %d uses "+ + "CPFP mode: %w", b.id, err) + } + + if cpfp { + b.log().Infof("Cleaning up CPFP store") + + inputs := make([]wire.OutPoint, 0, len(b.sweeps)) + for _, sweep := range b.sweeps { + inputs = append(inputs, sweep.outpoint) + } + + err := b.cfg.cpfpHelper.CleanupTransactions(ctx, inputs) + if err != nil { + return fmt.Errorf("failed to clean up store for "+ + "batch %d, inputs %v: %w", b.id, inputs, err) + } + } + + b.log().Infof("confirmed in txid %s", b.batchTxid) b.state = Confirmed return b.store.ConfirmBatch(ctx, b.id) @@ -1732,7 +1900,20 @@ func (b *batch) persist(ctx context.Context) error { // getBatchDestAddr returns the batch's destination address. If the batch // has already generated an address then the same one will be returned. +// The method must not be used in CPFP mode. Use getSweepsDestAddr instead. func (b *batch) getBatchDestAddr(ctx context.Context) (btcutil.Address, error) { + // Determine if we should use CPFP mode for the batch. + cpfp, err := b.isCPFP() + if err != nil { + return nil, fmt.Errorf("failed to determine if the batch %d "+ + "uses CPFP mode: %w", b.id, err) + } + + // Make sure that the method is not used for CPFP batches. + if cpfp { + return nil, fmt.Errorf("getBatchDestAddr used in CPFP mode") + } + var address btcutil.Address // If a batch address is set, use that. Otherwise, generate a @@ -1769,7 +1950,7 @@ func (b *batch) insertAndAcquireID(ctx context.Context) (int32, error) { } b.id = id - b.log = batchPrefixLogger(fmt.Sprintf("%d", b.id)) + b.setLog(batchPrefixLogger(fmt.Sprintf("%d", b.id))) return id, nil } diff --git a/sweepbatcher/sweep_batch_test.go b/sweepbatcher/sweep_batch_test.go new file mode 100644 index 000000000..dd873b300 --- /dev/null +++ b/sweepbatcher/sweep_batch_test.go @@ -0,0 +1,319 @@ +package sweepbatcher + +import ( + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/utils" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/require" +) + +// TestConstructUnsignedTx verifies that the function constructUnsignedTx +// correctly creates unsigned transactions. +func TestConstructUnsignedTx(t *testing.T) { + // Prepare the necessary data for test cases. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + p2trAddr := "bcrt1pa38tp2hgjevqv3jcsxeu7v72n0s5a3ck8q2u8r" + + "k6mm67dv7uk26qq8je7e" + p2trAddress, err := btcutil.DecodeAddress(p2trAddr, nil) + require.NoError(t, err) + p2trPkScript, err := txscript.PayToAddrScript(p2trAddress) + require.NoError(t, err) + + serializedPubKey := []byte{ + 0x02, 0x19, 0x2d, 0x74, 0xd0, 0xcb, 0x94, 0x34, 0x4c, 0x95, + 0x69, 0xc2, 0xe7, 0x79, 0x01, 0x57, 0x3d, 0x8d, 0x79, 0x03, + 0xc3, 0xeb, 0xec, 0x3a, 0x95, 0x77, 0x24, 0x89, 0x5d, 0xca, + 0x52, 0xc6, 0xb4} + p2pkAddress, err := btcutil.NewAddressPubKey( + serializedPubKey, &chaincfg.RegressionNetParams, + ) + require.NoError(t, err) + + swapHash := lntypes.Hash{1, 1, 1} + + swapContract := &loopdb.SwapContract{ + CltvExpiry: 222, + AmountRequested: 2_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + } + + htlc, err := utils.GetHtlc( + swapHash, swapContract, &chaincfg.RegressionNetParams, + ) + require.NoError(t, err) + estimator := htlc.AddSuccessToEstimator + + brokenEstimator := func(*input.TxWeightEstimator) error { + return fmt.Errorf("weight estimator test failure") + } + + cases := []struct { + name string + sweeps []sweep + address btcutil.Address + currentHeight int32 + feeRate chainfee.SatPerKWeight + wantErr string + wantTx *wire.MsgTx + wantWeight lntypes.WeightUnit + wantFeeForWeight btcutil.Amount + wantFee btcutil.Amount + }{ + { + name: "no sweeps error", + wantErr: "no sweeps in batch", + }, + + { + name: "two coop sweeps", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: destAddr, + currentHeight: 800_000, + feeRate: 1000, + wantTx: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + }, + { + PreviousOutPoint: op2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + wantWeight: 626, + wantFeeForWeight: 626, + wantFee: 626, + }, + + { + name: "p2tr destination address", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: p2trAddress, + currentHeight: 800_000, + feeRate: 1000, + wantTx: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + }, + { + PreviousOutPoint: op2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999326, + PkScript: p2trPkScript, + }, + }, + }, + wantWeight: 674, + wantFeeForWeight: 674, + wantFee: 674, + }, + + { + name: "unknown kind of address", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: nil, + wantErr: "unsupported address type", + }, + + { + name: "pay-to-pubkey address", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: p2pkAddress, + wantErr: "unknown address type", + }, + + { + name: "fee more than 20% clamped", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: destAddr, + currentHeight: 800_000, + feeRate: 1_000_000, + wantTx: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + }, + { + PreviousOutPoint: op2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2400000, + PkScript: batchPkScript, + }, + }, + }, + wantWeight: 626, + wantFeeForWeight: 626_000, + wantFee: 600_000, + }, + + { + name: "coop and noncoop", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + nonCoopHint: true, + htlc: *htlc, + htlcSuccessEstimator: estimator, + }, + }, + address: destAddr, + currentHeight: 800_000, + feeRate: 1000, + wantTx: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + }, + { + PreviousOutPoint: op2, + Sequence: 1, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999211, + PkScript: batchPkScript, + }, + }, + }, + wantWeight: 789, + wantFeeForWeight: 789, + wantFee: 789, + }, + + { + name: "weight estimator fails", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + nonCoopHint: true, + htlc: *htlc, + htlcSuccessEstimator: brokenEstimator, + }, + }, + address: destAddr, + currentHeight: 800_000, + feeRate: 1000, + wantErr: "sweep.htlcSuccessEstimator failed: " + + "weight estimator test failure", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + tx, weight, feeForW, fee, err := constructUnsignedTx( + tc.sweeps, tc.address, tc.currentHeight, + tc.feeRate, + ) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.wantTx, tx) + require.Equal(t, tc.wantWeight, weight) + require.Equal(t, tc.wantFeeForWeight, feeForW) + require.Equal(t, tc.wantFee, fee) + } + }) + } +} diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 3fe9fe9c4..65b230303 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -153,6 +153,50 @@ type SignMuSig2 func(ctx context.Context, muSig2Version input.MuSig2Version, swapHash lntypes.Hash, rootHash chainhash.Hash, sigHash [32]byte, ) ([]byte, error) +// CpfpHelper provides methods used when a custom tx signer and CPFP are used. +// In this mode sweepbatcher uses transactions provided by CPFP helper, which +// may be pre-signed and non-RBF'able, in which case CPFP may be needed. CPFP +// helper also provides transactions it previously produced by txid and affects +// batch selection - it has method Presign called upon new batch creation and +// adding to existing batch. +type CpfpHelper interface { + // IsCpfpApplied returns if CPFP mode is enabled for a particular sweep. + // This method should always return the same value for the same sweep. + // Currently CPFP and non-CPFP sweeps never appear in the same batch. + IsCpfpApplied(ctx context.Context, input wire.OutPoint) (bool, error) + + // Presign tries to presign a batch transaction. If the method returns + // nil, it is guaranteed that future calls to SignTx on this set of + // sweeps return valid signed transactions. + Presign(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount) error + + // DestPkScript returns destination pkScript used in a presigned + // transaction sweeping the inputs. Returns an error, if such tx + // doesn't exist. If there are many such transactions, returns any + // of pkScript's. + DestPkScript(ctx context.Context, + inputs []wire.OutPoint) ([]byte, error) + + // SignTx signs an unsigned transaction or returns a pre-signed tx. + // It must satisfy the following invariants: + // - the set of inputs is the same, though the order may change; + // - the output is the same, but its amount may be different; + // - feerate is higher or equal to minRelayFee; + // - LockTime may be decreased; + // - transaction version must be the same; + // - Sequence numbers in the inputs must be preserved. + SignTx(ctx context.Context, tx *wire.MsgTx, inputAmt btcutil.Amount, + minRelayFee chainfee.SatPerKWeight) (*wire.MsgTx, error) + + // LoadTx returns any tx previously returned by SignTx. + LoadTx(ctx context.Context, txid chainhash.Hash) (*wire.MsgTx, error) + + // CleanupTransactions removes all transactions related to any of the + // outpoints. Should be called after sweep batch tx is fully confirmed. + CleanupTransactions(ctx context.Context, inputs []wire.OutPoint) error +} + // VerifySchnorrSig is a function that can be used to verify a schnorr // signature. type VerifySchnorrSig func(pubKey *btcec.PublicKey, hash, sig []byte) error @@ -225,6 +269,16 @@ var ( ErrBatcherShuttingDown = errors.New("batcher shutting down") ) +// testRequest is a function passed to an event loop and a channel used to +// wait until the function is executed. This is used in unit tests only! +type testRequest struct { + // handler is the function to an event loop. + handler func() + + // quit is closed when the handler completes. + quit chan struct{} +} + // Batcher is a system that is responsible for accepting sweep requests and // placing them in appropriate batches. It will spin up new batches as needed. type Batcher struct { @@ -234,6 +288,12 @@ type Batcher struct { // sweepReqs is a channel where sweep requests are received. sweepReqs chan SweepRequest + // testReqs is a channel where test requests are received. + // This is used only in unit tests! The reason to have this is to + // avoid data races in require.Eventually calls running in parallel + // to the event loop. See method testRunInEventLoop(). + testReqs chan *testRequest + // errChan is a channel where errors are received. errChan chan error @@ -313,6 +373,10 @@ type Batcher struct { // error. By default, it logs all errors as warnings, but "insufficient // fee" as Info. publishErrorHandler PublishErrorHandler + + // cpfpHelper provides methods used when a custom tx signer and CPFP + // are enabled. + cpfpHelper CpfpHelper } // BatcherConfig holds batcher configuration. @@ -353,6 +417,10 @@ type BatcherConfig struct { // error. By default, it logs all errors as warnings, but "insufficient // fee" as Info. publishErrorHandler PublishErrorHandler + + // cpfpHelper provides methods used when a custom tx signer and CPFP + // are enabled. + cpfpHelper CpfpHelper } // BatcherOption configures batcher behaviour. @@ -426,6 +494,20 @@ func WithPublishErrorHandler(handler PublishErrorHandler) BatcherOption { } } +// WithCpfpHelper instructs sweepbatcher to switch to mode in which it may use +// CPFP for fee bumping. In this mode it uses transactions provided by CPFP +// helper, which may be pre-signed and non-RBF'able, in which case CPFP may be +// needed. CPFP helper also provides transactions it previously produced by txid +// and affects batch selection - it has method Presign called upon new batch +// creation and adding to existing batch. In CPFP mode method PresignSweep must +// be called prior to AddSweep. If PresignSweep fails, AddSweep must not be +// called. +func WithCpfpHelper(cpfpHelper CpfpHelper) BatcherOption { + return func(cfg *BatcherConfig) { + cfg.cpfpHelper = cpfpHelper + } +} + // NewBatcher creates a new Batcher instance. func NewBatcher(wallet lndclient.WalletKitClient, chainNotifier lndclient.ChainNotifierClient, @@ -461,6 +543,7 @@ func NewBatcher(wallet lndclient.WalletKitClient, return &Batcher{ batches: make(map[int32]*batch), sweepReqs: make(chan SweepRequest), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), quit: make(chan struct{}), initDone: make(chan struct{}), @@ -479,6 +562,7 @@ func NewBatcher(wallet lndclient.WalletKitClient, txLabeler: cfg.txLabeler, customMuSig2Signer: cfg.customMuSig2Signer, publishErrorHandler: cfg.publishErrorHandler, + cpfpHelper: cfg.cpfpHelper, } } @@ -518,29 +602,58 @@ func (b *Batcher) Run(ctx context.Context) error { case sweepReq := <-b.sweepReqs: sweep, err := b.fetchSweep(runCtx, sweepReq) if err != nil { - log.Warnf("fetchSweep failed: %v.", err) + log().Warnf("fetchSweep failed: %v.", err) + return err } err = b.handleSweep(runCtx, sweep, sweepReq.Notifier) if err != nil { - log.Warnf("handleSweep failed: %v.", err) + log().Warnf("handleSweep failed: %v.", err) + return err } + case testReq := <-b.testReqs: + testReq.handler() + close(testReq.quit) + case err := <-b.errChan: - log.Warnf("Batcher received an error: %v.", err) + log().Warnf("Batcher received an error: %v.", err) + return err case <-runCtx.Done(): - log.Infof("Stopping Batcher: run context cancelled.") + log().Infof("Stopping Batcher: run context cancelled.") + return runCtx.Err() } } } +// PresignSweep creates and stores presigned 1:1 transactions for the sweep. +// This method must be called prior to AddSweep if CPFP mode is enabled. +func (b *Batcher) PresignSweep(ctx context.Context, sweepOutpoint wire.OutPoint, + sweepValue btcutil.Amount, destAddress btcutil.Address) error { + + if b.cpfpHelper == nil { + return fmt.Errorf("cpfpHelper is not installed") + } + + sweeps := []sweep{ + { + outpoint: sweepOutpoint, + value: sweepValue, + }, + } + + return presign(ctx, b.cpfpHelper, sweeps, destAddress) +} + // AddSweep adds a sweep request to the batcher for handling. This will either -// place the sweep in an existing batch or create a new one. +// place the sweep in an existing batch or create a new one. In CPFP mode call +// PresignSweep prior to AddSweep. If PresignSweep fails, AddSweep must not be +// called. func (b *Batcher) AddSweep(sweepReq *SweepRequest) error { select { case b.sweepReqs <- *sweepReq: @@ -551,6 +664,36 @@ func (b *Batcher) AddSweep(sweepReq *SweepRequest) error { } } +// testRunInEventLoop runs a function in the event loop blocking until +// the function returns. For unit tests only! +func (b *Batcher) testRunInEventLoop(ctx context.Context, handler func()) { + // If the event loop is finished, run the function. + select { + case <-b.quit: + handler() + + return + default: + } + + quit := make(chan struct{}) + req := &testRequest{ + handler: handler, + quit: quit, + } + + select { + case b.testReqs <- req: + case <-ctx.Done(): + return + } + + select { + case <-quit: + case <-ctx.Done(): + } +} + // handleSweep handles a sweep request by either placing it in an existing // batch, or by spinning up a new batch for it. func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, @@ -561,8 +704,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return err } - log.Infof("Batcher handling sweep %x, completed=%v", sweep.swapHash[:6], - completed) + log().Infof("Batcher handling sweep %x, cpfp=%v, completed=%v", + sweep.swapHash[:6], sweep.cpfp, completed) // If the sweep has already been completed in a confirmed batch then we // can't attach its notifier to the batch as that is no longer running. @@ -573,8 +716,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, // on-chain confirmations to prevent issues caused by reorgs. parentBatch, err := b.store.GetParentBatch(ctx, sweep.swapHash) if err != nil { - log.Errorf("unable to get parent batch for sweep %x: "+ - "%v", sweep.swapHash[:6], err) + log().Errorf("unable to get parent batch for sweep %x:"+ + " %v", sweep.swapHash[:6], err) return err } @@ -590,16 +733,18 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, sweep.notifier = notifier - // Check if the sweep is already in a batch. If that is the case, we - // provide the sweep to that batch and return. + // This is a check to see if a batch is completed. In that case we just + // lazily delete it. for _, batch := range b.batches { - // This is a check to see if a batch is completed. In that case - // we just lazily delete it and continue our scan. if batch.isComplete() { delete(b.batches, batch.id) continue } + } + // Check if the sweep is already in a batch. If that is the case, we + // provide the sweep to that batch and return. + for _, batch := range b.batches { if batch.sweepExists(sweep.swapHash) { accepted, err := batch.addSweep(ctx, sweep) if err != nil && !errors.Is(err, ErrBatchShuttingDown) { @@ -624,8 +769,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return nil } - log.Warnf("Greedy batch selection algorithm failed for sweep %x: %v. "+ - "Falling back to old approach.", sweep.swapHash[:6], err) + log().Warnf("Greedy batch selection algorithm failed for sweep %x: %v."+ + " Falling back to old approach.", sweep.swapHash[:6], err) // If one of the batches accepts the sweep, we provide it to that batch. for _, batch := range b.batches { @@ -646,7 +791,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return b.spinUpNewBatch(ctx, sweep) } -// spinUpNewBatch creates new batch, starts it and adds the sweep to it. +// spinUpNewBatch creates new batch, starts it and adds the sweep to it. If CPFP +// mode is enabled, the result also depends on outcome of cpfpHelper.Presign. func (b *Batcher) spinUpNewBatch(ctx context.Context, sweep *sweep) error { // Spin up a fresh batch. newBatch, err := b.spinUpBatch(ctx) @@ -730,13 +876,13 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error { } if len(dbSweeps) == 0 { - log.Infof("skipping restored batch %d as it has no sweeps", + log().Infof("skipping restored batch %d as it has no sweeps", batch.id) // It is safe to drop this empty batch as it has no sweeps. err := b.store.DropBatch(ctx, batch.id) if err != nil { - log.Warnf("unable to drop empty batch %d: %v", + log().Warnf("unable to drop empty batch %d: %v", batch.id, err) } @@ -878,7 +1024,7 @@ func (b *Batcher) monitorSpendAndNotify(ctx context.Context, sweep *sweep, b.wg.Add(1) go func() { defer b.wg.Done() - log.Infof("Batcher monitoring spend for swap %x", + log().Infof("Batcher monitoring spend for swap %x", sweep.swapHash[:6]) for { @@ -1042,6 +1188,16 @@ func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash, swapHash[:6], err) } + // Determine if CPFP mode is used for this sweep. + var cpfp bool + if b.cpfpHelper != nil { + cpfp, err = b.cpfpHelper.IsCpfpApplied(ctx, outpoint) + if err != nil { + return nil, fmt.Errorf("failed to determine CPFP "+ + "status for sweep %x: %w", swapHash[:6], err) + } + } + // Find minimum fee rate for the sweep. Use customFeeRate if it is // provided, otherwise use wallet's EstimateFeeRate. var minFeeRate chainfee.SatPerKWeight @@ -1057,7 +1213,7 @@ func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash, } } else { if s.ConfTarget == 0 { - log.Warnf("Fee estimation was requested for zero "+ + log().Warnf("Fee estimation was requested for zero "+ "confTarget for sweep %x.", swapHash[:6]) } minFeeRate, err = b.wallet.EstimateFeeRate(ctx, s.ConfTarget) @@ -1085,6 +1241,7 @@ func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash, destAddr: s.DestAddr, minFeeRate: minFeeRate, nonCoopHint: s.NonCoopHint, + cpfp: cpfp, }, nil } @@ -1095,7 +1252,9 @@ func (b *Batcher) newBatchConfig(maxTimeoutDistance int32) batchConfig { noBumping: b.customFeeRate != nil, txLabeler: b.txLabeler, customMuSig2Signer: b.customMuSig2Signer, + cpfpHelper: b.cpfpHelper, clock: b.clock, + chainParams: b.chainParams, } } diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index fbfb0d418..d89456dba 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "strings" "sync" "testing" "time" @@ -13,6 +14,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog" "github.com/lightninglabs/lndclient" @@ -109,18 +111,35 @@ func checkBatcherError(t *testing.T, err error) { } } -// getOnlyBatch makes sure the batcher has exactly one batch and returns it. -func getOnlyBatch(batcher *Batcher) *batch { - if len(batcher.batches) != 1 { - panic(fmt.Sprintf("getOnlyBatch called on a batcher having "+ - "%d batches", len(batcher.batches))) - } +// getBatches returns batches in thread-safe way. +func getBatches(ctx context.Context, batcher *Batcher) []*batch { + var batches []*batch + batcher.testRunInEventLoop(ctx, func() { + for _, batch := range batcher.batches { + batches = append(batches, batch) + } + }) + + return batches +} + +// tryGetOnlyBatch returns a single batch if there is exactly one batch, or nil. +func tryGetOnlyBatch(ctx context.Context, batcher *Batcher) *batch { + batches := getBatches(ctx, batcher) - for _, batch := range batcher.batches { - return batch + if len(batches) == 1 { + return batches[0] + } else { + return nil } +} + +// getOnlyBatch makes sure the batcher has exactly one batch and returns it. +func getOnlyBatch(t *testing.T, ctx context.Context, batcher *Batcher) *batch { + batches := getBatches(ctx, batcher) + require.Len(t, batches, 1) - panic("unreachable") + return batches[0] } // testSweepBatcherBatchCreation tests that sweep requests enter the expected @@ -186,7 +205,7 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Wait for tx to be published. @@ -236,7 +255,7 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, // Batcher should not create a second batch as timeout distance is small // enough. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Create a third sweep request that has more timeout distance than @@ -273,33 +292,43 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since the second batch got created we check that it registered its + // primary sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as timeout distance is greater // than the threshold require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since the second batch got created we check that it registered its - // primary sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps // in it. - for _, batch := range batcher.batches { - switch batch.primarySweepID { - case sweepReq1.SwapHash: - if len(batch.sweeps) != 2 { - return false - } + batches := getBatches(ctx, batcher) - case sweepReq3.SwapHash: - if len(batch.sweeps) != 1 { - return false + for _, batch := range batches { + var bad bool + + batch.testRunInEventLoop(ctx, func() { + switch batch.primarySweepID { + case sweepReq1.SwapHash: + if len(batch.sweeps) != 2 { + bad = true + } + + case sweepReq3.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } } + }) + + if bad { + return false } } @@ -480,24 +509,26 @@ func testTxLabeler(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return len(getBatches(ctx, batcher)) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Wait for tx to be published. <-lnd.TxPublishChannel // Find the batch and assign it to a local variable for easier access. var theBatch *batch - for _, btch := range batcher.batches { - if btch.primarySweepID == sweepReq1.SwapHash { - theBatch = btch - } + for _, btch := range getBatches(ctx, batcher) { + btch.testRunInEventLoop(ctx, func() { + if btch.primarySweepID == sweepReq1.SwapHash { + theBatch = btch + } + }) } // Now test the label. @@ -632,15 +663,15 @@ func testPublishErrorHandler(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return len(getBatches(ctx, batcher)) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // The first attempt to publish the batch tx is expected to fail. require.ErrorIs(t, <-publishErrorChan, testPublishError) @@ -710,26 +741,33 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return len(getBatches(ctx, batcher)) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Find the batch and assign it to a local variable for easier access. batch := &batch{} - for _, btch := range batcher.batches { - if btch.primarySweepID == sweepReq1.SwapHash { - batch = btch - } + for _, btch := range getBatches(ctx, batcher) { + btch.testRunInEventLoop(ctx, func() { + if btch.primarySweepID == sweepReq1.SwapHash { + batch = btch + } + }) } require.Eventually(t, func() bool { // Batch should have the sweep stored. - return len(batch.sweeps) == 1 + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) + + return numSweeps == 1 }, test.Timeout, eventuallyCheckFrequency) // The primary sweep id should be that of the first inserted sweep. @@ -744,7 +782,12 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // After receiving a height notification the batch will step again, // leading to a new spend monitoring. require.Eventually(t, func() bool { - return batch.currentHeight == 601 + var currentHeight int32 + batch.testRunInEventLoop(ctx, func() { + currentHeight = batch.currentHeight + }) + + return currentHeight == 601 }, test.Timeout, eventuallyCheckFrequency) // Wait for tx to be published. @@ -788,7 +831,12 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // The batch should eventually read the spend notification and progress // its state to closed. require.Eventually(t, func() bool { - return batch.state == Closed + var state batchState + batch.testRunInEventLoop(ctx, func() { + state = batch.state + }) + + return state == Closed }, test.Timeout, eventuallyCheckFrequency) err = lnd.NotifyHeight(604) @@ -802,7 +850,12 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // Eventually the batch receives the confirmation notification and // confirms itself. require.Eventually(t, func() bool { - return batch.isComplete() + var complete bool + batch.testRunInEventLoop(ctx, func() { + complete = batch.isComplete() + }) + + return complete }, test.Timeout, eventuallyCheckFrequency) } @@ -811,18 +864,26 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, type wrappedLogger struct { btclog.Logger + mu sync.Mutex + debugMessages []string infoMessages []string } // Debugf logs debug message. func (l *wrappedLogger) Debugf(format string, params ...interface{}) { + l.mu.Lock() + defer l.mu.Unlock() + l.debugMessages = append(l.debugMessages, format) l.Logger.Debugf(format, params...) } // Infof logs info message. func (l *wrappedLogger) Infof(format string, params ...interface{}) { + l.mu.Lock() + defer l.mu.Unlock() + l.infoMessages = append(l.infoMessages, format) l.Logger.Infof(format, params...) } @@ -930,17 +991,18 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Eventually the batch is launched. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Replace the logger in the batch with wrappedLogger to watch messages. - var batch1 *batch - for _, batch := range batcher.batches { - batch1 = batch - } - require.NotNil(t, batch1) - testLogger := &wrappedLogger{Logger: batch1.log} - batch1.log = testLogger + batch1 := getOnlyBatch(t, ctx, batcher) + var testLogger *wrappedLogger + batch1.testRunInEventLoop(ctx, func() { + testLogger = &wrappedLogger{ + Logger: batch1.log(), + } + batch1.setLog(testLogger) + }) // Advance the clock to publishDelay. It will trigger the publishDelay // timer, but won't result in publishing, because of initialDelay. @@ -950,7 +1012,10 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for batch publishing to be skipped, because initialDelay has not // ended. require.EventuallyWithT(t, func(c *assert.CollectT) { - require.Contains(t, testLogger.debugMessages, stillWaitingMsg) + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + + assert.Contains(c, testLogger.debugMessages, stillWaitingMsg) }, test.Timeout, eventuallyCheckFrequency) // Advance the clock to the end of initialDelay. @@ -975,16 +1040,19 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) + // Make sure the batch has one sweep. + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 + return numSweeps == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have stored the batch. @@ -1020,25 +1088,6 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for the batcher to be initialized. <-batcher.initDone - // Wait for batch to load. - require.Eventually(t, func() bool { - // Make sure that the sweep was stored - if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { - return false - } - - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { - return false - } - - // Get the batch. - batch := getOnlyBatch(batcher) - - // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // Expect a timer to be set: 0 (instead of publishDelay), and // RegisterSpend to be called. The order is not determined, so catch // these actions from two separate goroutines. @@ -1051,6 +1100,9 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Since a batch was created we check that it registered for its // primary sweep's spend. <-lnd.RegisterSpendChannel + + // Wait for tx to be published. + <-lnd.TxPublishChannel }() wg3.Add(1) @@ -1065,6 +1117,28 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for RegisterSpend and for timer registration. wg3.Wait() + // Wait for batch to load. + require.Eventually(t, func() bool { + // Make sure that the sweep was stored + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { + return false + } + + // Make sure the batch has one sweep. + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) + + // Make sure the batch has one sweep. + return numSweeps == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Expect one timer: publishDelay (0). wantDelays = []time.Duration{0} require.Equal(t, wantDelays, delays) @@ -1073,9 +1147,6 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { now = now.Add(time.Millisecond) testClock.SetTime(now) - // Wait for tx to be published. - <-lnd.TxPublishChannel - // Tick tock next block. err = lnd.NotifyHeight(601) require.NoError(t, err) @@ -1226,15 +1297,18 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { require.Equal(t, wantDelays, delays) // Replace the logger in the batch with wrappedLogger to watch messages. - var batch2 *batch - for _, batch := range batcher.batches { - if batch.id != batch1.id { - batch2 = batch - } + var testLogger2 *wrappedLogger + for _, batch := range getBatches(ctx, batcher) { + batch.testRunInEventLoop(ctx, func() { + if batch.id != batch1.id { + testLogger2 = &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger2) + } + }) } - require.NotNil(t, batch2) - testLogger2 := &wrappedLogger{Logger: batch2.log} - batch2.log = testLogger2 + require.NotNil(t, testLogger2) // Add another sweep which is urgent. It will go to the same batch // to make sure minimum timeout is calculated properly. @@ -1274,7 +1348,10 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for sweep to be added to the batch. require.EventuallyWithT(t, func(c *assert.CollectT) { - require.Contains(t, testLogger2.infoMessages, "adding sweep %x") + testLogger2.mu.Lock() + defer testLogger2.mu.Unlock() + + assert.Contains(c, testLogger2.infoMessages, "adding sweep %x") }, test.Timeout, eventuallyCheckFrequency) // Advance the clock by publishDelay. Don't wait largeInitialDelay. @@ -1298,7 +1375,7 @@ func testMaxSweepsPerBatch(t *testing.T, store testStore, batcherStore testBatcherStore) { // Disable logging, because this test is very noisy. - oldLogger := log + oldLogger := log() UseLogger(build.NewSubLogger("SWEEP", nil)) defer UseLogger(oldLogger) @@ -1399,15 +1476,19 @@ func testMaxSweepsPerBatch(t *testing.T, store testStore, // Eventually the batches are launched and all the sweeps are added. require.Eventually(t, func() bool { // Make sure all the batches have started. - if len(batcher.batches) != expectedBatches { + batches := getBatches(ctx, batcher) + if len(batches) != expectedBatches { return false } // Make sure all the sweeps were added. sweepsNum := 0 - for _, batch := range batcher.batches { - sweepsNum += len(batch.sweeps) + for _, batch := range batches { + batch.testRunInEventLoop(ctx, func() { + sweepsNum += len(batch.sweeps) + }) } + return sweepsNum == swapsNum }, test.Timeout, eventuallyCheckFrequency) @@ -1588,20 +1669,27 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // Batcher should create a batch for the sweeps. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Find the batch and store it in a local variable for easier access. b := &batch{} - for _, btch := range batcher.batches { - if btch.primarySweepID == sweepReq1.SwapHash { - b = btch - } + for _, btch := range getBatches(ctx, batcher) { + btch.testRunInEventLoop(ctx, func() { + if btch.primarySweepID == sweepReq1.SwapHash { + b = btch + } + }) } // Batcher should contain all sweeps. require.Eventually(t, func() bool { - return len(b.sweeps) == 3 + var numSweeps int + b.testRunInEventLoop(ctx, func() { + numSweeps = len(b.sweeps) + }) + + return numSweeps == 3 }, test.Timeout, eventuallyCheckFrequency) // Verify that the batch has a primary sweep id that matches the first @@ -1650,20 +1738,25 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // Eventually the batch reads the notification and proceeds to a closed // state. require.Eventually(t, func() bool { - return b.state == Closed + var state batchState + b.testRunInEventLoop(ctx, func() { + state = b.state + }) + + return state == Closed }, test.Timeout, eventuallyCheckFrequency) + // Since second batch was created we check that it registered for its + // primary sweep's spend. + <-lnd.RegisterSpendChannel + // While handling the spend notification the batch should detect that // some sweeps did not appear in the spending tx, therefore it redirects // them back to the batcher and the batcher inserts them in a new batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since second batch was created we check that it registered for its - // primary sweep's spend. - <-lnd.RegisterSpendChannel - // We mock the confirmation notification. lnd.ConfChannel <- &chainntnfs.TxConfirmation{ Tx: spendingTx, @@ -1678,26 +1771,35 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // confirmation forever. <-lnd.TxPublishChannel + // Re-add one of remaining sweeps to trigger removing the completed + // batch from the batcher. + require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Eventually the batch receives the confirmation notification, // gracefully exits and the batcher deletes it. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Find the other batch, which includes the sweeps that did not appear // in the spending tx. - b = &batch{} - for _, btch := range batcher.batches { - b = btch - } + b = getOnlyBatch(t, ctx, batcher) // After all the sweeps enter, it should contain 2 sweeps. require.Eventually(t, func() bool { - return len(b.sweeps) == 2 + var numSweeps int + b.testRunInEventLoop(ctx, func() { + numSweeps = len(b.sweeps) + }) + return numSweeps == 2 }, test.Timeout, eventuallyCheckFrequency) // The batch should be in an open state. - require.Equal(t, b.state, Open) + var state batchState + b.testRunInEventLoop(ctx, func() { + state = b.state + }) + require.Equal(t, state, Open) } // testSweepBatcherNonWalletAddr tests that sweep requests that sweep to a non @@ -1753,16 +1855,16 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel @@ -1803,16 +1905,16 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq2)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as first batch is a non wallet // addr batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for second batch to be published. <-lnd.TxPublishChannel @@ -1850,38 +1952,47 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a new batch as timeout distance is greater than // the threshold require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return len(getBatches(ctx, batcher)) == 3 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published for 3rd batch. <-lnd.TxPublishChannel require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps // in it. - for _, batch := range batcher.batches { - switch batch.primarySweepID { - case sweepReq1.SwapHash: - if len(batch.sweeps) != 1 { - return false - } - - case sweepReq2.SwapHash: - if len(batch.sweeps) != 1 { - return false + batches := getBatches(ctx, batcher) + for _, batch := range batches { + var bad bool + + batch.testRunInEventLoop(ctx, func() { + switch batch.primarySweepID { + case sweepReq1.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } + + case sweepReq2.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } + + case sweepReq3.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } } + }) - case sweepReq3.SwapHash: - if len(batch.sweeps) != 1 { - return false - } + if bad { + return false } } @@ -2103,16 +2214,16 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel @@ -2124,7 +2235,7 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Batcher should not create a second batch as timeout distance is small // enough. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Publish a block to trigger batch 1 republishing. @@ -2137,32 +2248,32 @@ func testSweepBatcherComposite(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as this sweep pays to a non // wallet address. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the second batch to be published (1 sweep). tx = <-lnd.TxPublishChannel require.Equal(t, 1, len(tx.TxIn)) require.NoError(t, batcher.AddSweep(&sweepReq4)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a third batch as timeout distance is greater // than the threshold. require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return len(getBatches(ctx, batcher)) == 3 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the third batch to be published (1 sweep). tx = <-lnd.TxPublishChannel require.Equal(t, 1, len(tx.TxIn)) @@ -2181,21 +2292,21 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Batcher should not create a fourth batch as timeout distance is small // enough for it to join the last batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return len(getBatches(ctx, batcher)) == 3 }, test.Timeout, eventuallyCheckFrequency) require.NoError(t, batcher.AddSweep(&sweepReq6)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a fourth batch as this sweep pays to a non // wallet address. require.Eventually(t, func() bool { - return len(batcher.batches) == 4 + return len(getBatches(ctx, batcher)) == 4 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the 4th batch to be published (1 sweep). tx = <-lnd.TxPublishChannel require.Equal(t, 1, len(tx.TxIn)) @@ -2203,27 +2314,35 @@ func testSweepBatcherComposite(t *testing.T, store testStore, require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps in // it. - for _, batch := range batcher.batches { - switch batch.primarySweepID { - case sweepReq1.SwapHash: - if len(batch.sweeps) != 2 { - return false - } - - case sweepReq3.SwapHash: - if len(batch.sweeps) != 1 { - return false - } - - case sweepReq4.SwapHash: - if len(batch.sweeps) != 2 { - return false + batches := getBatches(ctx, batcher) + for _, batch := range batches { + var bad bool + batch.testRunInEventLoop(ctx, func() { + switch batch.primarySweepID { + case sweepReq1.SwapHash: + if len(batch.sweeps) != 2 { + bad = true + } + + case sweepReq3.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } + + case sweepReq4.SwapHash: + if len(batch.sweeps) != 2 { + bad = true + } + + case sweepReq6.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } } + }) - case sweepReq6.SwapHash: - if len(batch.sweeps) != 1 { - return false - } + if bad { + return false } } @@ -2360,8 +2479,11 @@ func testRestoringEmptyBatch(t *testing.T, store testStore, require.Eventually(t, func() bool { // Make sure that the sweep was stored and we have exactly one // active batch. - return batcherStore.AssertSweepStored(sweepReq.SwapHash) && - len(batcher.batches) == 1 + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have only one batch stored (as we dropped the dormant @@ -2577,9 +2699,14 @@ func testHandleSweepTwice(t *testing.T, backend testStore, require.Eventually(t, func() bool { // Make sure that the sweep was stored and we have exactly one // active batch. - return batcherStore.AssertSweepStored(sweepReq1.SwapHash) && - batcherStore.AssertSweepStored(sweepReq2.SwapHash) && - len(batcher.batches) == 2 + if !batcherStore.AssertSweepStored(sweepReq1.SwapHash) { + return false + } + if !batcherStore.AssertSweepStored(sweepReq2.SwapHash) { + return false + } + + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) // Change the second sweep so that it can be added to the first batch. @@ -2608,7 +2735,8 @@ func testHandleSweepTwice(t *testing.T, backend testStore, require.Eventually(t, func() bool { // Make sure there are two batches. - batches := batcher.batches + batches := getBatches(ctx, batcher) + if len(batches) != 2 { return false } @@ -2624,7 +2752,13 @@ func testHandleSweepTwice(t *testing.T, backend testStore, } // Make sure the second batch has the second sweep. - sweep2, has := secondBatch.sweeps[sweepReq2.SwapHash] + var ( + sweep2 sweep + has bool + ) + secondBatch.testRunInEventLoop(ctx, func() { + sweep2, has = secondBatch.sweeps[sweepReq2.SwapHash] + }) if !has { return false } @@ -2635,8 +2769,15 @@ func testHandleSweepTwice(t *testing.T, backend testStore, // Make sure each batch has one sweep. If the second sweep was added to // both batches, the following check won't pass. - for _, batch := range batcher.batches { - require.Equal(t, 1, len(batch.sweeps)) + batches := getBatches(ctx, batcher) + for _, batch := range batches { + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) + + // Make sure the batch has one sweep. + require.Equal(t, 1, numSweeps) } // Publish a block to trigger batch 2 republishing. @@ -2729,21 +2870,28 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) + // Make sure the batch has one sweep. + var ( + numSweeps int + confTarget int32 + ) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + confTarget = batch.cfg.batchConfTarget + }) // Make sure the batch has one sweep. - if len(batch.sweeps) != 1 { + if numSweeps != 1 { return false } // Make sure the batch has proper batchConfTarget. - return batch.cfg.batchConfTarget == 123 + return confTarget == 123 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have stored the batch. @@ -2772,6 +2920,12 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, // Wait for the batcher to be initialized. <-batcher.initDone + // Expect registration for spend notification. + <-lnd.RegisterSpendChannel + + // Wait for tx to be published. + <-lnd.TxPublishChannel + // Wait for batch to load. require.Eventually(t, func() bool { // Make sure that the sweep was stored @@ -2779,26 +2933,28 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) + // Make sure the batch has one sweep. + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 + return numSweeps == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure batchConfTarget was preserved. - require.Equal(t, 123, int(getOnlyBatch(batcher).cfg.batchConfTarget)) - - // Expect registration for spend notification. - <-lnd.RegisterSpendChannel - - // Wait for tx to be published. - <-lnd.TxPublishChannel + batch := getOnlyBatch(t, ctx, batcher) + var confTarget int32 + batch.testRunInEventLoop(ctx, func() { + confTarget = batch.cfg.batchConfTarget + }) + require.Equal(t, int32(123), confTarget) // Now make the batcher quit by canceling the context. cancel() @@ -2810,11 +2966,22 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, type sweepFetcherMock struct { store map[lntypes.Hash]*SweepInfo + mu sync.Mutex +} + +func (f *sweepFetcherMock) setSweep(hash lntypes.Hash, info *SweepInfo) { + f.mu.Lock() + defer f.mu.Unlock() + + f.store[hash] = info } func (f *sweepFetcherMock) FetchSweep(ctx context.Context, hash lntypes.Hash) ( *SweepInfo, error) { + f.mu.Lock() + defer f.mu.Unlock() + return f.store[hash], nil } @@ -2932,21 +3099,27 @@ func testSweepFetcher(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + // Try to get the batch. + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) - // Make sure the batch has one sweep. - if len(batch.sweeps) != 1 { + var ( + numSweeps int + confTarget int32 + ) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + confTarget = batch.cfg.batchConfTarget + }) + if numSweeps != 1 { return false } // Make sure the batch has proper batchConfTarget. - return batch.cfg.batchConfTarget == 123 + return confTarget == 123 }, test.Timeout, eventuallyCheckFrequency) // Get the published transaction and check the fee rate. @@ -3279,7 +3452,7 @@ func testWithMixedBatch(t *testing.T, store testStore, if i == 0 { sweepInfo.NonCoopHint = true } - sweepFetcher.store[swapHash] = sweepInfo + sweepFetcher.setSweep(swapHash, sweepInfo) // Create sweep request. sweepReq := SweepRequest{ @@ -3433,7 +3606,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, ) require.NoError(t, err) - sweepFetcher.store[swapHash] = &SweepInfo{ + sweepFetcher.setSweep(swapHash, &SweepInfo{ Preimage: preimages[i], NonCoopHint: nonCoopHints[i], @@ -3445,7 +3618,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, HTLC: *htlc, HTLCSuccessEstimator: htlc.AddSuccessToEstimator, DestAddr: destAddr, - } + }) // Create sweep request. sweepReq := SweepRequest{ @@ -3794,9 +3967,17 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is feeRateMedium. - batch := getOnlyBatch(batcher) - require.Len(t, batch.sweeps, 1) - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + batch := getOnlyBatch(t, ctx, batcher) + var ( + numSweeps int + cachedFeeRate chainfee.SatPerKWeight + ) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, 1, numSweeps) + require.Equal(t, feeRateMedium, cachedFeeRate) // Now decrease the fee of sweep1. setFeeRate(swapHash1, feeRateLow) @@ -3810,7 +3991,10 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is still feeRateMedium. - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + batch.testRunInEventLoop(ctx, func() { + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, feeRateMedium, cachedFeeRate) // Add sweep2, with feeRateMedium. swapHash2 := lntypes.Hash{2, 2, 2} @@ -3856,8 +4040,12 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is still feeRateMedium. - require.Len(t, batch.sweeps, 2) - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, 2, numSweeps) + require.Equal(t, feeRateMedium, cachedFeeRate) // Now update fee rate of second sweep (which is not primary) to // feeRateHigh. Fee rate of sweep 1 is still feeRateLow. @@ -3873,14 +4061,1362 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate increased to feeRateHigh. - require.Equal(t, feeRateHigh, batch.rbfCache.FeeRate) + batch.testRunInEventLoop(ctx, func() { + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, feeRateHigh, cachedFeeRate) } -// TestSweepBatcherBatchCreation tests that sweep requests enter the expected -// batch based on their timeout distance. -func TestSweepBatcherBatchCreation(t *testing.T) { - runTests(t, testSweepBatcherBatchCreation) -} +// mockCpfpHelper implements CpfpHelper interface and stores arguments passed +// in its methods to validate correctness of function publishWithCPFP. +type mockCpfpHelper struct { + // onlineOutpoints specifies which outpoints are capable of + // participating in presigning. + onlineOutpoints map[wire.OutPoint]bool + + // presignedBatches is the collection of presigned batches. + presignedBatches []*wire.MsgTx + + // mu should be hold by all the public methods of this type. + mu sync.Mutex + + // cleanupCalled is a channel where an element is sent every time + // CleanupTransactions is called. + cleanupCalled chan struct{} +} + +// newMockCpfpHelper returns new instance of mockCpfpHelper. +func newMockCpfpHelper() *mockCpfpHelper { + return &mockCpfpHelper{ + onlineOutpoints: make(map[wire.OutPoint]bool), + cleanupCalled: make(chan struct{}), + } +} + +// SetOutpointOnline changes the online status of an outpoint. +func (h *mockCpfpHelper) SetOutpointOnline(op wire.OutPoint, online bool) { + h.mu.Lock() + defer h.mu.Unlock() + + h.onlineOutpoints[op] = online +} + +// findOfflineInputs returns inputs of a tx which are offline. +func (h *mockCpfpHelper) findOfflineInputs(tx *wire.MsgTx) []wire.OutPoint { + offline := make([]wire.OutPoint, 0, len(tx.TxIn)) + for _, txIn := range tx.TxIn { + if !h.onlineOutpoints[txIn.PreviousOutPoint] { + offline = append(offline, txIn.PreviousOutPoint) + } + } + + return offline +} + +// sign signs the transaction. +func (h *mockCpfpHelper) sign(tx *wire.MsgTx) { + // Sign all the inputs. + for i := range tx.TxIn { + tx.TxIn[i].Witness = wire.TxWitness{ + make([]byte, 64), + } + } +} + +// getTxFeerate returns fee rate of a transaction. +func (h *mockCpfpHelper) getTxFeerate(tx *wire.MsgTx, + inputAmt btcutil.Amount) chainfee.SatPerKWeight { + + // "Sign" tx's copy to assess the weight. + tx2 := tx.Copy() + h.sign(tx2) + weight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(tx2)), + ) + fee := btcutil.Amount(tx.TxOut[0].Value) - inputAmt + + return chainfee.NewSatPerKWeight(fee, weight) +} + +// IsCpfpApplied returns if the input was previously used in any call to the +// SetOutpointOnline method. +func (h *mockCpfpHelper) IsCpfpApplied(ctx context.Context, + input wire.OutPoint) (bool, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + _, has := h.onlineOutpoints[input] + + return has, nil +} + +// Presign tries to presign the transaction. It succeeds if all the inputs +// are online. In case of success it adds the transaction to presignedBatches. +func (h *mockCpfpHelper) Presign(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount) error { + + h.mu.Lock() + defer h.mu.Unlock() + + if offline := h.findOfflineInputs(tx); len(offline) != 0 { + return fmt.Errorf("some inputs of tx are offline: %v", offline) + } + + tx = tx.Copy() + h.sign(tx) + h.presignedBatches = append(h.presignedBatches, tx) + + return nil +} + +// DestPkScript returns destination pkScript used in 1:1 presigned tx. +func (h *mockCpfpHelper) DestPkScript(ctx context.Context, + inputs []wire.OutPoint) ([]byte, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + inputsSet := make(map[wire.OutPoint]struct{}, len(inputs)) + for _, input := range inputs { + inputsSet[input] = struct{}{} + } + if len(inputsSet) != len(inputs) { + return nil, fmt.Errorf("duplicate inputs") + } + + inputsMatch := func(tx *wire.MsgTx) bool { + if len(tx.TxIn) != len(inputsSet) { + return false + } + + for _, txIn := range tx.TxIn { + if _, has := inputsSet[txIn.PreviousOutPoint]; !has { + return false + } + } + + return true + } + + for _, tx := range h.presignedBatches { + if inputsMatch(tx) { + return tx.TxOut[0].PkScript, nil + } + } + + return nil, fmt.Errorf("tx sweeping inputs %v not found", inputs) +} + +// SignTx tries to sign the transaction. If all the inputs are online, it signs +// the exact transaction passed and adds it to presignedBatches. Otherwise it +// looks for a transaction in presignedBatches satisfying the criteria. +func (h *mockCpfpHelper) SignTx(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount, + minRelayFee chainfee.SatPerKWeight) (*wire.MsgTx, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + // If all the inputs are online, sign this exact transaction. + if offline := h.findOfflineInputs(tx); len(offline) == 0 { + tx = tx.Copy() + h.sign(tx) + + // Add to the collection. + h.presignedBatches = append(h.presignedBatches, tx) + + return tx, nil + } + + // Find feerate of input tx. + inputFeeRate := h.getTxFeerate(tx, inputAmt) + + // Try to find a transaction in the collection satisfying all the + // criteria of CpfpHelper.SignTx. If there are many such transactions, + // select a transaction with feerate which is the closest to the feerate + // of the input tx. + var ( + bestTx *wire.MsgTx + bestFeerateDistance chainfee.SatPerKWeight + ) + for _, candidate := range h.presignedBatches { + err := CheckSignedTx(tx, candidate, inputAmt, minRelayFee) + if err != nil { + continue + } + + feeRate := h.getTxFeerate(candidate, inputAmt) + feeRateDistance := feeRate - inputFeeRate + if feeRateDistance < 0 { + feeRateDistance = -feeRateDistance + } + + if bestTx == nil || feeRateDistance < bestFeerateDistance { + bestTx = candidate + bestFeerateDistance = feeRateDistance + } + } + + if bestTx == nil { + return nil, fmt.Errorf("no such presigned tx found") + } + + return bestTx.Copy(), nil +} + +// LoadTx tries to load the transaction by txid. It scans presignedBatches. +func (h *mockCpfpHelper) LoadTx(ctx context.Context, + txid chainhash.Hash) (*wire.MsgTx, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + for _, tx := range h.presignedBatches { + if tx.TxHash() == txid { + return tx.Copy(), nil + } + } + + return nil, fmt.Errorf("tx with ID %v not found", txid) +} + +// CleanupTransactions removes all transactions related to any of the outpoints. +func (h *mockCpfpHelper) CleanupTransactions(ctx context.Context, + inputs []wire.OutPoint) error { + + h.mu.Lock() + defer h.mu.Unlock() + + inputsSet := make(map[wire.OutPoint]struct{}, len(inputs)) + for _, input := range inputs { + inputsSet[input] = struct{}{} + } + if len(inputsSet) != len(inputs) { + return fmt.Errorf("duplicate inputs") + } + + var presignedBatches []*wire.MsgTx + + // Filter out transactions spending any of the inputs passed. + for _, tx := range h.presignedBatches { + var match bool + for _, txIn := range tx.TxIn { + if _, has := inputsSet[txIn.PreviousOutPoint]; has { + match = true + break + } + } + + if !match { + presignedBatches = append(presignedBatches, tx) + } + } + + h.presignedBatches = presignedBatches + + h.cleanupCalled <- struct{}{} + + return nil +} + +// dummySweepFetcherMock implements SweepFetcher by returning blank SweepInfo. +// It is used in TestCPFP, because it doesn't use any fields from SweepInfo. +type dummySweepFetcherMock struct { +} + +// FetchSweep returns blank SweepInfo. +func (f *dummySweepFetcherMock) FetchSweep(ctx context.Context, + hash lntypes.Hash) (*SweepInfo, error) { + + return &SweepInfo{ + // Set Timeout to prevent warning messages about timeout=0. + Timeout: 1000, + }, nil +} + +// testCPFP_input1_offline_then_input2 tests CPFP mode for the following +// scenario: first input is added, then goes offline, then feerate grows, one of +// presigned transactions is published, feerate grows further, then CPFP is used +// and then another online input is added and is assigned to another batch. +func testCPFP_input1_offline_then_input2(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper)) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + + // This should fail, because the input is offline. + cpfpHelper.SetOutpointOnline(op1, false) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.Error(t, err) + require.ErrorContains(t, err, "offline") + + // Enable the input and try again. + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + + // Increase fee rate and turn off the input, so it can't sign updated + // tx. The feerate is close to the feerate of one of presigned txs, so + // there is no CPFP. + setFeeRate(feeRateMedium) + cpfpHelper.SetOutpointOnline(op1, false) + + // Deliver sweep request to batcher. + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(987034), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Make sure the fee rate is feeRateMedium. + batch := getOnlyBatch(t, ctx, batcher) + var ( + numSweeps int + cachedFeeRate chainfee.SatPerKWeight + ) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, 1, numSweeps) + require.Equal(t, feeRateMedium, cachedFeeRate) + + // Raise feerate and trigger new publishing. The parent tx should be the + // same plus a CPFP tx. + setFeeRate(feeRateHigh) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(601)) + + parent2 := <-lnd.TxPublishChannel + child := <-lnd.TxPublishChannel + require.Equal(t, parent.TxHash(), parent2.TxHash()) + require.Len(t, child.TxIn, 1) + require.Len(t, child.TxOut, 1) + parentOp := wire.OutPoint{ + Hash: parent2.TxHash(), + Index: 0, + } + require.Equal(t, parentOp, child.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(966600), child.TxOut[0].Value) + require.Equal(t, batchPkScript, child.TxOut[0].PkScript) + + // Now add another input. It is online, but the first input is still + // offline, so another input should go to another batch. + swapHash2 := lntypes.Hash{2, 2, 2} + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2}, + Index: 2, + } + sweepReq2 := SweepRequest{ + SwapHash: swapHash2, + Value: 2_000_000, + Outpoint: op2, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op2, true) + err = batcher.PresignSweep(ctx, op2, 2_000_000, destAddr) + require.NoError(t, err) + + // Deliver sweep request to batcher. + require.NoError(t, batcher.AddSweep(&sweepReq2)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + batch2 := <-lnd.TxPublishChannel + require.Len(t, batch2.TxIn, 1) + require.Len(t, batch2.TxOut, 1) + require.Equal(t, op2, batch2.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(1984160), batch2.TxOut[0].Value) + require.Equal(t, batchPkScript, batch2.TxOut[0].PkScript) + + // Now confirm the first batch. Make sure its presigned transactions + // were removed, but not the transactions of the second batch. + presignedSize1 := len(cpfpHelper.presignedBatches) + + parent2hash := parent2.TxHash() + spendDetail := &chainntnfs.SpendDetail{ + SpentOutPoint: &sweepReq1.Outpoint, + SpendingTx: parent2, + SpenderTxHash: &parent2hash, + SpenderInputIndex: 0, + SpendingHeight: 601, + } + lnd.SpendChannel <- spendDetail + <-lnd.RegisterConfChannel + require.NoError(t, lnd.NotifyHeight(604)) + lnd.ConfChannel <- &chainntnfs.TxConfirmation{ + Tx: parent2, + } + + <-cpfpHelper.cleanupCalled + + presignedSize2 := len(cpfpHelper.presignedBatches) + require.Greater(t, presignedSize2, 0) + require.Greater(t, presignedSize1, presignedSize2) + + // Make sure we still have presigned transactions for the second batch. + cpfpHelper.SetOutpointOnline(op2, false) + _, err = cpfpHelper.SignTx( + ctx, batch2, 2_000_000, chainfee.FeePerKwFloor, + ) + require.NoError(t, err) +} + +// testCPFP_two_inputs_one_goes_offline tests CPFP mode for the following +// scenario: two online inputs are added, then one of them goes offline, then +// feerate grows and a presigned transaction is used. +func testCPFP_two_inputs_one_goes_offline(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Add second sweep. + swapHash2 := lntypes.Hash{2, 2, 2} + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2}, + Index: 2, + } + sweepReq2 := SweepRequest{ + SwapHash: swapHash2, + Value: 2_000_000, + Outpoint: op2, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op2, true) + err = batcher.PresignSweep(ctx, op2, 2_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq2)) + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 2) + require.Len(t, parent.TxOut, 1) + require.ElementsMatch( + t, []wire.OutPoint{op1, op2}, + []wire.OutPoint{ + parent.TxIn[0].PreviousOutPoint, + parent.TxIn[1].PreviousOutPoint, + }, + ) + require.Equal(t, int64(2993740), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Now turn off the second input, raise feerate and trigger new + // publishing. The feerate is close to one of the presigned feerates, + // so this should result in RBF. + cpfpHelper.SetOutpointOnline(op2, false) + setFeeRate(feeRateMedium) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, batcher.AddSweep(&sweepReq2)) + require.NoError(t, lnd.NotifyHeight(601)) + + parent2 := <-lnd.TxPublishChannel + require.NotEqual(t, parent.TxHash(), parent2.TxHash()) + require.Len(t, parent2.TxIn, 2) + require.Len(t, parent2.TxOut, 1) + require.ElementsMatch( + t, []wire.OutPoint{op1, op2}, + []wire.OutPoint{ + parent.TxIn[0].PreviousOutPoint, + parent.TxIn[1].PreviousOutPoint, + }, + ) + require.Equal(t, int64(2979503), parent2.TxOut[0].Value) + require.Equal(t, batchPkScript, parent2.TxOut[0].PkScript) +} + +// testCPFP_cpfp_previous_version tests CPFP mode for the following scenario: +// one input is added, a transaction is published, then the input goes offline +// and feerate grows, RBF is attempted, but broadcast fails, so the batcher +// CPFPs previously published version. +func testCPFP_cpfp_previous_version(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(996040), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Now turn off the first input, raise feerate and trigger new + // publishing, which will fail. + var failedToPublishTx *wire.MsgTx + lnd.PublishHandler = func(ctx context.Context, tx *wire.MsgTx, + label string) error { + + // We should fail the first publishing, which is a sweep, + // but we shouldn't fail CPFP publishing. + if strings.HasPrefix(label, cpfpLabelPrefix) { + return nil + } + + failedToPublishTx = tx + + return fmt.Errorf("test error") + } + cpfpHelper.SetOutpointOnline(op1, false) + setFeeRate(feeRateMedium) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(601)) + + child := <-lnd.TxPublishChannel + require.NotEqual(t, parent.TxHash(), child.TxHash()) + require.Len(t, child.TxIn, 1) + require.Len(t, child.TxOut, 1) + require.Equal(t, wire.OutPoint{ + Hash: parent.TxHash(), + Index: 0, + }, child.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(974950), child.TxOut[0].Value) + + // Make sure the failed attempt used higher feerate than parent. + require.Equal(t, int64(987034), failedToPublishTx.TxOut[0].Value) +} + +// testCPFP_no_cpfp_if_all_online tests CPFP mode for the following scenario: +// one input is added, a transaction is published, then feerate grows, RBF is +// attempted, but broadcast fails, but CPFP is not used, because all the inputs +// are online (which is deduced by SignTx signing a tx with the same feerate as +// requested). +func testCPFP_no_cpfp_if_all_online(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(996040), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Replace the logger in the batch with wrappedLogger to watch messages. + batch := getOnlyBatch(t, ctx, batcher) + testLogger := &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger) + + // Now turn off the first input, raise feerate and trigger new + // publishing, which will fail. + lnd.PublishHandler = func(ctx context.Context, tx *wire.MsgTx, + label string) error { + + return fmt.Errorf("test error") + } + setFeeRate(feeRateMedium) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(601)) + + // Wait for batcher to log that CPFP is not needed. + require.EventuallyWithT(t, func(c *assert.CollectT) { + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + + assert.Contains( + c, testLogger.infoMessages, "CPFP is not needed", + ) + }, test.Timeout, eventuallyCheckFrequency) +} + +// testCPFP_first_publish_fails tests CPFP mode for the following scenario: +// one input is added and goes offline, feerate grows a transaction is attempted +// to be published, but fails, no CPFP is attempted. Then the input goes online +// and is published being signed online. +func testCPFP_first_publish_fails(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + cpfpHelper.SetOutpointOnline(op1, false) + + // Make sure that publish attempt fails. + lnd.PublishHandler = func(ctx context.Context, tx *wire.MsgTx, + label string) error { + + return fmt.Errorf("test error") + } + + // Add the sweep, triggering the publish attempt. + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Replace the logger in the batch with wrappedLogger to watch messages. + batch := getOnlyBatch(t, ctx, batcher) + testLogger := &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger) + + // Trigger another publish attempt in case "CPFP is not needed" was + // logged before we installed the logger watcher. + require.NoError(t, lnd.NotifyHeight(601)) + + // Wait for batcher to log that CPFP is not needed. + require.EventuallyWithT(t, func(c *assert.CollectT) { + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + + assert.Contains( + c, testLogger.infoMessages, "CPFP is not needed", + ) + }, test.Timeout, eventuallyCheckFrequency) + + // Now turn on the first input, raise feerate and trigger new + // publishing, which should succeed. + lnd.PublishHandler = nil + setFeeRate(feeRateMedium) + cpfpHelper.SetOutpointOnline(op1, true) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(602)) + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(988120), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) +} + +// testCPFP_cpfp_publishing_fails tests CPFP mode for the following scenario: +// one input is added, a transaction is published, then the input goes offline +// and feerate grows, RBF is published and then CPFP is attempted to achieve +// the exact desired fee rate, but fails to be published. After then another +// block comes in and both the parent and the child are published and this +// succeeds. +func testCPFP_cpfp_publishing_fails(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(996040), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Replace the logger in the batch with wrappedLogger to watch messages. + batch := getOnlyBatch(t, ctx, batcher) + testLogger := &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger) + + // Now turn off the first input, raise feerate and trigger new + // publishing, which will succeed. But then the CPFP will fail. + var failedToPublishTx *wire.MsgTx + lnd.PublishHandler = func(ctx context.Context, tx *wire.MsgTx, + label string) error { + + // We should fail the CPFP only. + if strings.HasPrefix(label, cpfpLabelPrefix) { + failedToPublishTx = tx + + return fmt.Errorf("test error") + } + + return nil + } + cpfpHelper.SetOutpointOnline(op1, false) + setFeeRate(feeRateHigh) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(601)) + + // Expect new version of the batch to be published. This is one + // of the presigned transactions. + parent2 := <-lnd.TxPublishChannel + require.NotEqual(t, parent.TxHash(), parent2.TxHash()) + require.Len(t, parent2.TxIn, 1) + require.Len(t, parent2.TxOut, 1) + require.Equal(t, op1, parent2.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(987034), parent2.TxOut[0].Value) + require.Equal(t, batchPkScript, parent2.TxOut[0].PkScript) + + // Wait for batcher to log that CPFP has failed. + require.Eventually(t, func() bool { + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + + for _, msg := range testLogger.infoMessages { + match := strings.Contains( + msg, "failed to publish child tx", + ) + if match { + return true + } + } + + return false + }, test.Timeout, eventuallyCheckFrequency) + + // Make sure that the failed to publish tx is our expected CPFP + // spending parent2. + require.Len(t, failedToPublishTx.TxIn, 1) + require.Len(t, failedToPublishTx.TxOut, 1) + require.Equal(t, wire.OutPoint{ + Hash: parent2.TxHash(), + Index: 0, + }, failedToPublishTx.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(966600), failedToPublishTx.TxOut[0].Value) + require.Equal(t, batchPkScript, failedToPublishTx.TxOut[0].PkScript) + + // Great, now les all published transactions pass and trigger another + // publishing. + lnd.PublishHandler = nil + require.NoError(t, lnd.NotifyHeight(602)) + + // Expect a parent and a child to be published. + parent2a := <-lnd.TxPublishChannel + require.Equal(t, parent2.TxHash(), parent2a.TxHash()) + + child := <-lnd.TxPublishChannel + require.Len(t, child.TxIn, 1) + require.Len(t, child.TxOut, 1) + require.Equal(t, wire.OutPoint{ + Hash: parent2a.TxHash(), + Index: 0, + }, child.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(966600), child.TxOut[0].Value) + require.Equal(t, batchPkScript, child.TxOut[0].PkScript) +} + +// testCPFP_cpfp_and_regular_sweeps tests a combination of CPFP mode and regular +// mode for the following scenario: one regular input is added, then a CPFP +// input is added and it goes to another batch, because they shouldn't appear +// in the same batch. Then another regular and another CPFP inputs are added and +// go to the existing batches of their types. +func testCPFP_cpfp_and_regular_sweeps(t *testing.T, store testStore, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) + require.NoError(t, err) + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, sweepStore, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + ///////////////////////////////////// + // Create the first regular sweep. // + ///////////////////////////////////// + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + + swap1 := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 1_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{1}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + err = store.CreateLoopOut(ctx, swapHash1, swap1) + require.NoError(t, err) + store.AssertLoopOutStored() + + // Deliver sweep request to batcher. + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + tx1 := <-lnd.TxPublishChannel + require.Len(t, tx1.TxIn, 1) + require.Len(t, tx1.TxOut, 1) + + ////////////////////////////////// + // Create the first CPFP sweep. // + ////////////////////////////////// + swapHash2 := lntypes.Hash{2, 2, 2} + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2}, + Index: 2, + } + + swap2 := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 2_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{2}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + err = store.CreateLoopOut(ctx, swapHash2, swap2) + require.NoError(t, err) + store.AssertLoopOutStored() + + sweepReq2 := SweepRequest{ + SwapHash: swapHash2, + Value: 2_000_000, + Outpoint: op2, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op2, true) + err = batcher.PresignSweep(ctx, op2, 2_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq2)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + tx2 := <-lnd.TxPublishChannel + require.Len(t, tx2.TxIn, 1) + require.Len(t, tx2.TxOut, 1) + require.Equal(t, op2, tx2.TxIn[0].PreviousOutPoint) + + ////////////////////////////////////// + // Create the second regular sweep. // + ////////////////////////////////////// + swapHash3 := lntypes.Hash{3, 3, 3} + op3 := wire.OutPoint{ + Hash: chainhash.Hash{3, 3}, + Index: 3, + } + sweepReq3 := SweepRequest{ + SwapHash: swapHash3, + Value: 4_000_000, + Outpoint: op3, + Notifier: &dummyNotifier, + } + + swap3 := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 4_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{3}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + err = store.CreateLoopOut(ctx, swapHash3, swap3) + require.NoError(t, err) + store.AssertLoopOutStored() + + // Deliver sweep request to batcher. + require.NoError(t, batcher.AddSweep(&sweepReq3)) + + /////////////////////////////////// + // Create the second CPFP sweep. // + /////////////////////////////////// + swapHash4 := lntypes.Hash{4, 4, 4} + op4 := wire.OutPoint{ + Hash: chainhash.Hash{4, 4}, + Index: 4, + } + + swap4 := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 3_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{4}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + err = store.CreateLoopOut(ctx, swapHash4, swap4) + require.NoError(t, err) + store.AssertLoopOutStored() + + sweepReq4 := SweepRequest{ + SwapHash: swapHash4, + Value: 3_000_000, + Outpoint: op4, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op4, true) + err = batcher.PresignSweep(ctx, op4, 4_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq4)) + + // Wait for the both batches to have two sweeps. + require.Eventually(t, func() bool { + // Make sure there are two batches. + batches := getBatches(ctx, batcher) + if len(batches) != 2 { + return false + } + + // Make sure each batch has two sweeps. + for _, batch := range batches { + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) + if numSweeps != 2 { + return false + } + } + + return true + }, test.Timeout, eventuallyCheckFrequency) + + // Mine a block to trigger both batches publishing. + require.NoError(t, lnd.NotifyHeight(601)) + + // Wait for a transactions to be published. + tx3 := <-lnd.TxPublishChannel + require.Len(t, tx3.TxIn, 2) + require.Len(t, tx3.TxOut, 1) + require.Equal(t, int64(4993740), tx3.TxOut[0].Value) + + tx4 := <-lnd.TxPublishChannel + require.Len(t, tx4.TxIn, 2) + require.Len(t, tx4.TxOut, 1) + require.Equal(t, int64(4993740), tx4.TxOut[0].Value) +} + +// TestSweepBatcherBatchCreation tests that sweep requests enter the expected +// batch based on their timeout distance. +func TestSweepBatcherBatchCreation(t *testing.T) { + runTests(t, testSweepBatcherBatchCreation) +} // TestFeeBumping tests that sweep is RBFed with slightly higher fee rate after // each block unless WithCustomFeeRate is passed. @@ -4023,6 +5559,41 @@ func TestFeeRateGrows(t *testing.T) { runTests(t, testFeeRateGrows) } +// TestCPFP tests CPFP mode. This test doesn't use loopdb. +func TestCPFP(t *testing.T) { + logger := btclog.NewBackend(os.Stdout).Logger("SWEEP") + logger.SetLevel(btclog.LevelTrace) + UseLogger(logger) + + t.Run("input1_offline_then_input2", func(t *testing.T) { + testCPFP_input1_offline_then_input2(t, NewStoreMock()) + }) + + t.Run("two_inputs_one_goes_offline", func(t *testing.T) { + testCPFP_two_inputs_one_goes_offline(t, NewStoreMock()) + }) + + t.Run("cpfp_previous_version", func(t *testing.T) { + testCPFP_cpfp_previous_version(t, NewStoreMock()) + }) + + t.Run("no_cpfp_if_all_online", func(t *testing.T) { + testCPFP_no_cpfp_if_all_online(t, NewStoreMock()) + }) + + t.Run("first_publish_fails", func(t *testing.T) { + testCPFP_first_publish_fails(t, NewStoreMock()) + }) + + t.Run("cpfp_publishing_fails", func(t *testing.T) { + testCPFP_cpfp_publishing_fails(t, NewStoreMock()) + }) + + t.Run("cpfp_and_regular_sweeps", func(t *testing.T) { + runTests(t, testCPFP_cpfp_and_regular_sweeps) + }) +} + // testBatcherStore is BatcherStore used in tests. type testBatcherStore interface { BatcherStore @@ -4035,6 +5606,8 @@ type loopdbBatcherStore struct { BatcherStore sweepsSet map[lntypes.Hash]struct{} + + mu sync.Mutex } // UpsertSweep inserts a sweep into the database, or updates an existing sweep @@ -4042,6 +5615,9 @@ type loopdbBatcherStore struct { func (s *loopdbBatcherStore) UpsertSweep(ctx context.Context, sweep *dbSweep) error { + s.mu.Lock() + defer s.mu.Unlock() + err := s.BatcherStore.UpsertSweep(ctx, sweep) if err == nil { s.sweepsSet[sweep.SwapHash] = struct{}{} @@ -4051,7 +5627,11 @@ func (s *loopdbBatcherStore) UpsertSweep(ctx context.Context, // AssertSweepStored asserts that a sweep is stored. func (s *loopdbBatcherStore) AssertSweepStored(id lntypes.Hash) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, has := s.sweepsSet[id] + return has } diff --git a/test/lnd_services_mock.go b/test/lnd_services_mock.go index db4447448..aaf5c1106 100644 --- a/test/lnd_services_mock.go +++ b/test/lnd_services_mock.go @@ -29,6 +29,7 @@ func NewMockLnd() *LndMockServices { lightningClient := &mockLightningClient{} walletKit := &mockWalletKit{ feeEstimates: make(map[int32]chainfee.SatPerKWeight), + minRelayFee: chainfee.FeePerKwFloor, } chainNotifier := &mockChainNotifier{} signer := &mockSigner{} @@ -128,6 +129,11 @@ type SignOutputRawRequest struct { SignDescriptors []*lndclient.SignDescriptor } +// PublishHandler is optional transaction handler function called upon calling +// the method PublishTransaction. +type PublishHandler func(ctx context.Context, tx *wire.MsgTx, + label string) error + // LndMockServices provides a full set of mocked lnd services. type LndMockServices struct { lndclient.LndServices @@ -173,6 +179,8 @@ type LndMockServices struct { WaitForFinished func() + PublishHandler PublishHandler + lock sync.Mutex } @@ -278,3 +286,7 @@ func (s *LndMockServices) SetFeeEstimate(confTarget int32, confTarget, feeEstimate, ) } + +func (s *LndMockServices) SetMinRelayFee(feeEstimate chainfee.SatPerKWeight) { + s.LndServices.WalletKit.(*mockWalletKit).setMinRelayFee(feeEstimate) +} diff --git a/test/walletkit_mock.go b/test/walletkit_mock.go index 637686c68..332d78866 100644 --- a/test/walletkit_mock.go +++ b/test/walletkit_mock.go @@ -1,6 +1,7 @@ package test import ( + "bytes" "context" "errors" "fmt" @@ -34,6 +35,7 @@ type mockWalletKit struct { feeEstimateLock sync.Mutex feeEstimates map[int32]chainfee.SatPerKWeight + minRelayFee chainfee.SatPerKWeight } var _ lndclient.WalletKitClient = (*mockWalletKit)(nil) @@ -111,7 +113,13 @@ func (m *mockWalletKit) NextAddr(context.Context, string, walletrpc.AddressType, } func (m *mockWalletKit) PublishTransaction(ctx context.Context, tx *wire.MsgTx, - _ string) error { + label string) error { + + if m.lnd.PublishHandler != nil { + if err := m.lnd.PublishHandler(ctx, tx, label); err != nil { + return err + } + } m.lnd.AddTx(tx) m.lnd.TxPublishChannel <- tx @@ -169,6 +177,24 @@ func (m *mockWalletKit) EstimateFeeRate(ctx context.Context, return feeEstimate, nil } +func (m *mockWalletKit) setMinRelayFee(fee chainfee.SatPerKWeight) { + m.feeEstimateLock.Lock() + defer m.feeEstimateLock.Unlock() + + m.minRelayFee = fee +} + +// MinRelayFee returns the current minimum relay fee based on our chain backend +// in sat/kw. It can be set with setMinRelayFee. +func (m *mockWalletKit) MinRelayFee( + ctx context.Context) (chainfee.SatPerKWeight, error) { + + m.feeEstimateLock.Lock() + defer m.feeEstimateLock.Unlock() + + return m.minRelayFee, nil +} + // ListSweeps returns a list of the sweep transaction ids known to our node. func (m *mockWalletKit) ListSweeps(_ context.Context, _ int32) ( []string, error) { @@ -227,6 +253,25 @@ func (m *mockWalletKit) FundPsbt(_ context.Context, return nil, 0, nil, nil } +// finalScriptWitness is a sample signature suitable to put into PSBT. +var finalScriptWitness = func() []byte { + const pver = 0 + var buf bytes.Buffer + + // Write the number of witness elements. + if err := wire.WriteVarInt(&buf, pver, 1); err != nil { + panic(err) + } + + // Write a single witness element with a signature. + signature := make([]byte, 64) + if err := wire.WriteVarBytes(&buf, pver, signature); err != nil { + panic(err) + } + + return buf.Bytes() +}() + // SignPsbt expects a partial transaction with all inputs and outputs // fully declared and tries to sign all unsigned inputs that have all // required fields (UTXO information, BIP32 derivation information, @@ -239,9 +284,19 @@ func (m *mockWalletKit) FundPsbt(_ context.Context, // locking or input/output/fee value validation, PSBT finalization). Any // input that is incomplete will be skipped. func (m *mockWalletKit) SignPsbt(_ context.Context, - _ *psbt.Packet) (*psbt.Packet, error) { + packet *psbt.Packet) (*psbt.Packet, error) { - return nil, nil + inputs := make([]psbt.PInput, len(packet.Inputs)) + copy(inputs, packet.Inputs) + + for i := range inputs { + inputs[i].FinalScriptWitness = finalScriptWitness + } + + signedPacket := *packet + signedPacket.Inputs = inputs + + return &signedPacket, nil } // FinalizePsbt expects a partial transaction with all inputs and