@@ -731,7 +731,7 @@ def emit_sequence_constructor(self, name, type):
731
731
class PyTypesDeclareVisitor (PickleVisitor ):
732
732
733
733
def visitProduct (self , prod , name ):
734
- self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name , 0 )
734
+ self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name , 0 )
735
735
if prod .attributes :
736
736
self .emit ("static const char * const %s_attributes[] = {" % name , 0 )
737
737
for a in prod .attributes :
@@ -752,7 +752,7 @@ def visitSum(self, sum, name):
752
752
ptype = "void*"
753
753
if is_simple (sum ):
754
754
ptype = get_c_type (name )
755
- self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name , ptype ), 0 )
755
+ self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name , ptype ), 0 )
756
756
for t in sum .types :
757
757
self .visitConstructor (t , name )
758
758
@@ -984,15 +984,16 @@ def visitModule(self, mod):
984
984
985
985
/* Conversion AST -> Python */
986
986
987
- static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
987
+ static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
988
+ PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
988
989
{
989
990
Py_ssize_t i, n = asdl_seq_LEN(seq);
990
991
PyObject *result = PyList_New(n);
991
992
PyObject *value;
992
993
if (!result)
993
994
return NULL;
994
995
for (i = 0; i < n; i++) {
995
- value = func(state, asdl_seq_GET_UNTYPED(seq, i));
996
+ value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
996
997
if (!value) {
997
998
Py_DECREF(result);
998
999
return NULL;
@@ -1002,7 +1003,7 @@ def visitModule(self, mod):
1002
1003
return result;
1003
1004
}
1004
1005
1005
- static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
1006
+ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
1006
1007
{
1007
1008
PyObject *op = (PyObject*)o;
1008
1009
if (!op) {
@@ -1014,7 +1015,7 @@ def visitModule(self, mod):
1014
1015
#define ast2obj_identifier ast2obj_object
1015
1016
#define ast2obj_string ast2obj_object
1016
1017
1017
- static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
1018
+ static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
1018
1019
{
1019
1020
return PyLong_FromLong(b);
1020
1021
}
@@ -1123,8 +1124,6 @@ def visitModule(self, mod):
1123
1124
for dfn in mod .dfns :
1124
1125
self .visit (dfn )
1125
1126
self .file .write (textwrap .dedent ('''
1126
- state->recursion_depth = 0;
1127
- state->recursion_limit = 0;
1128
1127
state->initialized = 1;
1129
1128
return 1;
1130
1129
}
@@ -1265,24 +1264,25 @@ class ObjVisitor(PickleVisitor):
1265
1264
def func_begin (self , name ):
1266
1265
ctype = get_c_type (name )
1267
1266
self .emit ("PyObject*" , 0 )
1268
- self .emit ("ast2obj_%s(struct ast_state *state, void* _o)" % (name ), 0 )
1267
+ self .emit ("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name ), 0 )
1269
1268
self .emit ("{" , 0 )
1270
1269
self .emit ("%s o = (%s)_o;" % (ctype , ctype ), 1 )
1271
1270
self .emit ("PyObject *result = NULL, *value = NULL;" , 1 )
1272
1271
self .emit ("PyTypeObject *tp;" , 1 )
1273
1272
self .emit ('if (!o) {' , 1 )
1274
1273
self .emit ("Py_RETURN_NONE;" , 2 )
1275
1274
self .emit ("}" , 1 )
1276
- self .emit ("if (++state ->recursion_depth > state ->recursion_limit) {" , 1 )
1275
+ self .emit ("if (++vstate ->recursion_depth > vstate ->recursion_limit) {" , 1 )
1277
1276
self .emit ("PyErr_SetString(PyExc_RecursionError," , 2 )
1278
1277
self .emit ('"maximum recursion depth exceeded during ast construction");' , 3 )
1279
1278
self .emit ("return 0;" , 2 )
1280
1279
self .emit ("}" , 1 )
1281
1280
1282
1281
def func_end (self ):
1283
- self .emit ("state ->recursion_depth--;" , 1 )
1282
+ self .emit ("vstate ->recursion_depth--;" , 1 )
1284
1283
self .emit ("return result;" , 1 )
1285
1284
self .emit ("failed:" , 0 )
1285
+ self .emit ("vstate->recursion_depth--;" , 1 )
1286
1286
self .emit ("Py_XDECREF(value);" , 1 )
1287
1287
self .emit ("Py_XDECREF(result);" , 1 )
1288
1288
self .emit ("return NULL;" , 1 )
@@ -1300,15 +1300,15 @@ def visitSum(self, sum, name):
1300
1300
self .visitConstructor (t , i + 1 , name )
1301
1301
self .emit ("}" , 1 )
1302
1302
for a in sum .attributes :
1303
- self .emit ("value = ast2obj_%s(state, o->%s);" % (a .type , a .name ), 1 )
1303
+ self .emit ("value = ast2obj_%s(state, vstate, o->%s);" % (a .type , a .name ), 1 )
1304
1304
self .emit ("if (!value) goto failed;" , 1 )
1305
1305
self .emit ('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a .name , 1 )
1306
1306
self .emit ('goto failed;' , 2 )
1307
1307
self .emit ('Py_DECREF(value);' , 1 )
1308
1308
self .func_end ()
1309
1309
1310
1310
def simpleSum (self , sum , name ):
1311
- self .emit ("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name , name ), 0 )
1311
+ self .emit ("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name , name ), 0 )
1312
1312
self .emit ("{" , 0 )
1313
1313
self .emit ("switch(o) {" , 1 )
1314
1314
for t in sum .types :
@@ -1326,7 +1326,7 @@ def visitProduct(self, prod, name):
1326
1326
for field in prod .fields :
1327
1327
self .visitField (field , name , 1 , True )
1328
1328
for a in prod .attributes :
1329
- self .emit ("value = ast2obj_%s(state, o->%s);" % (a .type , a .name ), 1 )
1329
+ self .emit ("value = ast2obj_%s(state, vstate, o->%s);" % (a .type , a .name ), 1 )
1330
1330
self .emit ("if (!value) goto failed;" , 1 )
1331
1331
self .emit ("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a .name , 1 )
1332
1332
self .emit ('goto failed;' , 2 )
@@ -1367,7 +1367,7 @@ def set(self, field, value, depth):
1367
1367
self .emit ("for(i = 0; i < n; i++)" , depth + 1 )
1368
1368
# This cannot fail, so no need for error handling
1369
1369
self .emit (
1370
- "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));" .format (
1370
+ "PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));" .format (
1371
1371
field .type ,
1372
1372
value
1373
1373
),
@@ -1376,9 +1376,9 @@ def set(self, field, value, depth):
1376
1376
)
1377
1377
self .emit ("}" , depth )
1378
1378
else :
1379
- self .emit ("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value , field .type ), depth )
1379
+ self .emit ("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value , field .type ), depth )
1380
1380
else :
1381
- self .emit ("value = ast2obj_%s(state, %s);" % (field .type , value ), depth , reflow = False )
1381
+ self .emit ("value = ast2obj_%s(state, vstate, %s);" % (field .type , value ), depth , reflow = False )
1382
1382
1383
1383
1384
1384
class PartingShots (StaticVisitor ):
@@ -1396,21 +1396,22 @@ class PartingShots(StaticVisitor):
1396
1396
int COMPILER_STACK_FRAME_SCALE = 2;
1397
1397
PyThreadState *tstate = _PyThreadState_GET();
1398
1398
if (!tstate) {
1399
- return 0 ;
1399
+ return NULL ;
1400
1400
}
1401
- state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1401
+ struct validator vstate;
1402
+ vstate.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1402
1403
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
1403
1404
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
1404
- state-> recursion_depth = starting_recursion_depth;
1405
+ vstate. recursion_depth = starting_recursion_depth;
1405
1406
1406
- PyObject *result = ast2obj_mod(state, t);
1407
+ PyObject *result = ast2obj_mod(state, &vstate, t);
1407
1408
1408
1409
/* Check that the recursion depth counting balanced correctly */
1409
- if (result && state-> recursion_depth != starting_recursion_depth) {
1410
+ if (result && vstate. recursion_depth != starting_recursion_depth) {
1410
1411
PyErr_Format(PyExc_SystemError,
1411
1412
"AST constructor recursion depth mismatch (before=%d, after=%d)",
1412
- starting_recursion_depth, state-> recursion_depth);
1413
- return 0 ;
1413
+ starting_recursion_depth, vstate. recursion_depth);
1414
+ return NULL ;
1414
1415
}
1415
1416
return result;
1416
1417
}
@@ -1478,8 +1479,8 @@ def visit(self, object):
1478
1479
def generate_ast_state (module_state , f ):
1479
1480
f .write ('struct ast_state {\n ' )
1480
1481
f .write (' int initialized;\n ' )
1481
- f .write (' int recursion_depth ;\n ' )
1482
- f .write (' int recursion_limit ;\n ' )
1482
+ f .write (' int unused_recursion_depth ;\n ' )
1483
+ f .write (' int unused_recursion_limit ;\n ' )
1483
1484
for s in module_state :
1484
1485
f .write (' PyObject *' + s + ';\n ' )
1485
1486
f .write ('};' )
@@ -1545,6 +1546,11 @@ def generate_module_def(mod, metadata, f, internal_h):
1545
1546
#include "structmember.h"
1546
1547
#include <stddef.h>
1547
1548
1549
+ struct validator {
1550
+ int recursion_depth; /* current recursion depth */
1551
+ int recursion_limit; /* recursion limit */
1552
+ };
1553
+
1548
1554
// Forward declaration
1549
1555
static int init_types(struct ast_state *state);
1550
1556
0 commit comments