Skip to content

Commit c162879

Browse files
committed
feat: reduce the type assertion of CheckConn
Signed-off-by: monkey <[email protected]>
1 parent 67824eb commit c162879

File tree

4 files changed

+36
-25
lines changed

4 files changed

+36
-25
lines changed

internal/pool/conn.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package pool
33
import (
44
"bufio"
55
"context"
6+
"crypto/tls"
67
"net"
78
"sync/atomic"
9+
"syscall"
810
"time"
911

1012
"github.com/redis/go-redis/v9/internal/proto"
@@ -16,6 +18,9 @@ type Conn struct {
1618
usedAt int64 // atomic
1719
netConn net.Conn
1820

21+
// for checking the health status of the connection, it may be nil.
22+
rawConn syscall.RawConn
23+
1924
rd *proto.Reader
2025
bw *bufio.Writer
2126
wr *proto.Writer
@@ -34,6 +39,7 @@ func NewConn(netConn net.Conn) *Conn {
3439
cn.bw = bufio.NewWriter(netConn)
3540
cn.wr = proto.NewWriter(cn.bw)
3641
cn.SetUsedAt(time.Now())
42+
cn.setRawConn()
3743
return cn
3844
}
3945

@@ -50,6 +56,23 @@ func (cn *Conn) SetNetConn(netConn net.Conn) {
5056
cn.netConn = netConn
5157
cn.rd.Reset(netConn)
5258
cn.bw.Reset(netConn)
59+
cn.setRawConn()
60+
}
61+
62+
func (cn *Conn) setRawConn() {
63+
conn := cn.netConn
64+
if conn == nil {
65+
return
66+
}
67+
if tlsConn, ok := conn.(*tls.Conn); ok {
68+
conn = tlsConn.NetConn()
69+
}
70+
71+
if sysConn, ok := conn.(syscall.Conn); ok {
72+
if rawConn, err := sysConn.SyscallConn(); err == nil {
73+
cn.rawConn = rawConn
74+
}
75+
}
5376
}
5477

5578
func (cn *Conn) Write(b []byte) (int, error) {

internal/pool/conn_check.go

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,15 @@
33
package pool
44

55
import (
6-
"crypto/tls"
76
"errors"
87
"io"
9-
"net"
108
"syscall"
11-
"time"
129
)
1310

1411
var errUnexpectedRead = errors.New("unexpected read from socket")
1512

16-
func connCheck(conn net.Conn) error {
17-
// Reset previous timeout.
18-
_ = conn.SetDeadline(time.Time{})
19-
20-
// Check if tls.Conn.
21-
if c, ok := conn.(*tls.Conn); ok {
22-
conn = c.NetConn()
23-
}
24-
sysConn, ok := conn.(syscall.Conn)
25-
if !ok {
26-
return nil
27-
}
28-
rawConn, err := sysConn.SyscallConn()
29-
if err != nil {
30-
return err
31-
}
32-
13+
func connCheck(rawConn syscall.RawConn) error {
3314
var sysErr error
34-
3515
if err := rawConn.Read(func(fd uintptr) bool {
3616
var buf [1]byte
3717
n, err := syscall.Read(int(fd), buf[:])

internal/pool/conn_check_dummy.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
package pool
44

5-
import "net"
5+
import (
6+
"syscall"
7+
)
68

7-
func connCheck(conn net.Conn) error {
9+
func connCheck(_ syscall.RawConn) error {
810
return nil
911
}

internal/pool/pool.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,8 @@ func (p *ConnPool) Close() error {
499499
return firstErr
500500
}
501501

502+
var zeroTime = time.Time{}
503+
502504
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
503505
now := time.Now()
504506

@@ -509,8 +511,12 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
509511
return false
510512
}
511513

512-
if connCheck(cn.netConn) != nil {
513-
return false
514+
if cn.rawConn != nil {
515+
// reset previous timeout.
516+
_ = cn.netConn.SetDeadline(zeroTime)
517+
if connCheck(cn.rawConn) != nil {
518+
return false
519+
}
514520
}
515521

516522
cn.SetUsedAt(now)

0 commit comments

Comments
 (0)