Skip to content

Commit d58a5f4

Browse files
serhiy-storchakayileigpshead
authored
[3.12] gh-106905: Use separate structs to track recursion depth in each PyAST_mod2obj call. (GH-113035) (GH-113472)
(cherry picked from commit 48c4973) Co-authored-by: Yilei Yang <[email protected]> Co-authored-by: Gregory P. Smith [Google LLC] <[email protected]>
1 parent 2c07540 commit d58a5f4

File tree

4 files changed

+420
-328
lines changed

4 files changed

+420
-328
lines changed

Include/internal/pycore_ast_state.h

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Use per AST-parser state rather than global state to track recursion depth
2+
within the AST parser to prevent potential race condition due to
3+
simultaneous parsing.
4+
5+
The issue primarily showed up in 3.11 by multithreaded users of
6+
:func:`ast.parse`. In 3.12 a change to when garbage collection can be
7+
triggered prevented the race condition from occurring.

Parser/asdl_c.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ def emit_sequence_constructor(self, name, type):
731731
class PyTypesDeclareVisitor(PickleVisitor):
732732

733733
def visitProduct(self, prod, name):
734-
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
734+
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
735735
if prod.attributes:
736736
self.emit("static const char * const %s_attributes[] = {" % name, 0)
737737
for a in prod.attributes:
@@ -752,7 +752,7 @@ def visitSum(self, sum, name):
752752
ptype = "void*"
753753
if is_simple(sum):
754754
ptype = get_c_type(name)
755-
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
755+
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
756756
for t in sum.types:
757757
self.visitConstructor(t, name)
758758

@@ -984,15 +984,16 @@ def visitModule(self, mod):
984984
985985
/* Conversion AST -> Python */
986986
987-
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
987+
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
988+
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
988989
{
989990
Py_ssize_t i, n = asdl_seq_LEN(seq);
990991
PyObject *result = PyList_New(n);
991992
PyObject *value;
992993
if (!result)
993994
return NULL;
994995
for (i = 0; i < n; i++) {
995-
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
996+
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
996997
if (!value) {
997998
Py_DECREF(result);
998999
return NULL;
@@ -1002,7 +1003,7 @@ def visitModule(self, mod):
10021003
return result;
10031004
}
10041005
1005-
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
1006+
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
10061007
{
10071008
PyObject *op = (PyObject*)o;
10081009
if (!op) {
@@ -1014,7 +1015,7 @@ def visitModule(self, mod):
10141015
#define ast2obj_identifier ast2obj_object
10151016
#define ast2obj_string ast2obj_object
10161017
1017-
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
1018+
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
10181019
{
10191020
return PyLong_FromLong(b);
10201021
}
@@ -1123,8 +1124,6 @@ def visitModule(self, mod):
11231124
for dfn in mod.dfns:
11241125
self.visit(dfn)
11251126
self.file.write(textwrap.dedent('''
1126-
state->recursion_depth = 0;
1127-
state->recursion_limit = 0;
11281127
state->initialized = 1;
11291128
return 1;
11301129
}
@@ -1265,24 +1264,25 @@ class ObjVisitor(PickleVisitor):
12651264
def func_begin(self, name):
12661265
ctype = get_c_type(name)
12671266
self.emit("PyObject*", 0)
1268-
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
1267+
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
12691268
self.emit("{", 0)
12701269
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
12711270
self.emit("PyObject *result = NULL, *value = NULL;", 1)
12721271
self.emit("PyTypeObject *tp;", 1)
12731272
self.emit('if (!o) {', 1)
12741273
self.emit("Py_RETURN_NONE;", 2)
12751274
self.emit("}", 1)
1276-
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
1275+
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
12771276
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
12781277
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
12791278
self.emit("return 0;", 2)
12801279
self.emit("}", 1)
12811280

12821281
def func_end(self):
1283-
self.emit("state->recursion_depth--;", 1)
1282+
self.emit("vstate->recursion_depth--;", 1)
12841283
self.emit("return result;", 1)
12851284
self.emit("failed:", 0)
1285+
self.emit("vstate->recursion_depth--;", 1)
12861286
self.emit("Py_XDECREF(value);", 1)
12871287
self.emit("Py_XDECREF(result);", 1)
12881288
self.emit("return NULL;", 1)
@@ -1300,15 +1300,15 @@ def visitSum(self, sum, name):
13001300
self.visitConstructor(t, i + 1, name)
13011301
self.emit("}", 1)
13021302
for a in sum.attributes:
1303-
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
1303+
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
13041304
self.emit("if (!value) goto failed;", 1)
13051305
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
13061306
self.emit('goto failed;', 2)
13071307
self.emit('Py_DECREF(value);', 1)
13081308
self.func_end()
13091309

13101310
def simpleSum(self, sum, name):
1311-
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
1311+
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
13121312
self.emit("{", 0)
13131313
self.emit("switch(o) {", 1)
13141314
for t in sum.types:
@@ -1326,7 +1326,7 @@ def visitProduct(self, prod, name):
13261326
for field in prod.fields:
13271327
self.visitField(field, name, 1, True)
13281328
for a in prod.attributes:
1329-
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
1329+
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
13301330
self.emit("if (!value) goto failed;", 1)
13311331
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
13321332
self.emit('goto failed;', 2)
@@ -1367,7 +1367,7 @@ def set(self, field, value, depth):
13671367
self.emit("for(i = 0; i < n; i++)", depth+1)
13681368
# This cannot fail, so no need for error handling
13691369
self.emit(
1370-
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
1370+
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
13711371
field.type,
13721372
value
13731373
),
@@ -1376,9 +1376,9 @@ def set(self, field, value, depth):
13761376
)
13771377
self.emit("}", depth)
13781378
else:
1379-
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
1379+
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
13801380
else:
1381-
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
1381+
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
13821382

13831383

13841384
class PartingShots(StaticVisitor):
@@ -1396,21 +1396,22 @@ class PartingShots(StaticVisitor):
13961396
int COMPILER_STACK_FRAME_SCALE = 2;
13971397
PyThreadState *tstate = _PyThreadState_GET();
13981398
if (!tstate) {
1399-
return 0;
1399+
return NULL;
14001400
}
1401-
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1401+
struct validator vstate;
1402+
vstate.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
14021403
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
14031404
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
1404-
state->recursion_depth = starting_recursion_depth;
1405+
vstate.recursion_depth = starting_recursion_depth;
14051406
1406-
PyObject *result = ast2obj_mod(state, t);
1407+
PyObject *result = ast2obj_mod(state, &vstate, t);
14071408
14081409
/* Check that the recursion depth counting balanced correctly */
1409-
if (result && state->recursion_depth != starting_recursion_depth) {
1410+
if (result && vstate.recursion_depth != starting_recursion_depth) {
14101411
PyErr_Format(PyExc_SystemError,
14111412
"AST constructor recursion depth mismatch (before=%d, after=%d)",
1412-
starting_recursion_depth, state->recursion_depth);
1413-
return 0;
1413+
starting_recursion_depth, vstate.recursion_depth);
1414+
return NULL;
14141415
}
14151416
return result;
14161417
}
@@ -1478,8 +1479,8 @@ def visit(self, object):
14781479
def generate_ast_state(module_state, f):
14791480
f.write('struct ast_state {\n')
14801481
f.write(' int initialized;\n')
1481-
f.write(' int recursion_depth;\n')
1482-
f.write(' int recursion_limit;\n')
1482+
f.write(' int unused_recursion_depth;\n')
1483+
f.write(' int unused_recursion_limit;\n')
14831484
for s in module_state:
14841485
f.write(' PyObject *' + s + ';\n')
14851486
f.write('};')
@@ -1545,6 +1546,11 @@ def generate_module_def(mod, metadata, f, internal_h):
15451546
#include "structmember.h"
15461547
#include <stddef.h>
15471548
1549+
struct validator {
1550+
int recursion_depth; /* current recursion depth */
1551+
int recursion_limit; /* recursion limit */
1552+
};
1553+
15481554
// Forward declaration
15491555
static int init_types(struct ast_state *state);
15501556

0 commit comments

Comments
 (0)