From 49c7902daf2d7d4b04f3875a5abf8f3cf9ee2b6d Mon Sep 17 00:00:00 2001
From: Kevin Lin <developer@kevinlin.info>
Date: Sat, 10 Aug 2024 14:41:15 -0700
Subject: [PATCH] Bound connection pool background dials to configured dial
 timeout

---
 internal/pool/bench_test.go |  2 ++
 internal/pool/pool.go       | 12 ++++++++++--
 internal/pool/pool_test.go  |  5 +++++
 options.go                  |  1 +
 4 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go
index 71049f480d..72308e1242 100644
--- a/internal/pool/bench_test.go
+++ b/internal/pool/bench_test.go
@@ -33,6 +33,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
 				Dialer:          dummyDialer,
 				PoolSize:        bm.poolSize,
 				PoolTimeout:     time.Second,
+				DialTimeout:     1 * time.Second,
 				ConnMaxIdleTime: time.Hour,
 			})
 
@@ -76,6 +77,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
 				Dialer:          dummyDialer,
 				PoolSize:        bm.poolSize,
 				PoolTimeout:     time.Second,
+				DialTimeout:     1 * time.Second,
 				ConnMaxIdleTime: time.Hour,
 			})
 
diff --git a/internal/pool/pool.go b/internal/pool/pool.go
index 2125f3e133..b69c75f4f0 100644
--- a/internal/pool/pool.go
+++ b/internal/pool/pool.go
@@ -62,6 +62,7 @@ type Options struct {
 
 	PoolFIFO        bool
 	PoolSize        int
+	DialTimeout     time.Duration
 	PoolTimeout     time.Duration
 	MinIdleConns    int
 	MaxIdleConns    int
@@ -140,7 +141,10 @@ func (p *ConnPool) checkMinIdleConns() {
 }
 
 func (p *ConnPool) addIdleConn() error {
-	cn, err := p.dialConn(context.TODO(), true)
+	ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout)
+	defer cancel()
+
+	cn, err := p.dialConn(ctx, true)
 	if err != nil {
 		return err
 	}
@@ -230,15 +234,19 @@ func (p *ConnPool) tryDial() {
 			return
 		}
 
-		conn, err := p.cfg.Dialer(context.Background())
+		ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout)
+
+		conn, err := p.cfg.Dialer(ctx)
 		if err != nil {
 			p.setLastDialError(err)
 			time.Sleep(time.Second)
+			cancel()
 			continue
 		}
 
 		atomic.StoreUint32(&p.dialErrorsNum, 0)
 		_ = conn.Close()
+		cancel()
 		return
 	}
 }
diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go
index 76dec996b5..3ed9716ac9 100644
--- a/internal/pool/pool_test.go
+++ b/internal/pool/pool_test.go
@@ -22,6 +22,7 @@ var _ = Describe("ConnPool", func() {
 			Dialer:          dummyDialer,
 			PoolSize:        10,
 			PoolTimeout:     time.Hour,
+			DialTimeout:     1 * time.Second,
 			ConnMaxIdleTime: time.Millisecond,
 		})
 	})
@@ -46,6 +47,7 @@ var _ = Describe("ConnPool", func() {
 			},
 			PoolSize:        10,
 			PoolTimeout:     time.Hour,
+			DialTimeout:     1 * time.Second,
 			ConnMaxIdleTime: time.Millisecond,
 			MinIdleConns:    minIdleConns,
 		})
@@ -129,6 +131,7 @@ var _ = Describe("MinIdleConns", func() {
 			PoolSize:        poolSize,
 			MinIdleConns:    minIdleConns,
 			PoolTimeout:     100 * time.Millisecond,
+			DialTimeout:     1 * time.Second,
 			ConnMaxIdleTime: -1,
 		})
 		Eventually(func() int {
@@ -306,6 +309,7 @@ var _ = Describe("race", func() {
 			Dialer:          dummyDialer,
 			PoolSize:        10,
 			PoolTimeout:     time.Minute,
+			DialTimeout:     1 * time.Second,
 			ConnMaxIdleTime: time.Millisecond,
 		})
 
@@ -336,6 +340,7 @@ var _ = Describe("race", func() {
 			PoolSize:     1000,
 			MinIdleConns: 50,
 			PoolTimeout:  3 * time.Second,
+			DialTimeout:  1 * time.Second,
 		}
 		p := pool.NewConnPool(opt)
 
diff --git a/options.go b/options.go
index 6ed693a0b0..7fd6538e10 100644
--- a/options.go
+++ b/options.go
@@ -515,6 +515,7 @@ func newConnPool(
 		PoolFIFO:        opt.PoolFIFO,
 		PoolSize:        opt.PoolSize,
 		PoolTimeout:     opt.PoolTimeout,
+		DialTimeout:     opt.DialTimeout,
 		MinIdleConns:    opt.MinIdleConns,
 		MaxIdleConns:    opt.MaxIdleConns,
 		MaxActiveConns:  opt.MaxActiveConns,