Skip to content

Commit f76320a

Browse files
GODRIVER-3616 Apply client-level timeout to tailable cursors (#2174)
1 parent 54bab6d commit f76320a

File tree

6 files changed

+240
-102
lines changed

6 files changed

+240
-102
lines changed

internal/integration/cursor_test.go

Lines changed: 157 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -319,116 +319,192 @@ func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 {
319319
return got
320320
}
321321

322-
func TestCursor_tailableAwaitData(t *testing.T) {
323-
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
322+
func tadcFindFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
323+
mt.Helper()
324324

325-
cappedOpts := options.CreateCollection().SetCapped(true).
326-
SetSizeInBytes(1024 * 64)
325+
initCollection(mt, &coll)
326+
cur, err := coll.Find(ctx, bson.D{{"__nomatch", 1}},
327+
options.Find().SetBatchSize(1).SetCursorType(options.TailableAwait))
328+
require.NoError(mt, err, "Find error: %v", err)
327329

328-
// TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS.
329-
mtOpts := mtest.NewOptions().MinServerVersion("4.4").
330-
Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single).
331-
CollectionCreateOptions(cappedOpts)
330+
return cur
331+
}
332332

333-
mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) {
334-
initCollection(mt, mt.Coll)
333+
func tadcAggregateFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
334+
mt.Helper()
335335

336-
// Create a 30ms failpoint for getMore.
337-
mt.SetFailPoint(failpoint.FailPoint{
338-
ConfigureFailPoint: "failCommand",
339-
Mode: failpoint.Mode{
340-
Times: 1,
341-
},
342-
Data: failpoint.Data{
343-
FailCommands: []string{"getMore"},
344-
BlockConnection: true,
345-
BlockTimeMS: 30,
346-
},
347-
})
336+
initCollection(mt, &coll)
337+
opts := options.Aggregate()
338+
pipeline := mongo.Pipeline{{{"$changeStream", bson.D{{"fullDocument", "default"}}}},
339+
{{"$match", bson.D{
340+
{"operationType", "insert"},
341+
{"fullDocment.__nomatch", 1},
342+
}}},
343+
}
348344

349-
// Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData
350-
// cursor type.
351-
opts := options.Find().
352-
SetBatchSize(1).
353-
SetMaxAwaitTime(100 * time.Millisecond).
354-
SetCursorType(options.TailableAwait)
345+
cursor, err := coll.Aggregate(ctx, pipeline, opts)
346+
require.NoError(mt, err, "Aggregate error: %v", err)
355347

356-
cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts)
357-
require.NoError(mt, err)
348+
return cursor
349+
}
358350

359-
defer cursor.Close(context.Background())
351+
func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
352+
mt.Helper()
360353

361-
// Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying
362-
// getMore loop should run at least two times: the first getMore will block
363-
// for 30ms on the getMore and then an additional 100ms for the
364-
// maxAwaitTimeMS. The second getMore will then use the remaining ~70ms
365-
// left on the timeout.
366-
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
367-
defer cancel()
354+
initCollection(mt, &coll)
368355

369-
// Iterate twice to force a getMore
370-
cursor.Next(ctx)
356+
cur, err := coll.Database().RunCommandCursor(ctx, bson.D{
357+
{"find", coll.Name()},
358+
{"filter", bson.D{{"__nomatch", 1}}},
359+
{"tailable", true},
360+
{"awaitData", true},
361+
{"batchSize", int32(1)},
362+
})
363+
require.NoError(mt, err, "RunCommandCursor error: %v", err)
371364

372-
mt.ClearEvents()
373-
cursor.Next(ctx)
365+
return cur
366+
}
374367

375-
require.Error(mt, cursor.Err(), "expected error from cursor.Next")
376-
assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
368+
// For tailable awaitData cursors, the maxTimeMS for a getMore should be
369+
// min(maxAwaitTimeMS, remaining timeoutMS - minRoundTripTime) to allow the
370+
// server more opportunities to respond with an empty batch before a
371+
// client-side timeout.
372+
func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) {
373+
// These values reflect what is used in the unified spec tests, see
374+
// DRIVERS-2868.
375+
const timeoutMS = 200
376+
const maxAwaitTimeMS = 100
377+
const blockTimeMS = 30
378+
const getMoreBound = 71
379+
380+
// TODO(GODRIVER-3328): mongos doesn't honor a failpoint's full blockTimeMS.
381+
baseTopologies := []mtest.TopologyKind{mtest.Single, mtest.LoadBalanced, mtest.ReplicaSet}
382+
383+
type testCase struct {
384+
name string
385+
factory func(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor
386+
opTimeout bool
387+
topologies []mtest.TopologyKind
388+
}
377389

378-
// Collect all started events to find the getMore commands.
379-
startedEvents := mt.GetAllStartedEvents()
390+
cases := []testCase{
391+
// TODO(GODRIVER-2944): "find" cursors are tested in the CSOT unified spec
392+
// tests for tailable/awaitData cursors and so these tests can be removed
393+
// once the driver supports timeoutMode.
394+
{
395+
name: "find client-level timeout",
396+
factory: tadcFindFactory,
397+
topologies: baseTopologies,
398+
opTimeout: false,
399+
},
400+
{
401+
name: "find operation-level timeout",
402+
factory: tadcFindFactory,
403+
topologies: baseTopologies,
404+
opTimeout: true,
405+
},
380406

381-
var getMoreStartedEvents []*event.CommandStartedEvent
382-
for _, evt := range startedEvents {
383-
if evt.CommandName == "getMore" {
384-
getMoreStartedEvents = append(getMoreStartedEvents, evt)
385-
}
386-
}
407+
// There is no analogue to tailable/awaiData cursor unified spec tests for
408+
// aggregate and runnCommand.
409+
{
410+
name: "aggregate with changeStream client-level timeout",
411+
factory: tadcAggregateFactory,
412+
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
413+
opTimeout: false,
414+
},
415+
{
416+
name: "aggregate with changeStream operation-level timeout",
417+
factory: tadcAggregateFactory,
418+
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
419+
opTimeout: true,
420+
},
421+
{
422+
name: "runCommandCursor client-level timeout",
423+
factory: tadcRunCommandCursorFactory,
424+
topologies: baseTopologies,
425+
opTimeout: false,
426+
},
427+
{
428+
name: "runCommandCursor operation-level timeout",
429+
factory: tadcRunCommandCursorFactory,
430+
topologies: baseTopologies,
431+
opTimeout: true,
432+
},
433+
}
387434

388-
// The first getMore should have a maxTimeMS of <= 100ms.
389-
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100))
435+
mt := mtest.New(t, mtest.NewOptions().CreateClient(false).MinServerVersion("4.2"))
390436

391-
// The second getMore should have a maxTimeMS of <=71, indicating that we
392-
// are using the time remaining in the context rather than the
393-
// maxAwaitTimeMS.
394-
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71))
395-
})
437+
for _, tc := range cases {
438+
// Reset the collection between test cases to avoid leaking timeouts
439+
// between tests.
440+
cappedOpts := options.CreateCollection().SetCapped(true).SetSizeInBytes(1024 * 64)
441+
caseOpts := mtest.NewOptions().
442+
CollectionCreateOptions(cappedOpts).
443+
Topologies(tc.topologies...).
444+
CreateClient(true)
396445

397-
mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
446+
if !tc.opTimeout {
447+
caseOpts = caseOpts.ClientOptions(options.Client().SetTimeout(timeoutMS * time.Millisecond))
448+
}
398449

399-
mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) {
400-
initCollection(mt, mt.Coll)
401-
mt.ClearEvents()
450+
mt.RunOpts(tc.name, caseOpts, func(mt *mtest.T) {
451+
mt.SetFailPoint(failpoint.FailPoint{
452+
ConfigureFailPoint: "failCommand",
453+
Mode: failpoint.Mode{Times: 1},
454+
Data: failpoint.Data{
455+
FailCommands: []string{"getMore"},
456+
BlockConnection: true,
457+
BlockTimeMS: int32(blockTimeMS),
458+
},
459+
})
460+
461+
ctx := context.Background()
462+
463+
var cancel context.CancelFunc
464+
if tc.opTimeout {
465+
ctx, cancel = context.WithTimeout(ctx, timeoutMS*time.Millisecond)
466+
defer cancel()
467+
}
402468

403-
// Create a find cursor
404-
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)
469+
cur := tc.factory(ctx, mt, *mt.Coll)
470+
defer func() { assert.NoError(mt, cur.Close(context.Background())) }()
405471

406-
cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
407-
require.NoError(mt, err)
472+
require.NoError(mt, cur.Err())
408473

409-
_ = mt.GetStartedEvent() // Empty find from started list.
474+
cur.SetMaxAwaitTime(maxAwaitTimeMS * time.Millisecond)
410475

411-
defer cursor.Close(context.Background())
476+
mt.ClearEvents()
412477

413-
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
414-
defer cancel()
478+
assert.False(mt, cur.Next(ctx))
415479

416-
// Iterate twice to force a getMore
417-
cursor.Next(ctx)
418-
cursor.Next(ctx)
480+
require.Error(mt, cur.Err(), "expected error from cursor.Next")
481+
assert.ErrorIs(mt, cur.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
419482

420-
cmd := mt.GetStartedEvent().Command
483+
getMoreEvts := []*event.CommandStartedEvent{}
484+
for _, evt := range mt.GetAllStartedEvents() {
485+
if evt.CommandName == "getMore" {
486+
getMoreEvts = append(getMoreEvts, evt)
487+
}
488+
}
421489

422-
maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
423-
require.NoError(mt, err)
490+
// It's possible that three getMore events are called: 100ms, 70ms, and
491+
// then some small leftover of remaining time (e.g. 20µs).
492+
require.GreaterOrEqual(mt, len(getMoreEvts), 2)
424493

425-
got, ok := maxTimeMSRaw.AsInt64OK()
426-
require.True(mt, ok)
494+
// The first getMore should have a maxTimeMS of <= 100ms but greater
495+
// than 71ms, indicating that the maxAwaitTimeMS was used.
496+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(maxAwaitTimeMS))
497+
assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(getMoreBound))
427498

428-
assert.LessOrEqual(mt, got, int64(50))
429-
})
499+
// The second getMore should have a maxTimeMS of <=71, indicating that we
500+
// are using the time remaining in the context rather than the
501+
// maxAwaitTimeMS.
502+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(getMoreBound))
503+
})
504+
}
430505
}
431506

507+
// For tailable awaitData cursors, the maxTimeMS for a getMore should be
432508
func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) {
433509
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
434510

mongo/client_bulk_write.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,11 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
476476
return err
477477
}
478478
var cursor *Cursor
479-
cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry)
479+
cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry,
480+
481+
// This op doesn't return a cursor to the user, so setting the client
482+
// timeout should be a no-op.
483+
withCursorOptionClientTimeout(mb.client.timeout))
480484
if err != nil {
481485
return err
482486
}

mongo/collection.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,13 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption
10921092
if err != nil {
10931093
return nil, wrapErrors(err)
10941094
}
1095-
cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess)
1095+
cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess,
1096+
1097+
// The only way the server will return a tailable/awaitData cursor for an
1098+
// aggregate operation is for the first stage in the pipeline to
1099+
// be $changeStream, this is the only time maxAwaitTimeMS should be applied.
1100+
// For this reason, we pass the client timeout to the cursor.
1101+
withCursorOptionClientTimeout(a.client.timeout))
10961102
return cursor, wrapErrors(err)
10971103
}
10981104

@@ -1567,7 +1573,9 @@ func (coll *Collection) find(
15671573
if err != nil {
15681574
return nil, wrapErrors(err)
15691575
}
1570-
return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess)
1576+
1577+
return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess,
1578+
withCursorOptionClientTimeout(coll.client.timeout))
15711579
}
15721580

15731581
func newFindArgsFromFindOneArgs(args *options.FindOneOptions) *options.FindOptions {

0 commit comments

Comments
 (0)