From 4398ffffc19eb4d2f3e9781c0f6bd5a66309a196 Mon Sep 17 00:00:00 2001 From: Oleg Jukovec Date: Thu, 10 Apr 2025 10:18:34 +0300 Subject: [PATCH] dial: Connect() does not cancel by context if no i/o NetDialer, GreetingDialer, AuthDialer and ProtocolDialer may not cancel Dial() on context expiration when network connection hangs. The issue occurred because context wasn't properly handled during network I/O operations, potentially causing infinite waiting. Part of #TNTP-2018 --- CHANGELOG.md | 5 +- connection.go | 2 +- dial.go | 109 ++++++++++++++++++++++++++++++++++++------- dial_test.go | 125 ++++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 214 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aebbd1d3d..f9c59d95b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Fixed +- Connect() may not cancel Dial() call on context expiration if network + connection hangs (#443). + ## [v2.3.1] - 2025-04-03 The patch releases fixes expected Connect() behavior and reduces allocations. @@ -21,7 +24,7 @@ The patch releases fixes expected Connect() behavior and reduces allocations. ### Added - A usage of sync.Pool of msgpack.Decoder saves 2 object allocations per - a response decoding. + a response decoding (#440). ### Changed diff --git a/connection.go b/connection.go index 8c8552e60..2b43f2ec0 100644 --- a/connection.go +++ b/connection.go @@ -489,7 +489,7 @@ func (conn *Connection) dial(ctx context.Context) error { } req := newWatchRequest(key.(string)) - if err = writeRequest(c, req); err != nil { + if err = writeRequest(ctx, c, req); err != nil { st <- state return false } diff --git a/dial.go b/dial.go index d0acf69c0..9faeaa98f 100644 --- a/dial.go +++ b/dial.go @@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { if err != nil { return conn, err } + greeting := conn.Greeting() if greeting.Salt == "" { conn.Close() @@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { } } - if err := authenticate(conn, d.Auth, d.Username, d.Password, + if err := authenticate(ctx, conn, d.Auth, d.Username, d.Password, conn.Greeting().Salt); err != nil { conn.Close() return nil, fmt.Errorf("failed to authenticate: %w", err) @@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { protocolInfo: d.RequiredProtocolInfo, } - protocolConn.protocolInfo, err = identify(&protocolConn) + protocolConn.protocolInfo, err = identify(ctx, &protocolConn) if err != nil { protocolConn.Close() return nil, fmt.Errorf("failed to identify: %w", err) @@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { greetingConn := greetingConn{ Conn: conn, } - version, salt, err := readGreeting(greetingConn) + version, salt, err := readGreeting(ctx, &greetingConn) if err != nil { greetingConn.Close() return nil, fmt.Errorf("failed to read greeting: %w", err) } + greetingConn.greeting = Greeting{ Version: version, Salt: salt, @@ -410,31 +412,67 @@ func parseAddress(address string) (string, string) { return network, address } +// ioWaiter waits in a background until an io operation done or a context +// is expired. It closes the connection and writes a context error into the +// output channel on context expiration. +// +// A user of the helper should close the first output channel after an IO +// operation done and read an error from a second channel to get the result +// of waiting. +func ioWaiter(ctx context.Context, conn Conn) (chan<- struct{}, <-chan error) { + doneIO := make(chan struct{}) + doneWait := make(chan error, 1) + + go func() { + defer close(doneWait) + + select { + case <-ctx.Done(): + conn.Close() + <-doneIO + doneWait <- ctx.Err() + case <-doneIO: + doneWait <- nil + } + }() + + return doneIO, doneWait +} + // readGreeting reads a greeting message. -func readGreeting(reader io.Reader) (string, string, error) { +func readGreeting(ctx context.Context, conn Conn) (string, string, error) { var version, salt string + doneRead, doneWait := ioWaiter(ctx, conn) + data := make([]byte, 128) - _, err := io.ReadFull(reader, data) + _, err := io.ReadFull(conn, data) + + close(doneRead) + if err == nil { version = bytes.NewBuffer(data[:64]).String() salt = bytes.NewBuffer(data[64:108]).String() } + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + return version, salt, err } // identify sends info about client protocol, receives info // about server protocol in response and stores it in the connection. -func identify(conn Conn) (ProtocolInfo, error) { +func identify(ctx context.Context, conn Conn) (ProtocolInfo, error) { var info ProtocolInfo req := NewIdRequest(clientProtocolInfo) - if err := writeRequest(conn, req); err != nil { + if err := writeRequest(ctx, conn, req); err != nil { return info, err } - resp, err := readResponse(conn, req) + resp, err := readResponse(ctx, conn, req) if err != nil { if resp != nil && resp.Header().Error == iproto.ER_UNKNOWN_REQUEST_TYPE { @@ -495,7 +533,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error { } // authenticate authenticates for a connection. -func authenticate(c Conn, auth Auth, user string, pass string, salt string) error { +func authenticate(ctx context.Context, c Conn, auth Auth, user, pass, salt string) error { var req Request var err error @@ -511,37 +549,73 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro return errors.New("unsupported method " + auth.String()) } - if err = writeRequest(c, req); err != nil { + if err = writeRequest(ctx, c, req); err != nil { return err } - if _, err = readResponse(c, req); err != nil { + if _, err = readResponse(ctx, c, req); err != nil { return err } return nil } // writeRequest writes a request to the writer. -func writeRequest(w writeFlusher, req Request) error { +func writeRequest(ctx context.Context, conn Conn, req Request) error { var packet smallWBuf err := pack(&packet, msgpack.NewEncoder(&packet), 0, req, ignoreStreamId, nil) if err != nil { return fmt.Errorf("pack error: %w", err) } - if _, err = w.Write(packet.b); err != nil { + + doneWrite, doneWait := ioWaiter(ctx, conn) + + _, err = conn.Write(packet.b) + + close(doneWrite) + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + + if err != nil { return fmt.Errorf("write error: %w", err) } - if err = w.Flush(); err != nil { + + doneWrite, doneWait = ioWaiter(ctx, conn) + + err = conn.Flush() + + close(doneWrite) + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + + if err != nil { return fmt.Errorf("flush error: %w", err) } + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + return err } // readResponse reads a response from the reader. -func readResponse(r io.Reader, req Request) (Response, error) { +func readResponse(ctx context.Context, conn Conn, req Request) (Response, error) { var lenbuf [packetLengthBytes]byte - respBytes, err := read(r, lenbuf[:]) + doneRead, doneWait := ioWaiter(ctx, conn) + + respBytes, err := read(conn, lenbuf[:]) + + close(doneRead) + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + if err != nil { return nil, fmt.Errorf("read error: %w", err) } @@ -555,10 +629,12 @@ func readResponse(r io.Reader, req Request) (Response, error) { if err != nil { return nil, fmt.Errorf("decode response header error: %w", err) } + resp, err := req.Response(header, &buf) if err != nil { return nil, fmt.Errorf("creating response error: %w", err) } + _, err = resp.Decode() if err != nil { switch err.(type) { @@ -568,5 +644,6 @@ func readResponse(r io.Reader, req Request) (Response, error) { return resp, fmt.Errorf("decode response body error: %w", err) } } + return resp, nil } diff --git a/dial_test.go b/dial_test.go index abfdf0d4f..87b9af5d8 100644 --- a/dial_test.go +++ b/dial_test.go @@ -87,6 +87,8 @@ type mockIoConn struct { readbuf, writebuf bytes.Buffer // Calls readWg/writeWg.Wait() in Read()/Flush(). readWg, writeWg sync.WaitGroup + // wgDoneOnClose call Done() on the wait groups on Close(). + wgDoneOnClose bool // How many times to wait before a wg.Wait() call. readWgDelay, writeWgDelay int // Write()/Read()/Flush()/Close() calls count. @@ -137,6 +139,12 @@ func (m *mockIoConn) Flush() error { } func (m *mockIoConn) Close() error { + if m.wgDoneOnClose { + m.readWg.Done() + m.writeWg.Done() + m.wgDoneOnClose = false + } + m.closeCnt++ return nil } @@ -165,6 +173,7 @@ func newMockIoConn() *mockIoConn { conn := new(mockIoConn) conn.readWg.Add(1) conn.writeWg.Add(1) + conn.wgDoneOnClose = true return conn } @@ -201,9 +210,6 @@ func TestConn_Close(t *testing.T) { conn.Close() assert.Equal(t, 1, dialer.conn.closeCnt) - - dialer.conn.readWg.Done() - dialer.conn.writeWg.Done() } type stubAddr struct { @@ -224,8 +230,6 @@ func TestConn_Addr(t *testing.T) { conn.addr = stubAddr{str: addr} }) defer func() { - dialer.conn.readWg.Done() - dialer.conn.writeWg.Done() conn.Close() }() @@ -242,8 +246,6 @@ func TestConn_Greeting(t *testing.T) { conn.greeting = greeting }) defer func() { - dialer.conn.readWg.Done() - dialer.conn.writeWg.Done() conn.Close() }() @@ -263,8 +265,6 @@ func TestConn_ProtocolInfo(t *testing.T) { conn.info = info }) defer func() { - dialer.conn.readWg.Done() - dialer.conn.writeWg.Done() conn.Close() }() @@ -284,6 +284,7 @@ func TestConn_ReadWrite(t *testing.T) { 0x01, 0xce, 0x00, 0x00, 0x00, 0x02, 0x80, // Body map. }) + conn.wgDoneOnClose = false }) defer func() { dialer.conn.writeWg.Done() @@ -579,6 +580,24 @@ func TestNetDialer_Dial(t *testing.T) { } } +func TestNetDialer_Dial_hang_connection(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + + dialer := tarantool.NetDialer{ + Address: l.Addr().String(), + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + require.Nil(t, conn) + require.Error(t, err, context.DeadlineExceeded) +} + func TestNetDialer_Dial_requirements(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -685,6 +704,7 @@ func TestAuthDialer_Dial_DialerError(t *testing.T) { ctx, cancel := test_helpers.GetConnectContext() defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) if conn != nil { conn.Close() @@ -717,6 +737,38 @@ func TestAuthDialer_Dial_NoSalt(t *testing.T) { } } +func TestConn_AuthDialer_hang_connection(t *testing.T) { + salt := fmt.Sprintf("%s", testDialSalt) + salt = base64.StdEncoding.EncodeToString([]byte(salt)) + mock := &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.greeting.Salt = salt + conn.readWgDelay = 0 + conn.writeWgDelay = 0 + }, + } + dialer := tarantool.AuthDialer{ + Dialer: mock, + Username: "test", + Password: "test", + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := tarantool.Connect(ctx, &dialer, + tarantool.Opts{ + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) + + require.Nil(t, conn) + require.Error(t, err, context.DeadlineExceeded) + require.Equal(t, mock.conn.writeCnt, 1) + require.Equal(t, mock.conn.readCnt, 0) + require.Greater(t, mock.conn.closeCnt, 1) +} + func TestAuthDialer_Dial(t *testing.T) { salt := fmt.Sprintf("%s", testDialSalt) salt = base64.StdEncoding.EncodeToString([]byte(salt)) @@ -726,6 +778,7 @@ func TestAuthDialer_Dial(t *testing.T) { conn.writeWgDelay = 1 conn.readWgDelay = 2 conn.readbuf.Write(okResponse) + conn.wgDoneOnClose = false }, } defer func() { @@ -758,6 +811,7 @@ func TestAuthDialer_Dial_PapSha256Auth(t *testing.T) { conn.writeWgDelay = 1 conn.readWgDelay = 2 conn.readbuf.Write(okResponse) + conn.wgDoneOnClose = false }, } defer func() { @@ -800,6 +854,34 @@ func TestProtocolDialer_Dial_DialerError(t *testing.T) { assert.EqualError(t, err, "some error") } +func TestConn_ProtocolDialer_hang_connection(t *testing.T) { + mock := &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.readWgDelay = 0 + conn.writeWgDelay = 0 + }, + } + dialer := tarantool.ProtocolDialer{ + Dialer: mock, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := tarantool.Connect(ctx, &dialer, + tarantool.Opts{ + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) + + require.Nil(t, conn) + require.Error(t, err, context.DeadlineExceeded) + require.Equal(t, mock.conn.writeCnt, 1) + require.Equal(t, mock.conn.readCnt, 0) + require.Greater(t, mock.conn.closeCnt, 1) +} + func TestProtocolDialer_Dial_IdentifyFailed(t *testing.T) { dialer := tarantool.ProtocolDialer{ Dialer: &mockIoDialer{ @@ -898,6 +980,31 @@ func TestGreetingDialer_Dial_DialerError(t *testing.T) { assert.EqualError(t, err, "some error") } +func TestConn_GreetingDialer_hang_connection(t *testing.T) { + mock := &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.readWgDelay = 0 + }, + } + dialer := tarantool.GreetingDialer{ + Dialer: mock, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := tarantool.Connect(ctx, &dialer, + tarantool.Opts{ + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) + + require.Nil(t, conn) + require.Error(t, err, context.DeadlineExceeded) + require.Equal(t, mock.conn.readCnt, 1) + require.Greater(t, mock.conn.closeCnt, 1) +} + func TestGreetingDialer_Dial_GreetingFailed(t *testing.T) { dialer := tarantool.GreetingDialer{ Dialer: &mockIoDialer{