diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 156b5e1d..09ab267a 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -43,7 +43,8 @@ var ( ) const ( - DialTimeout = 5 * time.Second + DialTimeout = 5 * time.Second + CheckBackendInterval = time.Minute ) const ( @@ -79,6 +80,18 @@ const ( statusClosed ) +type BCConfig struct { + ProxyProtocol bool + RequireBackendTLS bool + CheckBackendInterval time.Duration +} + +func (cfg *BCConfig) check() { + if cfg.CheckBackendInterval == time.Duration(0) { + cfg.CheckBackendInterval = CheckBackendInterval + } +} + // BackendConnManager migrates a session from one BackendConnection to another. // // The signal processing goroutine tries to migrate the session once it receives a signal. @@ -97,13 +110,15 @@ type BackendConnManager struct { authenticator *Authenticator cmdProcessor *CmdProcessor eventReceiver unsafe.Pointer + config *BCConfig logger *zap.Logger // type *signalRedirect, it saves the last signal if there are multiple signals. // It will be set to nil after migration. signal unsafe.Pointer // redirectResCh is used to notify the event receiver asynchronously. - redirectResCh chan *redirectResult - closeStatus atomic.Int32 + redirectResCh chan *redirectResult + closeStatus atomic.Int32 + checkBackendTicker *time.Ticker // cancelFunc is used to cancel the signal processing goroutine. cancelFunc context.CancelFunc clientIO *pnet.PacketIO @@ -115,16 +130,17 @@ type BackendConnManager struct { } // NewBackendConnManager creates a BackendConnManager. -func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler, - connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager { +func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler, connectionID uint64, config *BCConfig) *BackendConnManager { + config.check() mgr := &BackendConnManager{ logger: logger, + config: config, connectionID: connectionID, cmdProcessor: NewCmdProcessor(), handshakeHandler: handshakeHandler, authenticator: &Authenticator{ - proxyProtocol: proxyProtocol, - requireBackendTLS: requireBackendTLS, + proxyProtocol: config.ProxyProtocol, + requireBackendTLS: config.RequireBackendTLS, salt: GenerateSalt(20), }, // There are 2 types of signals, which may be sent concurrently. @@ -158,6 +174,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.cmdProcessor.capability = mgr.authenticator.capability childCtx, cancelFunc := context.WithCancel(ctx) mgr.cancelFunc = cancelFunc + mgr.resetCheckBackendTicker() mgr.wg.Run(func() { mgr.processSignals(childCtx) }) @@ -243,6 +260,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e case statusClosing, statusClosed: return nil } + defer mgr.resetCheckBackendTicker() waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, waitingRedirect) if !holdRequest { @@ -336,6 +354,7 @@ func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken // processSignals runs in a goroutine to: // - Receive redirection signals and then try to migrate the session. // - Send redirection results to the event receiver. +// - Check if the backend is still alive. func (mgr *BackendConnManager) processSignals(ctx context.Context) { for { select { @@ -351,6 +370,8 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context) { mgr.processLock.Unlock() case rs := <-mgr.redirectResCh: mgr.notifyRedirectResult(ctx, rs) + case <-mgr.checkBackendTicker.C: + mgr.checkBackendActive() case <-ctx.Done(): return } @@ -499,6 +520,34 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context) { mgr.closeStatus.Store(statusClosing) } +func (mgr *BackendConnManager) checkBackendActive() { + switch mgr.closeStatus.Load() { + case statusClosing, statusClosed: + return + } + + mgr.processLock.Lock() + defer mgr.processLock.Unlock() + if !mgr.backendIO.IsPeerActive() { + mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client", mgr.clientIO.RemoteAddr()), + zap.Stringer("backend", mgr.backendIO.RemoteAddr())) + if err := mgr.clientIO.GracefulClose(); err != nil { + mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) + } + mgr.closeStatus.Store(statusClosing) + } +} + +// Checking backend is expensive, so only check it when the client is idle for some time. +// This function should be called within the lock. +func (mgr *BackendConnManager) resetCheckBackendTicker() { + if mgr.checkBackendTicker == nil { + mgr.checkBackendTicker = time.NewTicker(mgr.config.CheckBackendInterval) + } else { + mgr.checkBackendTicker.Reset(mgr.config.CheckBackendInterval) + } +} + func (mgr *BackendConnManager) ClientAddr() string { if mgr.clientIO == nil { return "" @@ -542,6 +591,9 @@ func (mgr *BackendConnManager) Value(key any) any { // Close releases all resources. func (mgr *BackendConnManager) Close() error { mgr.closeStatus.Store(statusClosing) + if mgr.checkBackendTicker != nil { + mgr.checkBackendTicker.Stop() + } if mgr.cancelFunc != nil { mgr.cancelFunc() mgr.cancelFunc = nil diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index e3569503..89a267fe 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -211,15 +211,15 @@ func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketI return nil } -func (ts *backendMgrTester) checkConnClosed(_, _ *pnet.PacketIO) error { - for i := 0; i < 30; i++ { +func (ts *backendMgrTester) checkConnClosed4Proxy(_, _ *pnet.PacketIO) error { + require.Eventually(ts.t, func() bool { switch ts.mp.closeStatus.Load() { case statusClosing, statusClosed: - return nil + return true } - time.Sleep(100 * time.Millisecond) - } - return errors.New("timeout") + return false + }, 3*time.Second, 100*time.Millisecond) + return nil } func (ts *backendMgrTester) runTests(runners []runner) { @@ -621,7 +621,7 @@ func TestGracefulCloseWhenIdle(t *testing.T) { }, // really closed { - proxy: ts.checkConnClosed, + proxy: ts.checkConnClosed4Proxy, }, } ts.runTests(runners) @@ -659,7 +659,7 @@ func TestGracefulCloseWhenActive(t *testing.T) { }, // it will then automatically close { - proxy: ts.checkConnClosed, + proxy: ts.checkConnClosed4Proxy, }, } ts.runTests(runners) @@ -683,7 +683,7 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) { }, // it will then automatically close { - proxy: ts.checkConnClosed, + proxy: ts.checkConnClosed4Proxy, }, } ts.runTests(runners) @@ -763,7 +763,7 @@ func TestGetBackendIO(t *testing.T) { } }, } - mgr := NewBackendConnManager(logger.CreateLoggerForTest(t), handler, 0, false, false) + mgr := NewBackendConnManager(logger.CreateLoggerForTest(t), handler, 0, &BCConfig{}) var wg waitgroup.WaitGroup for i := 0; i <= len(listeners); i++ { wg.Run(func() { @@ -790,3 +790,81 @@ func TestGetBackendIO(t *testing.T) { wg.Wait() } } + +func TestBackendInactive(t *testing.T) { + ts := newBackendMgrTester(t, func(config *testConfig) { + config.proxyConfig.checkBackendInterval = 10 * time.Millisecond + }) + runners := []runner{ + // 1st handshake + { + client: ts.mc.authenticate, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }, + // do some queries and the interval is less than checkBackendInterval + { + client: func(packetIO *pnet.PacketIO) error { + for i := 0; i < 10; i++ { + time.Sleep(5 * time.Millisecond) + if err := ts.mc.request(packetIO); err != nil { + return err + } + } + return nil + }, + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + for i := 0; i < 10; i++ { + if err := ts.forwardCmd4Proxy(clientIO, backendIO); err != nil { + return err + } + } + return nil + }, + backend: func(packetIO *pnet.PacketIO) error { + for i := 0; i < 10; i++ { + if err := ts.respondWithNoTxn4Backend(packetIO); err != nil { + return err + } + } + return nil + }, + }, + // do some queries and the interval is longer than checkBackendInterval + { + client: func(packetIO *pnet.PacketIO) error { + for i := 0; i < 5; i++ { + time.Sleep(30 * time.Millisecond) + if err := ts.mc.request(packetIO); err != nil { + return err + } + } + return nil + }, + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + for i := 0; i < 5; i++ { + if err := ts.forwardCmd4Proxy(clientIO, backendIO); err != nil { + return err + } + } + return nil + }, + backend: func(packetIO *pnet.PacketIO) error { + for i := 0; i < 5; i++ { + if err := ts.respondWithNoTxn4Backend(packetIO); err != nil { + return err + } + } + return nil + }, + }, + // close the backend and the client connection will close + { + proxy: ts.checkConnClosed4Proxy, + backend: func(packetIO *pnet.PacketIO) error { + return packetIO.Close() + }, + }, + } + ts.runTests(runners) +} diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 2bfd6901..3671d731 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -26,19 +26,21 @@ import ( ) type proxyConfig struct { - frontendTLSConfig *tls.Config - backendTLSConfig *tls.Config - handler *CustomHandshakeHandler - sessionToken string - capability pnet.Capability - waitRedirect bool + frontendTLSConfig *tls.Config + backendTLSConfig *tls.Config + handler *CustomHandshakeHandler + checkBackendInterval time.Duration + sessionToken string + capability pnet.Capability + waitRedirect bool } func newProxyConfig() *proxyConfig { return &proxyConfig{ - handler: &CustomHandshakeHandler{}, - capability: defaultTestBackendCapability, - sessionToken: mockToken, + handler: &CustomHandshakeHandler{}, + capability: defaultTestBackendCapability, + sessionToken: mockToken, + checkBackendInterval: CheckBackendInterval, } } @@ -56,9 +58,11 @@ type mockProxy struct { func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { mp := &mockProxy{ - proxyConfig: cfg, - logger: logger.CreateLoggerForTest(t).Named("mockProxy"), - BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), cfg.handler, 0, false, false), + proxyConfig: cfg, + logger: logger.CreateLoggerForTest(t).Named("mockProxy"), + BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), cfg.handler, 0, &BCConfig{ + CheckBackendInterval: cfg.checkBackendInterval, + }), } mp.cmdProcessor.capability = cfg.capability.Uint32() return mp diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index d9760c04..7ef8cd86 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -42,7 +42,10 @@ type ClientConnection struct { func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config, hsHandler backend.HandshakeHandler, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection { - bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, proxyProtocol, requireBackendTLS) + bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, &backend.BCConfig{ + ProxyProtocol: proxyProtocol, + RequireBackendTLS: requireBackendTLS, + }) opts := make([]pnet.PacketIOption, 0, 2) opts = append(opts, pnet.WithWrapError(ErrClientConn)) if proxyProtocol { diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 393d4660..f92b0174 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -141,7 +141,7 @@ func (p *PacketIO) GetSequence() uint8 { func (p *PacketIO) readOnePacket() ([]byte, bool, error) { var header [4]byte - if _, err := io.ReadFull(p.conn, header[:]); err != nil { + if _, err := io.ReadFull(p.buf, header[:]); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } p.inBytes += 4 @@ -164,7 +164,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { // refill mysql headers if refill { - if _, err := io.ReadFull(p.conn, header[:]); err != nil { + if _, err := io.ReadFull(p.buf, header[:]); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } p.inBytes += 4 @@ -178,7 +178,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) data := make([]byte, length) - if _, err := io.ReadFull(p.conn, data); err != nil { + if _, err := io.ReadFull(p.buf, data); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } p.inBytes += uint64(length) @@ -261,6 +261,24 @@ func (p *PacketIO) Flush() error { return nil } +// IsPeerActive checks if the peer connection is still active. +// This function cannot be called concurrently with other functions of PacketIO. +// This function normally costs 1ms, so don't call it too frequently. +// This function may incorrectly return true if the system is extremely slow. +func (p *PacketIO) IsPeerActive() bool { + if err := p.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { + return false + } + active := true + if _, err := p.buf.Peek(1); err != nil { + active = !errors.Is(err, io.EOF) + } + if err := p.conn.SetReadDeadline(time.Time{}); err != nil { + return false + } + return active +} + func (p *PacketIO) GracefulClose() error { if err := p.conn.SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) { return err diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 281853b8..92456eba 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -204,3 +204,61 @@ func TestPacketIOClose(t *testing.T) { 1, ) } + +func TestPeerActive(t *testing.T) { + stls, ctls, err := security.CreateTLSConfigForTest() + require.NoError(t, err) + ch := make(chan struct{}) + testTCPConn(t, + func(t *testing.T, cli *PacketIO) { + // It's active at the beginning. + require.True(t, cli.IsPeerActive()) + ch <- struct{}{} // let srv write packet + // ReadPacket still reads the whole data after checking. + ch <- struct{}{} + require.True(t, cli.IsPeerActive()) + data, err := cli.ReadPacket() + require.NoError(t, err) + require.Equal(t, "123", string(data)) + // IsPeerActive works after reading data. + require.True(t, cli.IsPeerActive()) + // IsPeerActive works after writing data. + require.NoError(t, cli.WritePacket([]byte("456"), true)) + require.True(t, cli.IsPeerActive()) + // upgrade to TLS and try again + require.NoError(t, cli.ClientTLSHandshake(ctls)) + require.True(t, cli.IsPeerActive()) + data, err = cli.ReadPacket() + require.NoError(t, err) + require.Equal(t, "123", string(data)) + require.True(t, cli.IsPeerActive()) + require.NoError(t, cli.WritePacket([]byte("456"), true)) + require.True(t, cli.IsPeerActive()) + // It's not active after the peer closes. + ch <- struct{}{} + ch <- struct{}{} + require.False(t, cli.IsPeerActive()) + }, + func(t *testing.T, srv *PacketIO) { + <-ch + err := srv.WritePacket([]byte("123"), true) + require.NoError(t, err) + <-ch + data, err := srv.ReadPacket() + require.NoError(t, err) + require.Equal(t, "456", string(data)) + // upgrade to TLS and try again + _, err = srv.ServerTLSHandshake(stls) + require.NoError(t, err) + err = srv.WritePacket([]byte("123"), true) + require.NoError(t, err) + data, err = srv.ReadPacket() + require.NoError(t, err) + require.Equal(t, "456", string(data)) + <-ch + require.NoError(t, srv.Close()) + <-ch + }, + 10, + ) +} diff --git a/pkg/proxy/net/tls.go b/pkg/proxy/net/tls.go index 754858d3..0e5ecb01 100644 --- a/pkg/proxy/net/tls.go +++ b/pkg/proxy/net/tls.go @@ -15,6 +15,7 @@ package net import ( + "bufio" "crypto/tls" "github.com/pingcap/TiProxy/lib/util/errors" @@ -28,6 +29,8 @@ func (p *PacketIO) ServerTLSHandshake(tlsConfig *tls.Config) (tls.ConnectionStat } p.conn = tlsConn p.buf.Writer.Reset(p.conn) + // Wrap it with another buffer to enable Peek. + p.buf = bufio.NewReadWriter(bufio.NewReaderSize(p.conn, defaultReaderSize), p.buf.Writer) return tlsConn.ConnectionState(), nil } @@ -39,5 +42,7 @@ func (p *PacketIO) ClientTLSHandshake(tlsConfig *tls.Config) error { } p.conn = tlsConn p.buf.Writer.Reset(p.conn) + // Wrap it with another buffer to enable Peek. + p.buf = bufio.NewReadWriter(bufio.NewReaderSize(p.conn, defaultReaderSize), p.buf.Writer) return nil }