diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index d859981f85..4003506e8d 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -651,6 +651,27 @@ func TestClient(t *testing.T) { "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) } }) + + // Test that OP_MSG is used for handshakes when loadBalanced is true. + opMsgLBOpts := mtest.NewOptions().ClientType(mtest.Proxy).MinServerVersion("5.0").Topologies(mtest.LoadBalanced) + mt.RunOpts("OP_MSG used for handshakes when loadBalanced is true", opMsgLBOpts, func(mt *mtest.T) { + err := mt.Client.Ping(context.Background(), mtest.PrimaryRp) + assert.Nil(mt, err, "Ping error: %v", err) + + msgPairs := mt.GetProxiedMessages() + assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs)) + + // First three messages should be connection handshakes: one for the heartbeat connection, another for the + // application connection, and a final one for the RTT monitor connection. + for idx, pair := range msgPairs[:3] { + assert.Equal(mt, "hello", pair.CommandName, "expected command name 'hello' at index %d, got %s", idx, + pair.CommandName) + + // Assert that appended OpCode is OP_MSG when loadBalanced is true. + assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, + "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) + } + }) } func TestClientStress(t *testing.T) { diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index e5fffa8d44..993ef634f9 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -813,10 +813,10 @@ func (Operation) decompressWireMessage(wm []byte) ([]byte, error) { func (op Operation) createWireMessage(ctx context.Context, dst []byte, desc description.SelectedServer, conn Connection) ([]byte, startedInformation, error) { - - // If API version is not declared and wire version is unknown or less than 6, use OP_QUERY. - // Otherwise, use OP_MSG. - if op.ServerAPI == nil && (desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion) { + // If topology is not LoadBalanced, API version is not declared, and wire version is unknown + // or less than 6, use OP_QUERY. Otherwise, use OP_MSG. + if desc.Kind != description.LoadBalanced && op.ServerAPI == nil && + (desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion) { return op.createQueryWireMessage(dst, desc) } return op.createMsgWireMessage(ctx, dst, desc, conn) diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index d222ada93d..5e75c08eef 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -159,7 +159,9 @@ func (h *Hello) handshakeCommand(dst []byte, desc description.SelectedServer) ([ // command appends all necessary command fields. func (h *Hello) command(dst []byte, desc description.SelectedServer) ([]byte, error) { - if h.serverAPI != nil || desc.Server.HelloOK { + // Use "hello" if topology is LoadBalanced, API version is declared or server + // has responded with "helloOk". Otherwise, use legacy hello. + if desc.Kind == description.LoadBalanced || h.serverAPI != nil || desc.Server.HelloOK { dst = bsoncore.AppendInt32Element(dst, "hello", 1) } else { dst = bsoncore.AppendInt32Element(dst, internal.LegacyHello, 1)