Skip to content

Commit b78c5cd

Browse files
authored
Add recursion limit for dynamic code (#358)
Prevent stack exhaustion on: Decoder: * CopyNext * Skip * ReadIntf * ReadMapStrIntf * WriteToJSON Standalone: * Skip * ReadMapStrIntfBytes * ReadIntfBytes * CopyToJSON * UnmarshalAsJSON Limit is set to 100K recursive map/slice operations.
1 parent bdea0d5 commit b78c5cd

File tree

7 files changed

+250
-30
lines changed

7 files changed

+250
-30
lines changed

msgp/defs.go

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ const (
3232
last5 = 0x1f
3333
first3 = 0xe0
3434
last7 = 0x7f
35+
36+
// recursionLimit is the limit of recursive calls.
37+
// This limits the call depth of dynamic code, like Skip and interface conversions.
38+
recursionLimit = 100000
3539
)
3640

3741
func isfixint(b byte) bool {

msgp/errors.go

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ var (
1313
// contain the contents of the message
1414
ErrShortBytes error = errShort{}
1515

16+
// ErrRecursion is returned when the maximum recursion limit is reached for an operation.
17+
// This should only realistically be seen on adversarial data trying to exhaust the stack.
18+
ErrRecursion error = errRecursion{}
19+
1620
// this error is only returned
1721
// if we reach code that should
1822
// be unreachable
@@ -134,6 +138,11 @@ func (f errFatal) Resumable() bool { return false }
134138

135139
func (f errFatal) withContext(ctx string) error { f.ctx = addCtx(f.ctx, ctx); return f }
136140

141+
type errRecursion struct{}
142+
143+
func (e errRecursion) Error() string { return "msgp: recursion limit reached" }
144+
func (e errRecursion) Resumable() bool { return false }
145+
137146
// ArrayError is an error returned
138147
// when decoding a fix-sized array
139148
// of the wrong size

msgp/json.go

+14
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ func rwMap(dst jsWriter, src *Reader) (n int, err error) {
109109
return dst.WriteString("{}")
110110
}
111111

112+
// This is potentially a recursive call.
113+
if done, err := src.recursiveCall(); err != nil {
114+
return 0, err
115+
} else {
116+
defer done()
117+
}
118+
112119
err = dst.WriteByte('{')
113120
if err != nil {
114121
return
@@ -162,6 +169,13 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) {
162169
if err != nil {
163170
return
164171
}
172+
// This is potentially a recursive call.
173+
if done, err := src.recursiveCall(); err != nil {
174+
return 0, err
175+
} else {
176+
defer done()
177+
}
178+
165179
var sz uint32
166180
var nn int
167181
sz, err = src.ReadArrayHeader()

msgp/json_bytes.go

+29-23
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ import (
99
"time"
1010
)
1111

12-
var unfuns [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error)
12+
var unfuns [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error)
1313

1414
func init() {
1515
// NOTE(pmh): this is best expressed as a jump table,
1616
// but gc doesn't do that yet. revisit post-go1.5.
17-
unfuns = [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error){
17+
unfuns = [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error){
1818
StrType: rwStringBytes,
1919
BinType: rwBytesBytes,
2020
MapType: rwMapBytes,
@@ -51,15 +51,15 @@ func UnmarshalAsJSON(w io.Writer, msg []byte) ([]byte, error) {
5151
dst = bufio.NewWriterSize(w, 512)
5252
}
5353
for len(msg) > 0 && err == nil {
54-
msg, scratch, err = writeNext(dst, msg, scratch)
54+
msg, scratch, err = writeNext(dst, msg, scratch, 0)
5555
}
5656
if !cast && err == nil {
5757
err = dst.(*bufio.Writer).Flush()
5858
}
5959
return msg, err
6060
}
6161

62-
func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
62+
func writeNext(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
6363
if len(msg) < 1 {
6464
return msg, scratch, ErrShortBytes
6565
}
@@ -76,10 +76,13 @@ func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
7676
t = TimeType
7777
}
7878
}
79-
return unfuns[t](w, msg, scratch)
79+
return unfuns[t](w, msg, scratch, depth)
8080
}
8181

82-
func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
82+
func rwArrayBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
83+
if depth >= recursionLimit {
84+
return msg, scratch, ErrRecursion
85+
}
8386
sz, msg, err := ReadArrayHeaderBytes(msg)
8487
if err != nil {
8588
return msg, scratch, err
@@ -95,7 +98,7 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
9598
return msg, scratch, err
9699
}
97100
}
98-
msg, scratch, err = writeNext(w, msg, scratch)
101+
msg, scratch, err = writeNext(w, msg, scratch, depth+1)
99102
if err != nil {
100103
return msg, scratch, err
101104
}
@@ -104,7 +107,10 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
104107
return msg, scratch, err
105108
}
106109

107-
func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
110+
func rwMapBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
111+
if depth >= recursionLimit {
112+
return msg, scratch, ErrRecursion
113+
}
108114
sz, msg, err := ReadMapHeaderBytes(msg)
109115
if err != nil {
110116
return msg, scratch, err
@@ -120,15 +126,15 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
120126
return msg, scratch, err
121127
}
122128
}
123-
msg, scratch, err = rwMapKeyBytes(w, msg, scratch)
129+
msg, scratch, err = rwMapKeyBytes(w, msg, scratch, depth)
124130
if err != nil {
125131
return msg, scratch, err
126132
}
127133
err = w.WriteByte(':')
128134
if err != nil {
129135
return msg, scratch, err
130136
}
131-
msg, scratch, err = writeNext(w, msg, scratch)
137+
msg, scratch, err = writeNext(w, msg, scratch, depth+1)
132138
if err != nil {
133139
return msg, scratch, err
134140
}
@@ -137,17 +143,17 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
137143
return msg, scratch, err
138144
}
139145

140-
func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
141-
msg, scratch, err := rwStringBytes(w, msg, scratch)
146+
func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
147+
msg, scratch, err := rwStringBytes(w, msg, scratch, depth)
142148
if err != nil {
143149
if tperr, ok := err.(TypeError); ok && tperr.Encoded == BinType {
144-
return rwBytesBytes(w, msg, scratch)
150+
return rwBytesBytes(w, msg, scratch, depth)
145151
}
146152
}
147153
return msg, scratch, err
148154
}
149155

150-
func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
156+
func rwStringBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
151157
str, msg, err := ReadStringZC(msg)
152158
if err != nil {
153159
return msg, scratch, err
@@ -156,7 +162,7 @@ func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, erro
156162
return msg, scratch, err
157163
}
158164

159-
func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
165+
func rwBytesBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
160166
bts, msg, err := ReadBytesZC(msg)
161167
if err != nil {
162168
return msg, scratch, err
@@ -180,7 +186,7 @@ func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
180186
return msg, scratch, err
181187
}
182188

183-
func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
189+
func rwNullBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
184190
msg, err := ReadNilBytes(msg)
185191
if err != nil {
186192
return msg, scratch, err
@@ -189,7 +195,7 @@ func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
189195
return msg, scratch, err
190196
}
191197

192-
func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
198+
func rwBoolBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
193199
b, msg, err := ReadBoolBytes(msg)
194200
if err != nil {
195201
return msg, scratch, err
@@ -202,7 +208,7 @@ func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
202208
return msg, scratch, err
203209
}
204210

205-
func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
211+
func rwIntBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
206212
i, msg, err := ReadInt64Bytes(msg)
207213
if err != nil {
208214
return msg, scratch, err
@@ -212,7 +218,7 @@ func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
212218
return msg, scratch, err
213219
}
214220

215-
func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
221+
func rwUintBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
216222
u, msg, err := ReadUint64Bytes(msg)
217223
if err != nil {
218224
return msg, scratch, err
@@ -222,7 +228,7 @@ func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
222228
return msg, scratch, err
223229
}
224230

225-
func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
231+
func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
226232
var f float32
227233
var err error
228234
f, msg, err = ReadFloat32Bytes(msg)
@@ -234,7 +240,7 @@ func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err
234240
return msg, scratch, err
235241
}
236242

237-
func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
243+
func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
238244
var f float64
239245
var err error
240246
f, msg, err = ReadFloat64Bytes(msg)
@@ -246,7 +252,7 @@ func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err
246252
return msg, scratch, err
247253
}
248254

249-
func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
255+
func rwTimeBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
250256
var t time.Time
251257
var err error
252258
t, msg, err = ReadTimeBytes(msg)
@@ -261,7 +267,7 @@ func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
261267
return msg, scratch, err
262268
}
263269

264-
func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
270+
func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
265271
var err error
266272
var et int8
267273
et, err = peekExtension(msg)

msgp/read.go

+40-3
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,9 @@ type Reader struct {
143143
// is stateless; all the
144144
// buffering is done
145145
// within R.
146-
R *fwd.Reader
147-
scratch []byte
146+
R *fwd.Reader
147+
scratch []byte
148+
recursionDepth int
148149
}
149150

150151
// Read implements `io.Reader`
@@ -190,6 +191,11 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) {
190191
return n, io.ErrShortWrite
191192
}
192193

194+
if done, err := m.recursiveCall(); err != nil {
195+
return n, err
196+
} else {
197+
defer done()
198+
}
193199
// for maps and slices, read elements
194200
for x := uintptr(0); x < o; x++ {
195201
var n2 int64
@@ -202,6 +208,18 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) {
202208
return n, nil
203209
}
204210

211+
// recursiveCall will increment the recursion depth and return an error if it is exceeded.
212+
// If a nil error is returned, done must be called to decrement the counter.
213+
func (m *Reader) recursiveCall() (done func(), err error) {
214+
if m.recursionDepth >= recursionLimit {
215+
return func() {}, ErrRecursion
216+
}
217+
m.recursionDepth++
218+
return func() {
219+
m.recursionDepth--
220+
}, nil
221+
}
222+
205223
// ReadFull implements `io.ReadFull`
206224
func (m *Reader) ReadFull(p []byte) (int, error) {
207225
return m.R.ReadFull(p)
@@ -332,7 +350,12 @@ func (m *Reader) Skip() error {
332350
return err
333351
}
334352

335-
// for maps and slices, skip elements
353+
// for maps and slices, skip elements with recursive call
354+
if done, err := m.recursiveCall(); err != nil {
355+
return err
356+
} else {
357+
defer done()
358+
}
336359
for x := uintptr(0); x < o; x++ {
337360
err = m.Skip()
338361
if err != nil {
@@ -1333,6 +1356,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) {
13331356
return
13341357

13351358
case MapType:
1359+
// This can call back here, so treat as recursive call.
1360+
if done, err := m.recursiveCall(); err != nil {
1361+
return nil, err
1362+
} else {
1363+
defer done()
1364+
}
1365+
13361366
mp := make(map[string]interface{})
13371367
err = m.ReadMapStrIntf(mp)
13381368
i = mp
@@ -1358,6 +1388,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) {
13581388
if err != nil {
13591389
return
13601390
}
1391+
1392+
if done, err := m.recursiveCall(); err != nil {
1393+
return nil, err
1394+
} else {
1395+
defer done()
1396+
}
1397+
13611398
out := make([]interface{}, int(sz))
13621399
for j := range out {
13631400
out[j], err = m.ReadIntf()

0 commit comments

Comments
 (0)