diff --git a/internal/integration/csot_test.go b/internal/integration/csot_test.go new file mode 100644 index 0000000000..3112ba8be0 --- /dev/null +++ b/internal/integration/csot_test.go @@ -0,0 +1,522 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package integration + +import ( + "context" + "errors" + "testing" + "time" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/event" + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/eventtest" + "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" + "go.mongodb.org/mongo-driver/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" +) + +// Test automatic "maxTimeMS" appending and connection closing behavior. +func TestCSOT_maxTimeMS(t *testing.T) { + mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) + + testCases := []struct { + desc string + commandName string + setup func(coll *mongo.Collection) error + operation func(ctx context.Context, coll *mongo.Collection) error + sendsMaxTimeMS bool + topologies []mtest.TopologyKind + }{ + { + desc: "FindOne", + commandName: "find", + setup: func(coll *mongo.Collection) error { + _, err := coll.InsertOne(context.Background(), bson.D{}) + return err + }, + operation: func(ctx context.Context, coll *mongo.Collection) error { + return coll.FindOne(ctx, bson.D{}).Err() + }, + sendsMaxTimeMS: true, + }, + { + desc: "Find", + commandName: "find", + setup: func(coll *mongo.Collection) error { + _, err := coll.InsertOne(context.Background(), bson.D{}) + return err + }, + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.Find(ctx, bson.D{}) + return err + }, + sendsMaxTimeMS: false, + }, + { + desc: "FindOneAndDelete", + commandName: "findAndModify", + setup: func(coll *mongo.Collection) error { + _, err := coll.InsertOne(context.Background(), bson.D{}) + return err + }, + operation: func(ctx context.Context, coll *mongo.Collection) error { + return coll.FindOneAndDelete(ctx, bson.D{}).Err() + }, + sendsMaxTimeMS: true, + }, + { + desc: "FindOneAndUpdate", + commandName: "findAndModify", + setup: func(coll *mongo.Collection) error { + _, err := coll.InsertOne(context.Background(), bson.D{}) + return err + }, + operation: func(ctx context.Context, coll *mongo.Collection) error { + return coll.FindOneAndUpdate(ctx, bson.D{}, bson.M{"$set": bson.M{"key": "value"}}).Err() + }, + sendsMaxTimeMS: true, + }, + { + desc: "FindOneAndReplace", + commandName: "findAndModify", + setup: func(coll *mongo.Collection) error { + _, err := coll.InsertOne(context.Background(), bson.D{}) + return err + }, + operation: func(ctx context.Context, coll *mongo.Collection) error { + return coll.FindOneAndReplace(ctx, bson.D{}, bson.D{}).Err() + }, + sendsMaxTimeMS: true, + }, + { + desc: "InsertOne", + commandName: "insert", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.InsertOne(ctx, bson.D{}) + return err + }, + sendsMaxTimeMS: true, + }, + { + desc: "InsertMany", + commandName: "insert", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.InsertMany(ctx, []interface{}{bson.D{}}) + return err + }, + sendsMaxTimeMS: true, + }, + { + desc: "UpdateOne", + commandName: "update", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.UpdateOne(ctx, bson.D{}, bson.M{"$set": bson.M{"key": "value"}}) + return err + }, + sendsMaxTimeMS: true, + }, + { + desc: "UpdateMany", + commandName: "update", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.UpdateMany(ctx, bson.D{}, bson.M{"$set": bson.M{"key": "value"}}) + return err + }, + sendsMaxTimeMS: true, + }, + { + desc: "ReplaceOne", + commandName: "update", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.ReplaceOne(ctx, bson.D{}, bson.D{}) + return err + }, + sendsMaxTimeMS: true, + }, + { + desc: "DeleteOne", + commandName: "delete", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.DeleteOne(ctx, bson.D{}) + return err + }, + sendsMaxTimeMS: true, + }, + { + desc: "DeleteMany", + commandName: "delete", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.DeleteMany(ctx, bson.D{}) + return err + }, + sendsMaxTimeMS: true, + }, + { + desc: "Distinct", + commandName: "distinct", + operation: func(ctx context.Context, coll *mongo.Collection) error { + return coll.Distinct(ctx, "name", bson.D{}).Err() + }, + sendsMaxTimeMS: true, + }, + { + desc: "Aggregate", + commandName: "aggregate", + operation: func(ctx context.Context, coll *mongo.Collection) error { + _, err := coll.Aggregate(ctx, mongo.Pipeline{}) + return err + }, + sendsMaxTimeMS: false, + }, + { + desc: "Watch", + commandName: "aggregate", + operation: func(ctx context.Context, coll *mongo.Collection) error { + cs, err := coll.Watch(ctx, mongo.Pipeline{}) + if cs != nil { + cs.Close(context.Background()) + } + return err + }, + sendsMaxTimeMS: true, + // Change Streams aren't supported on standalone topologies. + topologies: []mtest.TopologyKind{ + mtest.ReplicaSet, + mtest.Sharded, + }, + }, + { + desc: "Cursor getMore", + commandName: "getMore", + setup: func(coll *mongo.Collection) error { + _, err := coll.InsertMany(context.Background(), []interface{}{bson.D{}, bson.D{}}) + return err + }, + operation: func(ctx context.Context, coll *mongo.Collection) error { + cursor, err := coll.Find(ctx, bson.D{}, options.Find().SetBatchSize(1)) + if err != nil { + return err + } + var res []bson.D + return cursor.All(ctx, &res) + }, + sendsMaxTimeMS: false, + }, + } + + // getStartedEvent returns the first command started event that matches the + // specified command name. + getStartedEvent := func(mt *mtest.T, command string) *event.CommandStartedEvent { + for { + evt := mt.GetStartedEvent() + if evt == nil { + break + } + _, err := evt.Command.LookupErr(command) + if errors.Is(err, bsoncore.ErrElementNotFound) { + continue + } + return evt + } + + mt.Errorf("could not find command started event for command %q", command) + mt.FailNow() + return nil + } + + // assertMaxTimeMSIsSet asserts that "maxTimeMS" is set to a positive value + // on the given command document. + assertMaxTimeMSIsSet := func(mt *mtest.T, command bson.Raw) { + mt.Helper() + + maxTimeVal := command.Lookup("maxTimeMS") + + require.Greater(mt, + len(maxTimeVal.Value), + 0, + "expected maxTimeMS BSON value to be non-empty") + require.Equal(mt, + maxTimeVal.Type, + bson.TypeInt64, + "expected maxTimeMS BSON value to be type Int64") + assert.Greater(mt, + maxTimeVal.Int64(), + int64(0), + "expected maxTimeMS value to be greater than 0") + } + + // assertMaxTimeMSIsSet asserts that "maxTimeMS" is not set on the given + // command document. + assertMaxTimeMSNotSet := func(mt *mtest.T, command bson.Raw) { + mt.Helper() + + _, err := command.LookupErr("maxTimeMS") + assert.ErrorIs(mt, + err, + bsoncore.ErrElementNotFound, + "expected maxTimeMS BSON value to be missing, but is present") + } + + for _, tc := range testCases { + mt.RunOpts(tc.desc, mtest.NewOptions().Topologies(tc.topologies...), func(mt *mtest.T) { + mt.Run("timeoutMS not set", func(mt *mtest.T) { + if tc.setup != nil { + err := tc.setup(mt.Coll) + require.NoError(mt, err) + } + + err := tc.operation(context.Background(), mt.Coll) + require.NoError(mt, err) + + evt := getStartedEvent(mt, tc.commandName) + assertMaxTimeMSNotSet(mt, evt.Command) + }) + + mt.Run("Context with deadline", func(mt *mtest.T) { + if tc.setup != nil { + err := tc.setup(mt.Coll) + require.NoError(mt, err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := tc.operation(ctx, mt.Coll) + require.NoError(mt, err) + + evt := getStartedEvent(mt, tc.commandName) + if tc.sendsMaxTimeMS { + assertMaxTimeMSIsSet(mt, evt.Command) + } else { + assertMaxTimeMSNotSet(mt, evt.Command) + } + }) + + csotOpts := mtest.NewOptions(). + ClientOptions(options.Client().SetTimeout(10 * time.Second)) + mt.RunOpts("timeoutMS and context.Background", csotOpts, func(mt *mtest.T) { + if tc.setup != nil { + err := tc.setup(mt.Coll) + require.NoError(mt, err) + } + + err := tc.operation(context.Background(), mt.Coll) + require.NoError(mt, err) + + evt := getStartedEvent(mt, tc.commandName) + if tc.sendsMaxTimeMS { + assertMaxTimeMSIsSet(mt, evt.Command) + } else { + assertMaxTimeMSNotSet(mt, evt.Command) + } + }) + + mt.RunOpts("timeoutMS and Context with deadline", csotOpts, func(mt *mtest.T) { + if tc.setup != nil { + err := tc.setup(mt.Coll) + require.NoError(mt, err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := tc.operation(ctx, mt.Coll) + require.NoError(mt, err) + + evt := getStartedEvent(mt, tc.commandName) + if tc.sendsMaxTimeMS { + assertMaxTimeMSIsSet(mt, evt.Command) + } else { + assertMaxTimeMSNotSet(mt, evt.Command) + } + }) + + opts := mtest.NewOptions(). + // Blocking failpoints don't work on pre-4.2 and sharded + // clusters. + Topologies(mtest.Single, mtest.ReplicaSet). + MinServerVersion("4.2") + mt.RunOpts("prevents connection closure", opts, func(mt *mtest.T) { + if tc.setup != nil { + err := tc.setup(mt.Coll) + require.NoError(mt, err) + } + + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: "alwaysOn", + Data: mtest.FailPointData{ + FailCommands: []string{tc.commandName}, + BlockConnection: true, + // Note that some operations (currently Find and + // Aggregate) do not send maxTimeMS by default, meaning + // that the server will only respond after BlockTimeMS + // is elapsed. If the amount of time that the driver + // waits for responses after a timeout is significantly + // lower than BlockTimeMS, this test will start failing + // for those operations. + BlockTimeMS: 500, + }, + }) + + tpm := eventtest.NewTestPoolMonitor() + mt.ResetClient(options.Client(). + SetPoolMonitor(tpm.PoolMonitor)) + + // Run 5 operations that time out, then assert that no + // connections were closed. + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond) + err := tc.operation(ctx, mt.Coll) + cancel() + + if !mongo.IsTimeout(err) { + t.Logf("Operation %d returned a non-timeout error: %v", i, err) + } + } + + closedEvents := tpm.Events(func(pe *event.PoolEvent) bool { + return pe.Type == event.ConnectionClosed + }) + assert.Len(mt, closedEvents, 0, "expected no connection closed event") + }) + }) + } + + csotOpts := mtest.NewOptions().ClientOptions(options.Client().SetTimeout(10 * time.Second)) + mt.RunOpts("omitted for values greater than 2147483647ms", csotOpts, func(mt *mtest.T) { + ctx, cancel := context.WithTimeout(context.Background(), (2147483647+1000)*time.Millisecond) + defer cancel() + _, err := mt.Coll.InsertOne(ctx, bson.D{}) + require.NoError(t, err) + + evt := mt.GetStartedEvent() + _, err = evt.Command.LookupErr("maxTimeMS") + assert.ErrorIs(mt, + err, + bsoncore.ErrElementNotFound, + "expected maxTimeMS BSON value to be missing, but is present") + }) +} + +func TestCSOT_errors(t *testing.T) { + mt := mtest.New(t, mtest.NewOptions(). + CreateClient(false). + // Blocking failpoints don't work on pre-4.2 and sharded clusters. + Topologies(mtest.Single, mtest.ReplicaSet). + MinServerVersion("4.2"). + // Enable CSOT. + ClientOptions(options.Client().SetTimeout(10*time.Second))) + + // Test that, when CSOT is enabled, the error returned when the database + // returns a MaxTimeMSExceeded error (error code 50) wraps + // "context.DeadlineExceeded". + mt.Run("MaxTimeMSExceeded wraps context.DeadlineExceeded", func(mt *mtest.T) { + _, err := mt.Coll.InsertOne(context.Background(), bson.D{}) + require.NoError(mt, err, "InsertOne error") + + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"find"}, + ErrorCode: 50, // MaxTimeMSExceeded + }, + }) + + err = mt.Coll.FindOne(context.Background(), bson.D{}).Err() + + assert.True(mt, + errors.Is(err, context.DeadlineExceeded), + "expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded", + err) + assert.True(mt, + mongo.IsTimeout(err), + "expected error %[1]T(%[1]q) to be a timeout error", + err) + }) + + // Test that, when CSOT is enabled, the error returned when a context + // deadline is exceeded during a network operation wraps + // "context.DeadlineExceeded". + mt.Run("Context timeout wraps context.DeadlineExceeded", func(mt *mtest.T) { + _, err := mt.Coll.InsertOne(context.Background(), bson.D{}) + require.NoError(mt, err, "InsertOne error") + + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"find"}, + BlockConnection: true, + BlockTimeMS: 500, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond) + defer cancel() + err = mt.Coll.FindOne(ctx, bson.D{}).Err() + + assert.False(mt, + errors.Is(err, driver.ErrDeadlineWouldBeExceeded), + "expected error %[1]T(%[1]q) to not wrap driver.ErrDeadlineWouldBeExceeded", + err) + assert.True(mt, + errors.Is(err, context.DeadlineExceeded), + "expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded", + err) + assert.True(mt, + mongo.IsTimeout(err), + "expected error %[1]T(%[1]q) to be a timeout error", + err) + }) + + mt.Run("timeoutMS timeout wraps context.DeadlineExceeded", func(mt *mtest.T) { + _, err := mt.Coll.InsertOne(context.Background(), bson.D{}) + require.NoError(mt, err, "InsertOne error") + + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"find"}, + BlockConnection: true, + BlockTimeMS: 100, + }, + }) + + // Set timeoutMS=10 to run the FindOne, then unset it so the mtest + // cleanup operations pass successfully (e.g. unsetting failpoints). + mt.ResetClient(options.Client().SetTimeout(10 * time.Millisecond)) + defer mt.ResetClient(options.Client()) + err = mt.Coll.FindOne(context.Background(), bson.D{}).Err() + + assert.False(mt, + errors.Is(err, driver.ErrDeadlineWouldBeExceeded), + "expected error %[1]T(%[1]q) to not wrap driver.ErrDeadlineWouldBeExceeded", + err) + assert.True(mt, + errors.Is(err, context.DeadlineExceeded), + "expected error %[1]T(%[1]q) to wrap context.DeadlineExceeded", + err) + assert.True(mt, + mongo.IsTimeout(err), + "expected error %[1]T(%[1]q) to be a timeout error", + err) + }) +} diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index 1e3879b9ef..fb744f0b30 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -208,7 +208,7 @@ func (t *T) cleanup() { // Run creates a new T instance for a sub-test and runs the given callback. It also creates a new collection using the // given name which is available to the callback through the T.Coll variable and is dropped after the callback // returns. -func (t *T) Run(name string, callback func(*T)) { +func (t *T) Run(name string, callback func(mt *T)) { t.RunOpts(name, NewOptions(), callback) } @@ -216,7 +216,7 @@ func (t *T) Run(name string, callback func(*T)) { // constraints specified in the options, the new sub-test will be skipped automatically. If the test is not skipped, // the callback will be run with the new T instance. RunOpts creates a new collection with the given name which is // available to the callback through the T.Coll variable and is dropped after the callback returns. -func (t *T) RunOpts(name string, opts *Options, callback func(*T)) { +func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) { t.T.Run(name, func(wrapped *testing.T) { sub := newT(wrapped, t.baseOpts, opts) diff --git a/internal/integration/unified/unified_spec_runner.go b/internal/integration/unified/unified_spec_runner.go index 8ec5d14454..23605e8991 100644 --- a/internal/integration/unified/unified_spec_runner.go +++ b/internal/integration/unified/unified_spec_runner.go @@ -61,6 +61,31 @@ var ( "unpin when a new transaction is started": "Implement GODRIVER-3034", "unpin when a non-transaction write operation uses a session": "Implement GODRIVER-3034", "unpin when a non-transaction read operation uses a session": "Implement GODRIVER-3034", + + // DRIVERS-2722: Setting "maxTimeMS" on a command that creates a cursor + // also limits the lifetime of the cursor. That may be surprising to + // users, so omit "maxTimeMS" from operations that return user-managed + // cursors. + "timeoutMS can be overridden for a find": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "timeoutMS can be configured for an operation - find on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "timeoutMS can be configured for an operation - aggregate on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "timeoutMS can be configured for an operation - aggregate on database": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "timeoutMS can be configured on a MongoClient - find on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "timeoutMS can be configured on a MongoClient - aggregate on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "timeoutMS can be configured on a MongoClient - aggregate on database": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "operation is retried multiple times for non-zero timeoutMS - find on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "operation is retried multiple times for non-zero timeoutMS - aggregate on collection": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "operation is retried multiple times for non-zero timeoutMS - aggregate on database": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + "timeoutMS applied to find command": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", + + // DRIVERS-2953: This test requires that the driver sends a "getMore" + // with "maxTimeMS" set. However, "getMore" can only include "maxTimeMS" + // for tailable awaitData cursors. Including "maxTimeMS" on "getMore" + // for any other cursor type results in a server error: + // + // (BadValue) cannot set maxTimeMS on getMore command for a non-awaitData cursor + // + "Non-tailable cursor lifetime remaining timeoutMS applied to getMore if timeoutMode is unset": "maxTimeMS can't be set on a getMore. See DRIVERS-2953", } logMessageValidatorTimeout = 10 * time.Millisecond diff --git a/mongo/collection.go b/mongo/collection.go index a73ff90760..5ba3113c0c 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -949,7 +949,12 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption ServerAPI(a.client.serverAPI). HasOutputStage(hasOutputStage). Timeout(a.client.timeout). - Authenticator(a.client.authenticator) + Authenticator(a.client.authenticator). + // Omit "maxTimeMS" from operations that return a user-managed cursor to + // prevent confusing "cursor not found" errors. + // + // See DRIVERS-2722 for more detail. + OmitMaxTimeMS(true) if args.AllowDiskUse != nil { op.AllowDiskUse(*args.AllowDiskUse) @@ -1293,11 +1298,20 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, if err != nil { return nil, err } - return coll.find(ctx, filter, args) + + // Omit "maxTimeMS" from operations that return a user-managed cursor to + // prevent confusing "cursor not found" errors. + // + // See DRIVERS-2722 for more detail. + return coll.find(ctx, filter, true, args) } -func (coll *Collection) find(ctx context.Context, filter interface{}, - args *options.FindOptions) (cur *Cursor, err error) { +func (coll *Collection) find( + ctx context.Context, + filter interface{}, + omitMaxTimeMS bool, + args *options.FindOptions, +) (cur *Cursor, err error) { if ctx == nil { ctx = context.Background() @@ -1335,7 +1349,8 @@ func (coll *Collection) find(ctx context.Context, filter interface{}, CommandMonitor(coll.client.monitor).ServerSelector(selector). ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) + Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator). + OmitMaxTimeMS(omitMaxTimeMS) cursorOpts := coll.client.createBaseCursorOptions() @@ -1500,7 +1515,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, if err != nil { return nil } - cursor, err := coll.find(ctx, filter, newFindArgsFromFindOneArgs(args)) + cursor, err := coll.find(ctx, filter, false, newFindArgsFromFindOneArgs(args)) return &SingleResult{ ctx: ctx, cur: cursor, diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index fa0bb90665..ddf4913ef5 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -441,6 +441,12 @@ func (bc *BatchCursor) getMore(ctx context.Context) { Crypt: bc.crypt, ServerAPI: bc.serverAPI, + // Omit the automatically-calculated maxTimeMS because setting maxTimeMS + // on a non-awaitData cursor causes a server error. For awaitData + // cursors, maxTimeMS is set when maxAwaitTime is specified by the above + // CommandFn. + OmitMaxTimeMS: true, + // No read preference is passed to the getMore command, // resulting in the default read preference: "primaryPreferred". // Since this could be confusing, and there is no requirement diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go index 1a0afca5d9..61847329f2 100644 --- a/x/mongo/driver/errors.go +++ b/x/mongo/driver/errors.go @@ -509,7 +509,7 @@ func ExtractErrorFromServerResponse(doc bsoncore.Document) error { errmsg = "command failed" } - return Error{ + err := Error{ Code: code, Message: errmsg, Name: codeName, @@ -517,6 +517,20 @@ func ExtractErrorFromServerResponse(doc bsoncore.Document) error { TopologyVersion: tv, Raw: doc, } + + // If we get a MaxTimeMSExpired error, assume that the error was caused + // by setting "maxTimeMS" on the command based on the context deadline + // or on "timeoutMS". In that case, make the error wrap + // context.DeadlineExceeded so that users can always check + // + // errors.Is(err, context.DeadlineExceeded) + // + // for either client-side or server-side timeouts. + if err.Code == 50 { + err.Wrapped = context.DeadlineExceeded + } + + return err } if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil { diff --git a/x/mongo/driver/integration/aggregate_test.go b/x/mongo/driver/integration/aggregate_test.go index b055d9f11e..0d34c4db18 100644 --- a/x/mongo/driver/integration/aggregate_test.go +++ b/x/mongo/driver/integration/aggregate_test.go @@ -10,113 +10,21 @@ import ( "bytes" "context" "testing" - "time" "go.mongodb.org/mongo-driver/v2/bson" - "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/integtest" - "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/internal/serverselector" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation" - "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" ) -func setUpMonitor() (*event.CommandMonitor, chan *event.CommandStartedEvent, chan *event.CommandSucceededEvent, chan *event.CommandFailedEvent) { - started := make(chan *event.CommandStartedEvent, 1) - succeeded := make(chan *event.CommandSucceededEvent, 1) - failed := make(chan *event.CommandFailedEvent, 1) - - return &event.CommandMonitor{ - Started: func(_ context.Context, e *event.CommandStartedEvent) { - started <- e - }, - Succeeded: func(_ context.Context, e *event.CommandSucceededEvent) { - succeeded <- e - }, - Failed: func(_ context.Context, e *event.CommandFailedEvent) { - failed <- e - }, - }, started, succeeded, failed -} - -func skipIfBelow32(ctx context.Context, t *testing.T, topo *topology.Topology) { - server, err := topo.SelectServer(ctx, &serverselector.Write{}) - noerr(t, err) - - versionCmd := bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "serverStatus", 1)) - serverStatus, err := runCommand(server, dbName, versionCmd) - noerr(t, err) - version, err := serverStatus.LookupErr("version") - noerr(t, err) - - if integtest.CompareVersions(t, version.StringValue(), "3.2") < 0 { - t.Skip() - } -} - func TestAggregate(t *testing.T) { if testing.Short() { t.Skip("skipping integration test in short mode") } - t.Run("TestMaxTimeMSInGetMore", func(t *testing.T) { - ctx := context.Background() - monitor, started, succeeded, failed := setUpMonitor() - dbName := "TestAggMaxTimeDB" - collName := "TestAggMaxTimeColl" - top := integtest.MonitoredTopology(t, dbName, monitor) - clearChannels(started, succeeded, failed) - skipIfBelow32(ctx, t, top) - - clearChannels(started, succeeded, failed) - err := operation.NewInsert( - bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)), - bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)), - bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)), - ).Collection(collName).Database(dbName). - Deployment(top).ServerSelector(&serverselector.Write{}).Execute(context.Background()) - noerr(t, err) - - clearChannels(started, succeeded, failed) - op := operation.NewAggregate(bsoncore.BuildDocumentFromElements(nil)). - Collection(collName).Database(dbName).Deployment(top).ServerSelector(&serverselector.Write{}). - CommandMonitor(monitor).BatchSize(2) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - defer cancel() - - err = op.Execute(ctx) - noerr(t, err) - batchCursor, err := op.Result(driver.CursorOptions{BatchSize: 2, CommandMonitor: monitor}) - noerr(t, err) - - var e *event.CommandStartedEvent - select { - case e = <-started: - case <-time.After(2000 * time.Millisecond): - t.Fatal("timed out waiting for aggregate") - } - - require.Equal(t, "aggregate", e.CommandName) - - clearChannels(started, succeeded, failed) - // first Next() should automatically return true - require.True(t, batchCursor.Next(ctx), "expected true from first Next, got false") - clearChannels(started, succeeded, failed) - batchCursor.Next(ctx) // should do getMore - - select { - case e = <-started: - case <-time.After(200 * time.Millisecond): - t.Fatal("timed out waiting for getMore") - } - require.Equal(t, "getMore", e.CommandName) - _, err = e.Command.LookupErr("maxTimeMS") - noerr(t, err) - }) t.Run("Multiple Batches", func(t *testing.T) { ds := []bsoncore.Document{ bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 1)), @@ -185,15 +93,3 @@ func TestAggregate(t *testing.T) { }) } - -func clearChannels(s chan *event.CommandStartedEvent, succ chan *event.CommandSucceededEvent, f chan *event.CommandFailedEvent) { - for len(s) > 0 { - <-s - } - for len(succ) > 0 { - <-succ - } - for len(f) > 0 { - <-f - } -} diff --git a/x/mongo/driver/integration/main_test.go b/x/mongo/driver/integration/main_test.go index c1b68d27d6..f1c24bfd0d 100644 --- a/x/mongo/driver/integration/main_test.go +++ b/x/mongo/driver/integration/main_test.go @@ -119,13 +119,13 @@ func addCompressorToURI(uri string) string { return uri + "compressors=" + comp } -// runCommand runs an arbitrary command on a given database of target server -func runCommand(s driver.Server, db string, cmd bsoncore.Document) (bsoncore.Document, error) { +// runCommand runs an arbitrary command on a given database of the target +// server. +func runCommand(s driver.Server, db string, cmd bsoncore.Document) error { op := operation.NewCommand(cmd). - Database(db).Deployment(driver.SingleServerDeployment{Server: s}) - err := op.Execute(context.Background()) - res := op.Result() - return res, err + Database(db). + Deployment(driver.SingleServerDeployment{Server: s}) + return op.Execute(context.Background()) } // dropCollection drops the collection in the test cluster. diff --git a/x/mongo/driver/integration/scram_test.go b/x/mongo/driver/integration/scram_test.go index 44aa88255a..d75623e051 100644 --- a/x/mongo/driver/integration/scram_test.go +++ b/x/mongo/driver/integration/scram_test.go @@ -147,8 +147,7 @@ func runScramAuthTest(t *testing.T, credential options.Credential) error { noerr(t, err) cmd := bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dbstats", 1)) - _, err = runCommand(server, integtest.DBName(t), cmd) - return err + return runCommand(server, integtest.DBName(t), cmd) } func createScramUsers(t *testing.T, s driver.Server, cases []scramTestCase) error { @@ -169,7 +168,7 @@ func createScramUsers(t *testing.T, s driver.Server, cases []scramTestCase) erro )), bsoncore.AppendArrayElement(nil, "mechanisms", bsoncore.BuildArray(nil, values...)), ) - _, err := runCommand(s, db, newUserCmd) + err := runCommand(s, db, newUserCmd) if err != nil { return fmt.Errorf("Couldn't create user '%s' on db '%s': %w", c.username, integtest.DBName(t), err) } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 35e363a1a6..e368a1e40c 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1213,7 +1213,7 @@ func (op Operation) addBatchArray(dst []byte) []byte { func (op Operation) createLegacyHandshakeWireMessage( ctx context.Context, - maxTimeMS uint64, + maxTimeMS int64, dst []byte, desc description.SelectedServer, ) ([]byte, startedInformation, error) { @@ -1272,7 +1272,7 @@ func (op Operation) createLegacyHandshakeWireMessage( // If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly // specifies the default behavior of no timeout server-side. if maxTimeMS > 0 { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS)) + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS) } dst, _ = bsoncore.AppendDocumentEnd(dst, idx) @@ -1293,7 +1293,7 @@ func (op Operation) createLegacyHandshakeWireMessage( func (op Operation) createMsgWireMessage( ctx context.Context, - maxTimeMS uint64, + maxTimeMS int64, dst []byte, desc description.SelectedServer, conn *mnet.Connection, @@ -1343,7 +1343,7 @@ func (op Operation) createMsgWireMessage( // If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly // specifies the default behavior of no timeout server-side. if maxTimeMS > 0 { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS)) + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS) } dst = bsoncore.AppendStringElement(dst, "$db", op.Database) @@ -1389,7 +1389,7 @@ func isLegacyHandshake(op Operation, desc description.SelectedServer) bool { func (op Operation) createWireMessage( ctx context.Context, - maxTimeMS uint64, + maxTimeMS int64, dst []byte, desc description.SelectedServer, conn *mnet.Connection, @@ -1666,7 +1666,7 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) // if the ctx is a Timeout context. If the context is not a Timeout context, it uses the // operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is // not a Timeout context, calculateMaxTimeMS returns 0. -func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration, rttStats string) (uint64, error) { +func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration, rttStats string) (int64, error) { if op.OmitMaxTimeMS { return 0, nil } @@ -1683,13 +1683,23 @@ func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) if maxTimeMS <= 0 { return 0, fmt.Errorf( - "remaining time %v until context deadline is less than or equal to rtt minimum: %w\n%v", + "remaining time %v until context deadline is less than or equal to min network round-trip time %v (%v): %w", remainingTimeout, - ErrDeadlineWouldBeExceeded, - rttStats) + rttMin, + rttStats, + ErrDeadlineWouldBeExceeded) } - return uint64(maxTimeMS), nil + // The server will return a "BadValue" error if maxTimeMS is greater + // than the maximum positive int32 value (about 24.9 days). If the + // user specified a timeout value greater than that, omit maxTimeMS + // and let the client-side timeout handle cancelling the op if the + // timeout is ever reached. + if maxTimeMS > math.MaxInt32 { + return 0, nil + } + + return maxTimeMS, nil } // updateClusterTimes updates the cluster times for the session and cluster clock attached to this diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index a80e8b035e..5b5fd02192 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -49,6 +49,7 @@ type Aggregate struct { hasOutputStage bool customOptions map[string]bsoncore.Value timeout *time.Duration + omitMaxTimeMS bool result driver.CursorResponse } @@ -112,6 +113,7 @@ func (a *Aggregate) Execute(ctx context.Context) error { Timeout: a.timeout, Name: driverutil.AggregateOp, Authenticator: a.authenticator, + OmitMaxTimeMS: a.omitMaxTimeMS, }.Execute(ctx) } @@ -416,3 +418,14 @@ func (a *Aggregate) Authenticator(authenticator driver.Authenticator) *Aggregate a.authenticator = authenticator return a } + +// OmitMaxTimeMS omits the automatically-calculated "maxTimeMS" from the +// command. +func (a *Aggregate) OmitMaxTimeMS(omit bool) *Aggregate { + if a == nil { + a = new(Aggregate) + } + + a.omitMaxTimeMS = omit + return a +} diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 803e2768c2..21cb92eca0 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -62,6 +62,7 @@ type Find struct { serverAPI *driver.ServerAPIOptions timeout *time.Duration logger *logger.Logger + omitMaxTimeMS bool } // NewFind constructs and returns a new Find. @@ -109,6 +110,7 @@ func (f *Find) Execute(ctx context.Context) error { Logger: f.logger, Name: driverutil.FindOp, Authenticator: f.authenticator, + OmitMaxTimeMS: f.omitMaxTimeMS, }.Execute(ctx) } @@ -559,3 +561,14 @@ func (f *Find) Authenticator(authenticator driver.Authenticator) *Find { f.authenticator = authenticator return f } + +// OmitMaxTimeMS omits the automatically-calculated "maxTimeMS" from the +// command. +func (f *Find) OmitMaxTimeMS(omit bool) *Find { + if f == nil { + f = new(Find) + } + + f.omitMaxTimeMS = omit + return f +} diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 3efee28865..8ab473f33e 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -286,7 +286,7 @@ func TestOperation(t *testing.T) { rtt RTTMonitor rttMin time.Duration rttStats string - want uint64 + want int64 err error }{ { @@ -644,6 +644,35 @@ func TestOperation(t *testing.T) { // the TransientTransactionError label. assert.Equal(t, err, context.Canceled, "expected context.Canceled error, got %v", err) }) + t.Run("ErrDeadlineWouldBeExceeded wraps context.DeadlineExceeded", func(t *testing.T) { + // Create a deployment that returns a server that reports a 90th + // percentile RTT of 1 minute. + d := new(mockDeployment) + d.returns.server = mockServer{ + conn: mnet.NewConnection(&mockConnection{}), + rttMonitor: mockRTTMonitor{min: 1 * time.Minute}, + } + + // Create an operation with a Timeout specified to enable CSOT behavior. + var dur time.Duration + op := Operation{ + Database: "foobar", + Deployment: d, + CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + return dst, nil + }, + Timeout: &dur, + } + + // Call the operation with a context with a deadline less than the 90th + // percentile RTT configured above. + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + err := op.Execute(ctx) + + assert.ErrorIs(t, err, ErrDeadlineWouldBeExceeded) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) } func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte { @@ -709,6 +738,25 @@ func (m *mockServerSelector) String() string { panic("not implemented") } +type mockServer struct { + conn *mnet.Connection + err error + rttMonitor RTTMonitor +} + +func (ms mockServer) Connection(context.Context) (*mnet.Connection, error) { return ms.conn, ms.err } +func (ms mockServer) RTTMonitor() RTTMonitor { return ms.rttMonitor } + +type mockRTTMonitor struct { + ewma time.Duration + min time.Duration + stats string +} + +func (mrm mockRTTMonitor) EWMA() time.Duration { return mrm.ewma } +func (mrm mockRTTMonitor) Min() time.Duration { return mrm.min } +func (mrm mockRTTMonitor) Stats() string { return mrm.stats } + type mockConnection struct { // parameters pWriteWM []byte diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 4455516cae..34e04be1b6 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -79,6 +79,10 @@ type connection struct { // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate // accessTokens in the OIDC authenticator cache. oidcTokenGenID uint64 + + // awaitingResponse indicates that the server response was not completely + // read before returning the connection to the pool. + awaitingResponse bool } // newConnection handles the creation of a connection. It does not connect the connection. @@ -373,8 +377,16 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { dst, errMsg, err := c.read(ctx) if err != nil { - // We closeConnection the connection because we don't know if there are other bytes left to read. - c.close() + if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() { + // If the error was a timeout error, instead of closing the + // connection mark it as awaiting response so the pool can read the + // response before making it available to other operations. + c.awaitingResponse = true + } else { + // Otherwise, and close the connection because we don't know what + // the connection state is. + c.close() + } message := errMsg if errors.Is(err, io.EOF) { message = "socket was unexpectedly closed" diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index ecb0246c7d..e9a9e8ef20 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -767,6 +767,80 @@ func (p *pool) removeConnection(conn *connection, reason reason, err error) erro return nil } +var ( + // BGReadTimeout is the maximum amount of the to wait when trying to read + // the server reply on a connection after an operation timed out. The + // default is 400ms. + // + // Deprecated: BGReadTimeout is intended for internal use only and may be + // removed or modified at any time. + BGReadTimeout = 400 * time.Millisecond + + // BGReadCallback is a callback for monitoring the behavior of the + // background-read-on-timeout connection preserving mechanism. + // + // Deprecated: BGReadCallback is intended for internal use only and may be + // removed or modified at any time. + BGReadCallback func(addr string, start, read time.Time, errs []error, connClosed bool) +) + +// bgRead sets a new read deadline on the provided connection and tries to read +// any bytes returned by the server. If successful, it checks the connection +// into the provided pool. If there are any errors, it closes the connection. +// +// It calls the package-global BGReadCallback function, if set, with the +// address, timings, and any errors that occurred. +func bgRead(pool *pool, conn *connection) { + var start, read time.Time + start = time.Now() + errs := make([]error, 0) + connClosed := false + + defer func() { + // No matter what happens, always check the connection back into the + // pool, which will either make it available for other operations or + // remove it from the pool if it was closed. + err := pool.checkInNoEvent(conn) + if err != nil { + errs = append(errs, fmt.Errorf("error checking in: %w", err)) + } + + if BGReadCallback != nil { + BGReadCallback(conn.addr.String(), start, read, errs, connClosed) + } + }() + + err := conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout)) + if err != nil { + errs = append(errs, fmt.Errorf("error setting a read deadline: %w", err)) + + connClosed = true + err := conn.close() + if err != nil { + errs = append(errs, fmt.Errorf("error closing conn after setting read deadline: %w", err)) + } + + return + } + + // The context here is only used for cancellation, not deadline timeout, so + // use context.Background(). The read timeout is set by calling + // SetReadDeadline above. + _, _, err = conn.read(context.Background()) + read = time.Now() + if err != nil { + errs = append(errs, fmt.Errorf("error reading: %w", err)) + + connClosed = true + err := conn.close() + if err != nil { + errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err)) + } + + return + } +} + // checkIn returns an idle connection to the pool. If the connection is perished or the pool is // closed, it is removed from the connection pool and closed. func (p *pool) checkIn(conn *connection) error { @@ -806,6 +880,20 @@ func (p *pool) checkInNoEvent(conn *connection) error { return ErrWrongPool } + // If the connection has an awaiting server response, try to read the + // response in another goroutine before checking it back into the pool. + // + // Do this here because we want to publish checkIn events when the operation + // is done with the connection, not when it's ready to be used again. That + // means that connections in "awaiting response" state are checked in but + // not usable, which is not covered by the current pool events. We may need + // to add pool event information in the future to communicate that. + if conn.awaitingResponse { + conn.awaitingResponse = false + go bgRead(p, conn) + return nil + } + // Bump the connection idle deadline here because we're about to make the connection "available". // The idle deadline is used to determine when a connection has reached its max idle time and // should be closed. A connection reaches its max idle time when it has been "available" in the diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 525822a16d..88856b6b50 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -261,6 +261,8 @@ func (r *rttMonitor) Stats() string { r.mu.RLock() defer r.mu.RUnlock() - return fmt.Sprintf(`Round-trip-time monitor statistics:`+"\n"+ - `moving average RTT: %v, minimum RTT: %v`+"\n", r.averageRTT, r.minRTT) + return fmt.Sprintf( + "network round-trip time stats: moving avg: %v, min: %v", + r.averageRTT, + r.minRTT) }