Skip to content

GH-91079: Decouple C stack overflow checks from Python recursion checks. #96507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
9 changes: 6 additions & 3 deletions Include/cpython/pystate.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,12 @@ struct _ts {
/* Was this thread state statically allocated? */
int _static;

int recursion_remaining;
int recursion_limit;
int recursion_headroom; /* Allow 50 more calls to handle any errors. */
int py_recursion_remaining;
int py_recursion_limit;
int py_recursion_headroom; /* Allow 50 more calls to handle any errors. */

int c_recursion_remaining;
int c_recursion_headroom; /* Allow 50 more calls to handle any errors. */

/* 'tracing' keeps track of the execution depth when tracing/profiling.
This is to prevent the actual trace/profile code from being recorded in
Expand Down
2 changes: 0 additions & 2 deletions Include/internal/pycore_ast_state.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 20 additions & 15 deletions Include/internal/pycore_ceval.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,26 @@ extern void _PyEval_DeactivateOpCache(void);

/* --- _Py_EnterRecursiveCall() ----------------------------------------- */

#ifdef USE_STACKCHECK
/* With USE_STACKCHECK macro defined, trigger stack checks in
_Py_CheckRecursiveCall() on every 64th call to _Py_EnterRecursiveCall. */
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
return (tstate->recursion_remaining-- <= 0
|| (tstate->recursion_remaining & 63) == 0);
}
#else
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
return tstate->recursion_remaining-- <= 0;
}
#endif

PyAPI_FUNC(int) _Py_CheckRecursiveCall(
PyThreadState *tstate,
const char *where);

PyAPI_FUNC(int) _Py_CheckRecursiveCallN(PyThreadState *tstate, int n,
const char *where);

static inline int _Py_EnterRecursiveCallTstate(PyThreadState *tstate,
const char *where) {
return (_Py_MakeRecCheck(tstate) && _Py_CheckRecursiveCall(tstate, where));
return (tstate->c_recursion_remaining-- <= 0) &&
_Py_CheckRecursiveCallN(tstate, 1, where);
}

static inline int _Py_EnterRecursiveCallN(PyThreadState *tstate, int n,
const char *where) {
tstate->c_recursion_remaining -= n;
if (tstate->c_recursion_remaining >= 0) {
return 0;
}
return _Py_CheckRecursiveCallN(tstate, n, where);
}

static inline int _Py_EnterRecursiveCall(const char *where) {
Expand All @@ -142,7 +142,11 @@ static inline int _Py_EnterRecursiveCall(const char *where) {
}

static inline void _Py_LeaveRecursiveCallTstate(PyThreadState *tstate) {
tstate->recursion_remaining++;
tstate->c_recursion_remaining++;
}

static inline void _Py_LeaveRecursiveCallN(PyThreadState *tstate, int n) {
tstate->c_recursion_remaining += n;
}

static inline void _Py_LeaveRecursiveCall(void) {
Expand All @@ -156,6 +160,7 @@ extern PyObject* _Py_MakeCoro(PyFunctionObject *func);

extern int _Py_HandlePending(PyThreadState *tstate);

#define C_RECURSION_LIMT 2000

#ifdef __cplusplus
}
Expand Down
3 changes: 0 additions & 3 deletions Include/internal/pycore_compile.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ extern PyObject* _Py_Mangle(PyObject *p, PyObject *name);
typedef struct {
int optimize;
int ff_features;

int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
} _PyASTOptimizeState;

extern int _PyAST_Optimize(
Expand Down
2 changes: 1 addition & 1 deletion Include/internal/pycore_runtime_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ extern "C" {
#define _PyThreadState_INIT \
{ \
._static = 1, \
.recursion_limit = Py_DEFAULT_RECURSION_LIMIT, \
.py_recursion_limit = Py_DEFAULT_RECURSION_LIMIT, \
.context_ver = 1, \
}

Expand Down
2 changes: 0 additions & 2 deletions Include/internal/pycore_symtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ struct symtable {
PyObject *st_private; /* name of current class or NULL */
PyFutureFeatures *st_future; /* module's future features that affect
the symbol table */
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};

typedef struct _symtable_entry {
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/list_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_repr(self):

def test_repr_deep(self):
a = self.type2test([])
for i in range(sys.getrecursionlimit() + 100):
for i in range(4000):
a = self.type2test([a])
self.assertRaises(RecursionError, repr, a)

Expand Down
6 changes: 3 additions & 3 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,9 +816,9 @@ def next(self):

@support.cpython_only
def test_ast_recursion_limit(self):
fail_depth = sys.getrecursionlimit() * 3
crash_depth = sys.getrecursionlimit() * 300
success_depth = int(fail_depth * 0.75)
fail_depth = 3000
crash_depth = 100_000
success_depth = 1500

def check_limit(prefix, repeated):
expect_ok = prefix + repeated * success_depth
Expand Down
38 changes: 38 additions & 0 deletions Lib/test/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,44 @@ def test_multiple_values(self):
with self.check_raises_type_error(msg):
A().method_two_args("x", "y", x="oops")

@cpython_only
class TestRecursion(unittest.TestCase):

def test_super_deep(self):

def recurse(n):
if n:
recurse(n-1)

def py_recurse(n, m):
if n:
py_recurse(n-1, m)
else:
c_py_recurse(m-1)

def c_recurse(n):
if n:
_testcapi.pyobject_fastcall(c_recurse, (n-1,))

def c_py_recurse(m):
if m:
_testcapi.pyobject_fastcall(py_recurse, (1000, m))

depth = sys.getrecursionlimit()
sys.setrecursionlimit(100_000)
try:
recurse(90_000)
with self.assertRaises(RecursionError):
recurse(101_000)
c_recurse(100)
with self.assertRaises(RecursionError):
c_recurse(90_000)
c_py_recurse(90)
with self.assertRaises(RecursionError):
c_py_recurse(100_000)
finally:
sys.setrecursionlimit(depth)


if __name__ == "__main__":
unittest.main()
16 changes: 4 additions & 12 deletions Lib/test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def __getitem__(self, key):

@unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI")
def test_extended_arg(self):
# default: 1000 * 2.5 = 2500 repetitions
repeat = int(sys.getrecursionlimit() * 2.5)
repeat = 1500
longexpr = 'x = x or ' + '-x' * repeat
g = {}
code = '''
Expand Down Expand Up @@ -546,16 +545,9 @@ def test_yet_more_evil_still_undecodable(self):
@support.cpython_only
@unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI")
def test_compiler_recursion_limit(self):
# Expected limit is sys.getrecursionlimit() * the scaling factor
# in symtable.c (currently 3)
# We expect to fail *at* that limit, because we use up some of
# the stack depth limit in the test suite code
# So we check the expected limit and 75% of that
# XXX (ncoghlan): duplicating the scaling factor here is a little
# ugly. Perhaps it should be exposed somewhere...
fail_depth = sys.getrecursionlimit() * 3
crash_depth = sys.getrecursionlimit() * 300
success_depth = int(fail_depth * 0.75)
fail_depth = 3000
crash_depth = 100_000
success_depth = 1500

def check_limit(prefix, repeated, mode="single"):
expect_ok = prefix + repeated * success_depth
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def __repr__(self):

def test_repr_deep(self):
d = {}
for i in range(sys.getrecursionlimit() + 100):
for i in range(4000):
d = {1: d}
self.assertRaises(RecursionError, repr, d)

Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_exception_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def test_basics_split_by_predicate__match(self):
class DeepRecursionInSplitAndSubgroup(unittest.TestCase):
def make_deep_eg(self):
e = TypeError(1)
for i in range(2000):
for i in range(4000):
e = ExceptionGroup('eg', [e])
return e

Expand Down
12 changes: 6 additions & 6 deletions Lib/test/test_isinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from test import support



class TestIsInstanceExceptions(unittest.TestCase):
# Test to make sure that an AttributeError when accessing the instance's
# class's bases is masked. This was actually a bug in Python 2.2 and
Expand Down Expand Up @@ -97,7 +97,7 @@ def getclass(self):
class D: pass
self.assertRaises(RuntimeError, isinstance, c, D)


# These tests are similar to above, but tickle certain code paths in
# issubclass() instead of isinstance() -- really PyObject_IsSubclass()
# vs. PyObject_IsInstance().
Expand Down Expand Up @@ -147,7 +147,7 @@ def getbases(self):
self.assertRaises(TypeError, issubclass, B, C())



# meta classes for creating abstract classes and instances
class AbstractClass(object):
def __init__(self, bases):
Expand Down Expand Up @@ -179,7 +179,7 @@ class Super:

class Child(Super):
pass

class TestIsInstanceIsSubclass(unittest.TestCase):
# Tests to ensure that isinstance and issubclass work on abstract
# classes and instances. Before the 2.2 release, TypeErrors were
Expand Down Expand Up @@ -353,10 +353,10 @@ def blowstack(fxn, arg, compare_to):
# Make sure that calling isinstance with a deeply nested tuple for its
# argument will raise RecursionError eventually.
tuple_arg = (compare_to,)
for cnt in range(sys.getrecursionlimit()+5):
for cnt in range(4000):
tuple_arg = (tuple_arg,)
fxn(arg, tuple_arg)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Decouple C stack overflow checking from Python recursion checking. Allows
the recursion limit to be increased safely, and reduces the chance of
segfaults.
2 changes: 1 addition & 1 deletion Modules/_testinternalcapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ get_recursion_depth(PyObject *self, PyObject *Py_UNUSED(args))

/* subtract one to ignore the frame of the get_recursion_depth() call */

return PyLong_FromLong(tstate->recursion_limit - tstate->recursion_remaining - 1);
return PyLong_FromLong(tstate->py_recursion_limit - tstate->py_recursion_remaining - 1);
}


Expand Down
4 changes: 2 additions & 2 deletions Objects/object.c
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,12 @@ PyObject_Repr(PyObject *v)

/* It is possible for a type to have a tp_repr representation that loops
infinitely. */
if (_Py_EnterRecursiveCallTstate(tstate,
if (_Py_EnterRecursiveCallN(tstate, 2,
" while getting the repr of an object")) {
return NULL;
}
res = (*Py_TYPE(v)->tp_repr)(v);
_Py_LeaveRecursiveCallTstate(tstate);
_Py_LeaveRecursiveCallN(tstate, 2);

if (res == NULL) {
return NULL;
Expand Down
36 changes: 3 additions & 33 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,8 +1112,6 @@ def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
state->initialized = 1;
return 1;
}
Expand Down Expand Up @@ -1261,14 +1259,12 @@ def func_begin(self, name):
self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1)
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit('if (_Py_EnterRecursiveCall("during ast construction")) {', 1)
self.emit("return 0;", 2)
self.emit("}", 1)

def func_end(self):
self.emit("state->recursion_depth--;", 1)
self.emit("_Py_LeaveRecursiveCall();", 1)
self.emit("return result;", 1)
self.emit("failed:", 0)
self.emit("Py_XDECREF(value);", 1)
Expand Down Expand Up @@ -1380,31 +1376,7 @@ class PartingShots(StaticVisitor):
return NULL;
}

int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;
/* Be careful here to prevent overflow. */
int COMPILER_STACK_FRAME_SCALE = 3;
PyThreadState *tstate = _PyThreadState_GET();
if (!tstate) {
return 0;
}
state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
state->recursion_depth = starting_recursion_depth;

PyObject *result = ast2obj_mod(state, t);

/* Check that the recursion depth counting balanced correctly */
if (result && state->recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth);
return 0;
}
return result;
return ast2obj_mod(state, t);
}

/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
Expand Down Expand Up @@ -1470,8 +1442,6 @@ def visit(self, object):
def generate_ast_state(module_state, f):
f.write('struct ast_state {\n')
f.write(' int initialized;\n')
f.write(' int recursion_depth;\n')
f.write(' int recursion_limit;\n')
for s in module_state:
f.write(' PyObject *' + s + ';\n')
f.write('};')
Expand Down
Loading