Skip to content

Commit ec65a63

Browse files
committed
GODRIVER-2677 Limit the maximum number of items in pool.
1 parent ef0c0ab commit ec65a63

File tree

4 files changed

+125
-31
lines changed

4 files changed

+125
-31
lines changed

x/mongo/driver/byteslicepool.go

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (C) MongoDB, Inc. 2022-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package driver
8+
9+
import (
10+
"sync"
11+
)
12+
13+
type byteslicePool struct {
14+
pool interface {
15+
Get() interface{}
16+
Put(interface{})
17+
}
18+
19+
countmax int
20+
count int
21+
mutex *sync.Mutex
22+
}
23+
24+
// newByteSlicePool creates a byte slices pool with a maximum number of items,
25+
// which is specified by the parameter, "size".
26+
func newByteSlicePool(size int) *byteslicePool {
27+
return &byteslicePool{
28+
pool: &sync.Pool{
29+
New: func() interface{} {
30+
// Start with 1kb buffers.
31+
b := make([]byte, 1024)
32+
// Return a pointer as the static analysis tool suggests.
33+
return &b
34+
},
35+
},
36+
countmax: size,
37+
mutex: new(sync.Mutex),
38+
}
39+
}
40+
41+
func (p *byteslicePool) Get() []byte {
42+
p.mutex.Lock()
43+
defer p.mutex.Unlock()
44+
if p.count < p.countmax {
45+
p.count++
46+
return (*p.pool.Get().(*[]byte))[:0]
47+
}
48+
return make([]byte, 0)
49+
}
50+
51+
func (p *byteslicePool) Put(b []byte) {
52+
// Proper usage of a sync.Pool requires each entry to have approximately the same memory
53+
// cost. To obtain this property when the stored type contains a variably-sized buffer,
54+
// we add a hard limit on the maximum buffer to place back in the pool. We limit the
55+
// size to 16MiB because that's the maximum wire message size supported by MongoDB.
56+
//
57+
// Comment copied from https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
58+
if c := cap(b); c <= 16*1024*1024 {
59+
p.mutex.Lock()
60+
defer p.mutex.Unlock()
61+
if p.count > 0 {
62+
p.pool.Put(&b)
63+
p.count--
64+
}
65+
}
66+
}

x/mongo/driver/byteslicepool_test.go

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (C) MongoDB, Inc. 2022-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package driver
8+
9+
import (
10+
"testing"
11+
12+
"go.mongodb.org/mongo-driver/internal/assert"
13+
)
14+
15+
type dummypool struct {
16+
getcnt int
17+
putcnt int
18+
}
19+
20+
func (p *dummypool) Get() interface{} {
21+
p.getcnt++
22+
b := make([]byte, 42)
23+
return &b
24+
}
25+
26+
func (p *dummypool) Put(_ interface{}) {
27+
p.putcnt++
28+
}
29+
30+
func TestByteSlicePool(t *testing.T) {
31+
t.Run("allocation", func(t *testing.T) {
32+
var memoryPool = newByteSlicePool(1)
33+
p := &dummypool{}
34+
memoryPool.pool = p
35+
b1 := memoryPool.Get()
36+
assert.Equal(t, 1, p.getcnt, "slice was not allocated correctly")
37+
b2 := memoryPool.Get()
38+
assert.Equal(t, 1, p.getcnt, "slice was not allocated correctly")
39+
memoryPool.Put(b2)
40+
assert.Equal(t, 1, p.putcnt, "slice was not returned correctly")
41+
memoryPool.Put(b1)
42+
assert.Equal(t, 1, p.putcnt, "slice was not returned correctly")
43+
})
44+
}

x/mongo/driver/operation.go

+13-29
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"fmt"
1414
"strconv"
1515
"strings"
16-
"sync"
1716
"time"
1817

1918
"go.mongodb.org/mongo-driver/bson"
@@ -318,14 +317,8 @@ func (op Operation) Validate() error {
318317
return nil
319318
}
320319

321-
var memoryPool = sync.Pool{
322-
New: func() interface{} {
323-
// Start with 1kb buffers.
324-
b := make([]byte, 1024)
325-
// Return a pointer as the static analysis tool suggests.
326-
return &b
327-
},
328-
}
320+
// Create a pool of maximum 512 byte slices.
321+
var memoryPool = newByteSlicePool(512)
329322

330323
// Execute runs this operation.
331324
func (op Operation) Execute(ctx context.Context) error {
@@ -425,17 +418,8 @@ func (op Operation) Execute(ctx context.Context) error {
425418
conn = nil
426419
}
427420

428-
wm := memoryPool.Get().(*[]byte)
421+
wm := memoryPool.Get()
429422
defer func() {
430-
// Proper usage of a sync.Pool requires each entry to have approximately the same memory
431-
// cost. To obtain this property when the stored type contains a variably-sized buffer,
432-
// we add a hard limit on the maximum buffer to place back in the pool. We limit the
433-
// size to 16MiB because that's the maximum wire message size supported by MongoDB.
434-
//
435-
// Comment copied from https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
436-
if cap(*wm) > 16*1024*1024 {
437-
return
438-
}
439423
memoryPool.Put(wm)
440424
}()
441425
for {
@@ -536,7 +520,7 @@ func (op Operation) Execute(ctx context.Context) error {
536520
}
537521

538522
var startedInfo startedInformation
539-
*wm, startedInfo, err = op.createWireMessage(ctx, (*wm)[:0], desc, maxTimeMS, conn)
523+
wm, startedInfo, err = op.createWireMessage(ctx, wm[:0], desc, maxTimeMS, conn)
540524
if err != nil {
541525
return err
542526
}
@@ -551,12 +535,12 @@ func (op Operation) Execute(ctx context.Context) error {
551535
op.publishStartedEvent(ctx, startedInfo)
552536

553537
// get the moreToCome flag information before we compress
554-
moreToCome := wiremessage.IsMsgMoreToCome(*wm)
538+
moreToCome := wiremessage.IsMsgMoreToCome(wm)
555539

556540
// compress wiremessage if allowed
557541
if compressor, ok := conn.(Compressor); ok && op.canCompress(startedInfo.cmdName) {
558-
b := memoryPool.Get().(*[]byte)
559-
*b, err = compressor.CompressWireMessage(*wm, (*b)[:0])
542+
b := memoryPool.Get()
543+
b, err = compressor.CompressWireMessage(wm, b[:0])
560544
memoryPool.Put(wm)
561545
wm = b
562546
if err != nil {
@@ -595,7 +579,7 @@ func (op Operation) Execute(ctx context.Context) error {
595579
if moreToCome {
596580
roundTrip = op.moreToComeRoundTrip
597581
}
598-
res, *wm, err = roundTrip(ctx, conn, *wm)
582+
res, wm, err = roundTrip(ctx, conn, wm)
599583

600584
if ep, ok := srvr.(ErrorProcessor); ok {
601585
_ = ep.ProcessError(err, conn)
@@ -975,12 +959,12 @@ func (Operation) decompressWireMessage(wm []byte) ([]byte, error) {
975959
}
976960

977961
// Copy msg, which is a subslice of wm. wm will be used to store the return value of the decompressed message.
978-
b := memoryPool.Get().(*[]byte)
962+
b := memoryPool.Get()
979963
msglen := len(msg)
980-
if len(*b) < msglen {
981-
*b = make([]byte, msglen)
964+
if len(b) < msglen {
965+
b = make([]byte, msglen)
982966
}
983-
copy(*b, msg)
967+
copy(b, msg)
984968
defer func() {
985969
memoryPool.Put(b)
986970
}()
@@ -993,7 +977,7 @@ func (Operation) decompressWireMessage(wm []byte) ([]byte, error) {
993977
Compressor: compressorID,
994978
UncompressedSize: uncompressedSize,
995979
}
996-
uncompressed, err := DecompressPayload((*b)[0:msglen], opts)
980+
uncompressed, err := DecompressPayload(b[0:msglen], opts)
997981
if err != nil {
998982
return nil, err
999983
}

x/mongo/driver/operation_exhaust.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ func (op Operation) ExecuteExhaust(ctx context.Context, conn StreamerConnection)
1818
return errors.New("exhaust read must be done with a connection that is currently streaming")
1919
}
2020

21-
wm := memoryPool.Get().(*[]byte)
21+
wm := memoryPool.Get()
2222
defer func() {
2323
memoryPool.Put(wm)
2424
}()
2525
var res []byte
2626
var err error
27-
res, *wm, err = op.readWireMessage(ctx, conn, (*wm)[:0])
27+
res, wm, err = op.readWireMessage(ctx, conn, wm[:0])
2828
if err != nil {
2929
return err
3030
}

0 commit comments

Comments
 (0)