Skip to content

backend, net: close the client connection when the backend is down #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ var (
)

const (
DialTimeout = 5 * time.Second
DialTimeout = 5 * time.Second
CheckBackendInterval = time.Minute
)

const (
Expand Down Expand Up @@ -79,6 +80,18 @@ const (
statusClosed
)

type BCConfig struct {
ProxyProtocol bool
RequireBackendTLS bool
CheckBackendInterval time.Duration
}

func (cfg *BCConfig) check() {
if cfg.CheckBackendInterval == time.Duration(0) {
cfg.CheckBackendInterval = CheckBackendInterval
}
}

// BackendConnManager migrates a session from one BackendConnection to another.
//
// The signal processing goroutine tries to migrate the session once it receives a signal.
Expand All @@ -97,13 +110,15 @@ type BackendConnManager struct {
authenticator *Authenticator
cmdProcessor *CmdProcessor
eventReceiver unsafe.Pointer
config *BCConfig
logger *zap.Logger
// type *signalRedirect, it saves the last signal if there are multiple signals.
// It will be set to nil after migration.
signal unsafe.Pointer
// redirectResCh is used to notify the event receiver asynchronously.
redirectResCh chan *redirectResult
closeStatus atomic.Int32
redirectResCh chan *redirectResult
closeStatus atomic.Int32
checkBackendTicker *time.Ticker
// cancelFunc is used to cancel the signal processing goroutine.
cancelFunc context.CancelFunc
clientIO *pnet.PacketIO
Expand All @@ -115,16 +130,17 @@ type BackendConnManager struct {
}

// NewBackendConnManager creates a BackendConnManager.
func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler,
connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager {
func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler, connectionID uint64, config *BCConfig) *BackendConnManager {
config.check()
mgr := &BackendConnManager{
logger: logger,
config: config,
connectionID: connectionID,
cmdProcessor: NewCmdProcessor(),
handshakeHandler: handshakeHandler,
authenticator: &Authenticator{
proxyProtocol: proxyProtocol,
requireBackendTLS: requireBackendTLS,
proxyProtocol: config.ProxyProtocol,
requireBackendTLS: config.RequireBackendTLS,
salt: GenerateSalt(20),
},
// There are 2 types of signals, which may be sent concurrently.
Expand Down Expand Up @@ -158,6 +174,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
mgr.cmdProcessor.capability = mgr.authenticator.capability
childCtx, cancelFunc := context.WithCancel(ctx)
mgr.cancelFunc = cancelFunc
mgr.resetCheckBackendTicker()
mgr.wg.Run(func() {
mgr.processSignals(childCtx)
})
Expand Down Expand Up @@ -243,6 +260,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e
case statusClosing, statusClosed:
return nil
}
defer mgr.resetCheckBackendTicker()
waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil
holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, waitingRedirect)
if !holdRequest {
Expand Down Expand Up @@ -336,6 +354,7 @@ func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken
// processSignals runs in a goroutine to:
// - Receive redirection signals and then try to migrate the session.
// - Send redirection results to the event receiver.
// - Check if the backend is still alive.
func (mgr *BackendConnManager) processSignals(ctx context.Context) {
for {
select {
Expand All @@ -351,6 +370,8 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context) {
mgr.processLock.Unlock()
case rs := <-mgr.redirectResCh:
mgr.notifyRedirectResult(ctx, rs)
case <-mgr.checkBackendTicker.C:
mgr.checkBackendActive()
case <-ctx.Done():
return
}
Expand Down Expand Up @@ -499,6 +520,34 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context) {
mgr.closeStatus.Store(statusClosing)
}

func (mgr *BackendConnManager) checkBackendActive() {
switch mgr.closeStatus.Load() {
case statusClosing, statusClosed:
return
}

mgr.processLock.Lock()
defer mgr.processLock.Unlock()
if !mgr.backendIO.IsPeerActive() {
mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client", mgr.clientIO.RemoteAddr()),
zap.Stringer("backend", mgr.backendIO.RemoteAddr()))
if err := mgr.clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err))
}
mgr.closeStatus.Store(statusClosing)
}
}

// Checking backend is expensive, so only check it when the client is idle for some time.
// This function should be called within the lock.
func (mgr *BackendConnManager) resetCheckBackendTicker() {
if mgr.checkBackendTicker == nil {
mgr.checkBackendTicker = time.NewTicker(mgr.config.CheckBackendInterval)
} else {
mgr.checkBackendTicker.Reset(mgr.config.CheckBackendInterval)
}
}

func (mgr *BackendConnManager) ClientAddr() string {
if mgr.clientIO == nil {
return ""
Expand Down Expand Up @@ -542,6 +591,9 @@ func (mgr *BackendConnManager) Value(key any) any {
// Close releases all resources.
func (mgr *BackendConnManager) Close() error {
mgr.closeStatus.Store(statusClosing)
if mgr.checkBackendTicker != nil {
mgr.checkBackendTicker.Stop()
}
if mgr.cancelFunc != nil {
mgr.cancelFunc()
mgr.cancelFunc = nil
Expand Down
98 changes: 88 additions & 10 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,15 +211,15 @@ func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketI
return nil
}

func (ts *backendMgrTester) checkConnClosed(_, _ *pnet.PacketIO) error {
for i := 0; i < 30; i++ {
func (ts *backendMgrTester) checkConnClosed4Proxy(_, _ *pnet.PacketIO) error {
require.Eventually(ts.t, func() bool {
switch ts.mp.closeStatus.Load() {
case statusClosing, statusClosed:
return nil
return true
}
time.Sleep(100 * time.Millisecond)
}
return errors.New("timeout")
return false
}, 3*time.Second, 100*time.Millisecond)
return nil
}

func (ts *backendMgrTester) runTests(runners []runner) {
Expand Down Expand Up @@ -621,7 +621,7 @@ func TestGracefulCloseWhenIdle(t *testing.T) {
},
// really closed
{
proxy: ts.checkConnClosed,
proxy: ts.checkConnClosed4Proxy,
},
}
ts.runTests(runners)
Expand Down Expand Up @@ -659,7 +659,7 @@ func TestGracefulCloseWhenActive(t *testing.T) {
},
// it will then automatically close
{
proxy: ts.checkConnClosed,
proxy: ts.checkConnClosed4Proxy,
},
}
ts.runTests(runners)
Expand All @@ -683,7 +683,7 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) {
},
// it will then automatically close
{
proxy: ts.checkConnClosed,
proxy: ts.checkConnClosed4Proxy,
},
}
ts.runTests(runners)
Expand Down Expand Up @@ -763,7 +763,7 @@ func TestGetBackendIO(t *testing.T) {
}
},
}
mgr := NewBackendConnManager(logger.CreateLoggerForTest(t), handler, 0, false, false)
mgr := NewBackendConnManager(logger.CreateLoggerForTest(t), handler, 0, &BCConfig{})
var wg waitgroup.WaitGroup
for i := 0; i <= len(listeners); i++ {
wg.Run(func() {
Expand All @@ -790,3 +790,81 @@ func TestGetBackendIO(t *testing.T) {
wg.Wait()
}
}

func TestBackendInactive(t *testing.T) {
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.checkBackendInterval = 10 * time.Millisecond
})
runners := []runner{
// 1st handshake
{
client: ts.mc.authenticate,
proxy: ts.firstHandshake4Proxy,
backend: ts.handshake4Backend,
},
// do some queries and the interval is less than checkBackendInterval
{
client: func(packetIO *pnet.PacketIO) error {
for i := 0; i < 10; i++ {
time.Sleep(5 * time.Millisecond)
if err := ts.mc.request(packetIO); err != nil {
return err
}
}
return nil
},
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
for i := 0; i < 10; i++ {
if err := ts.forwardCmd4Proxy(clientIO, backendIO); err != nil {
return err
}
}
return nil
},
backend: func(packetIO *pnet.PacketIO) error {
for i := 0; i < 10; i++ {
if err := ts.respondWithNoTxn4Backend(packetIO); err != nil {
return err
}
}
return nil
},
},
// do some queries and the interval is longer than checkBackendInterval
{
client: func(packetIO *pnet.PacketIO) error {
for i := 0; i < 5; i++ {
time.Sleep(30 * time.Millisecond)
if err := ts.mc.request(packetIO); err != nil {
return err
}
}
return nil
},
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
for i := 0; i < 5; i++ {
if err := ts.forwardCmd4Proxy(clientIO, backendIO); err != nil {
return err
}
}
return nil
},
backend: func(packetIO *pnet.PacketIO) error {
for i := 0; i < 5; i++ {
if err := ts.respondWithNoTxn4Backend(packetIO); err != nil {
return err
}
}
return nil
},
},
// close the backend and the client connection will close
{
proxy: ts.checkConnClosed4Proxy,
backend: func(packetIO *pnet.PacketIO) error {
return packetIO.Close()
},
},
}
ts.runTests(runners)
}
28 changes: 16 additions & 12 deletions pkg/proxy/backend/mock_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ import (
)

type proxyConfig struct {
frontendTLSConfig *tls.Config
backendTLSConfig *tls.Config
handler *CustomHandshakeHandler
sessionToken string
capability pnet.Capability
waitRedirect bool
frontendTLSConfig *tls.Config
backendTLSConfig *tls.Config
handler *CustomHandshakeHandler
checkBackendInterval time.Duration
sessionToken string
capability pnet.Capability
waitRedirect bool
}

func newProxyConfig() *proxyConfig {
return &proxyConfig{
handler: &CustomHandshakeHandler{},
capability: defaultTestBackendCapability,
sessionToken: mockToken,
handler: &CustomHandshakeHandler{},
capability: defaultTestBackendCapability,
sessionToken: mockToken,
checkBackendInterval: CheckBackendInterval,
}
}

Expand All @@ -56,9 +58,11 @@ type mockProxy struct {

func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy {
mp := &mockProxy{
proxyConfig: cfg,
logger: logger.CreateLoggerForTest(t).Named("mockProxy"),
BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), cfg.handler, 0, false, false),
proxyConfig: cfg,
logger: logger.CreateLoggerForTest(t).Named("mockProxy"),
BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), cfg.handler, 0, &BCConfig{
CheckBackendInterval: cfg.checkBackendInterval,
}),
}
mp.cmdProcessor.capability = cfg.capability.Uint32()
return mp
Expand Down
5 changes: 4 additions & 1 deletion pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ type ClientConnection struct {

func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config,
hsHandler backend.HandshakeHandler, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection {
bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, proxyProtocol, requireBackendTLS)
bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, &backend.BCConfig{
ProxyProtocol: proxyProtocol,
RequireBackendTLS: requireBackendTLS,
})
opts := make([]pnet.PacketIOption, 0, 2)
opts = append(opts, pnet.WithWrapError(ErrClientConn))
if proxyProtocol {
Expand Down
24 changes: 21 additions & 3 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (p *PacketIO) GetSequence() uint8 {
func (p *PacketIO) readOnePacket() ([]byte, bool, error) {
var header [4]byte

if _, err := io.ReadFull(p.conn, header[:]); err != nil {
if _, err := io.ReadFull(p.buf, header[:]); err != nil {
return nil, false, errors.Wrap(ErrReadConn, err)
}
p.inBytes += 4
Expand All @@ -164,7 +164,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) {

// refill mysql headers
if refill {
if _, err := io.ReadFull(p.conn, header[:]); err != nil {
if _, err := io.ReadFull(p.buf, header[:]); err != nil {
return nil, false, errors.Wrap(ErrReadConn, err)
}
p.inBytes += 4
Expand All @@ -178,7 +178,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) {
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

data := make([]byte, length)
if _, err := io.ReadFull(p.conn, data); err != nil {
if _, err := io.ReadFull(p.buf, data); err != nil {
return nil, false, errors.Wrap(ErrReadConn, err)
}
p.inBytes += uint64(length)
Expand Down Expand Up @@ -261,6 +261,24 @@ func (p *PacketIO) Flush() error {
return nil
}

// IsPeerActive checks if the peer connection is still active.
// This function cannot be called concurrently with other functions of PacketIO.
// This function normally costs 1ms, so don't call it too frequently.
// This function may incorrectly return true if the system is extremely slow.
func (p *PacketIO) IsPeerActive() bool {
if err := p.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
return false
}
active := true
if _, err := p.buf.Peek(1); err != nil {
active = !errors.Is(err, io.EOF)
}
if err := p.conn.SetReadDeadline(time.Time{}); err != nil {
return false
}
return active
}

func (p *PacketIO) GracefulClose() error {
if err := p.conn.SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) {
return err
Expand Down
Loading