Skip to content

Commit cc215b3

Browse files
authored
Make eval_frame thread safe (pytorch#239)
This should make eval_frame thread safe. Currently, the eval_frame is a global object, and different threads my step on each other setting a different one. This changes the behavior to instead always* have a "shim" eval_frame which then routes to the correct behavior by looking at the thread-local associated object. This is thread safe because now the callback object is always thread safe, and we only use it to drive logic at frame eval time, as opposed to at callback registration time. Currently, the logic for None/False/Callback is kept, but the False case could be easily collapsed behind the shim in a subsequent diff. *Always here means always when dynamo is running. The shim is installed and removed based on keeping track of how many dynamo threads are running at the moment.
1 parent 9456a72 commit cc215b3

File tree

1 file changed

+88
-92
lines changed

1 file changed

+88
-92
lines changed

torchdynamo/_eval_frame.c

Lines changed: 88 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
#undef Py_BUILD_CORE
1111
#endif
1212

13-
//#define TORCHDYNAMO_MULTI_THREAD
14-
//#define TORCHDYNAMO_DEBUG
1513
#define bool char
1614
#define false 0
1715
#define true 1
@@ -60,8 +58,6 @@ static PyObject *guard_error_hook = NULL;
6058

6159
size_t extra_index = -1;
6260

63-
#ifdef TORCHDYNAMO_MULTI_THREAD
64-
6561
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
6662

6763
inline static PyObject *eval_frame_callback_get(void) {
@@ -77,45 +73,23 @@ inline static void eval_frame_callback_set(PyObject *obj) {
7773
PyThread_tss_set(&eval_frame_callback_key, obj);
7874
}
7975

80-
#else
81-
82-
static PyObject *eval_frame_callback_ = NULL;
83-
84-
inline static PyObject *eval_frame_callback_get(void) {
85-
return eval_frame_callback_;
86-
}
87-
88-
inline static void eval_frame_callback_set(PyObject *obj) {
89-
eval_frame_callback_ = obj;
90-
}
91-
92-
#endif
93-
9476
static void ignored(void *obj) {}
77+
static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
78+
PyFrameObject *frame, int throw_flag);
9579
static PyObject *_custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame,
96-
int throw_flag);
80+
int throw_flag, PyObject *callback);
9781
static PyObject *_custom_eval_frame_run_only(PyThreadState *tstate,
9882
PyFrameObject *frame,
9983
int throw_flag);
10084
#if PY_VERSION_HEX >= 0x03090000
101-
static PyObject *custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame,
102-
int throw_flag) {
103-
return _custom_eval_frame(tstate, frame, throw_flag);
104-
}
105-
static PyObject *custom_eval_frame_run_only(PyThreadState *tstate,
106-
PyFrameObject *frame,
107-
int throw_flag) {
108-
return _custom_eval_frame_run_only(tstate, frame, throw_flag);
85+
static PyObject *custom_eval_frame_shim(PyThreadState *tstate,
86+
PyFrameObject *frame, int throw_flag) {
87+
return _custom_eval_frame_shim(tstate, frame, throw_flag);
10988
}
11089
#else
111-
static PyObject *custom_eval_frame(PyFrameObject *frame, int throw_flag) {
112-
PyThreadState *tstate = PyThreadState_GET();
113-
return _custom_eval_frame(tstate, frame, throw_flag);
114-
}
115-
static PyObject *custom_eval_frame_run_only(PyFrameObject *frame,
116-
int throw_flag) {
90+
static PyObject *custom_eval_frame_shim(PyFrameObject *frame, int throw_flag) {
11791
PyThreadState *tstate = PyThreadState_GET();
118-
return _custom_eval_frame_run_only(tstate, frame, throw_flag);
92+
return _custom_eval_frame_shim(tstate, frame, throw_flag);
11993
}
12094
#endif
12195

@@ -132,29 +106,33 @@ inline static PyObject *eval_frame_default(PyThreadState *tstate,
132106
#endif
133107
}
134108

135-
inline static void enable_eval_frame(PyThreadState *tstate) {
109+
inline static void enable_eval_frame_shim(PyThreadState *tstate) {
136110
#if PY_VERSION_HEX >= 0x03090000
137-
_PyInterpreterState_SetEvalFrameFunc(tstate->interp, &custom_eval_frame);
138-
#else
139-
tstate->interp->eval_frame = &custom_eval_frame;
140-
#endif
141-
}
142-
143-
inline static void disable_eval_frame(PyThreadState *tstate) {
144-
#if PY_VERSION_HEX >= 0x03090000
145-
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
146-
&_PyEval_EvalFrameDefault);
111+
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
112+
&custom_eval_frame_shim) {
113+
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
114+
&custom_eval_frame_shim);
115+
}
147116
#else
148-
tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
117+
if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
118+
// First call
119+
tstate->interp->eval_frame = &custom_eval_frame_shim;
120+
}
149121
#endif
150122
}
151123

152-
inline static void enable_run_only_eval_frame(PyThreadState *tstate) {
124+
inline static void enable_eval_frame_default(PyThreadState *tstate) {
153125
#if PY_VERSION_HEX >= 0x03090000
154-
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
155-
&custom_eval_frame_run_only);
126+
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
127+
&_PyEval_EvalFrameDefault) {
128+
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
129+
&_PyEval_EvalFrameDefault);
130+
}
156131
#else
157-
tstate->interp->eval_frame = &custom_eval_frame_run_only;
132+
if (tstate->interp->eval_frame != &_PyEval_EvalFrameDefault) {
133+
// First call
134+
tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
135+
}
158136
#endif
159137
}
160138

@@ -310,8 +288,26 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate,
310288
return result;
311289
}
312290

291+
static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
292+
PyFrameObject *frame, int throw_flag) {
293+
// Shims logic into one of three states. Can probably be refactored into a
294+
// single func, later:
295+
// - None: disables TorchDynamo
296+
// - False: run-only mode (reuse existing compiles)
297+
// - Python callable(): enables TorchDynamo
298+
PyObject *callback = eval_frame_callback_get();
299+
300+
if (callback == Py_None) {
301+
return eval_frame_default(tstate, frame, throw_flag);
302+
} else if (callback == Py_False) {
303+
return _custom_eval_frame_run_only(tstate, frame, throw_flag);
304+
} else {
305+
return _custom_eval_frame(tstate, frame, throw_flag, callback);
306+
}
307+
}
308+
313309
static PyObject *_custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame,
314-
int throw_flag) {
310+
int throw_flag, PyObject *callback) {
315311
DEBUG_TRACE("begin %s %s %i %i %i %i", name(frame),
316312
PyUnicode_AsUTF8(frame->f_code->co_filename), frame->f_lineno,
317313
frame->f_lasti, frame->f_iblock, frame->f_executing);
@@ -328,15 +324,17 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame,
328324
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
329325
DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins));
330326

331-
// don't run custom_eval_frame() for guard function
332-
PyObject *callback = eval_frame_callback_get();
333-
disable_eval_frame(tstate);
327+
// We don't run the current custom_eval_frame behavior for guards.
328+
// So we temporarily set the callback to Py_None to drive the correct behavior
329+
// in the shim.
330+
eval_frame_callback_set(Py_None);
334331

335332
PyCodeObject *cached_code = lookup(extra, frame->f_locals);
336333
if (cached_code != NULL) {
337334
// used cached version
338335
DEBUG_TRACE("cache hit %s", name(frame));
339-
enable_eval_frame(tstate);
336+
// Re-enable custom behavior
337+
eval_frame_callback_set(callback);
340338
return eval_custom_code(tstate, frame, cached_code, throw_flag);
341339
}
342340
// cache miss
@@ -353,13 +351,15 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame,
353351
extra = create_cache_entry(extra, result);
354352
Py_DECREF(result);
355353
set_extra(frame->f_code, extra);
356-
enable_eval_frame(tstate);
354+
// Re-enable custom behavior
355+
eval_frame_callback_set(callback);
357356
return eval_custom_code(tstate, frame, extra->code, throw_flag);
358357
} else {
359358
DEBUG_TRACE("create skip %s", name(frame));
360359
Py_DECREF(result);
361360
set_extra(frame->f_code, SKIP_CODE);
362-
enable_eval_frame(tstate);
361+
// Re-enable custom behavior
362+
eval_frame_callback_set(callback);
363363
return eval_frame_default(tstate, frame, throw_flag);
364364
}
365365
}
@@ -390,52 +390,50 @@ static PyObject *_custom_eval_frame_run_only(PyThreadState *tstate,
390390
}
391391
}
392392

393+
static int active_dynamo_threads = 0;
394+
395+
static PyObject *increment_working_threads(PyThreadState *tstate) {
396+
active_dynamo_threads = active_dynamo_threads + 1;
397+
if (active_dynamo_threads > 0) {
398+
enable_eval_frame_shim(tstate);
399+
}
400+
Py_RETURN_NONE;
401+
}
402+
403+
static PyObject *decrement_working_threads(PyThreadState *tstate) {
404+
if (active_dynamo_threads > 0) {
405+
active_dynamo_threads = active_dynamo_threads - 1;
406+
if (active_dynamo_threads == 0) {
407+
enable_eval_frame_default(tstate);
408+
}
409+
}
410+
Py_RETURN_NONE;
411+
}
412+
393413
static PyObject *set_eval_frame(PyObject *new_callback, PyThreadState *tstate) {
394414
// Change the eval frame callback and return the old one
395415
// - None: disables TorchDynamo
396416
// - False: run-only mode (reuse existing compiles)
397417
// - Python callable(): enables TorchDynamo
398-
399-
PyObject *old_callback;
400-
#if PY_VERSION_HEX >= 0x03090000
401-
void *old_eval_frame = _PyInterpreterState_GetEvalFrameFunc(tstate->interp);
402-
#else
403-
void *old_eval_frame = tstate->interp->eval_frame;
404-
#endif
405-
if (old_eval_frame == &custom_eval_frame) {
406-
old_callback = eval_frame_callback_get();
407-
} else if (old_eval_frame == &custom_eval_frame_run_only) {
408-
old_callback = Py_False;
409-
} else if (old_eval_frame == &_PyEval_EvalFrameDefault) {
410-
old_callback = Py_None;
411-
} else {
412-
CHECK(false);
413-
}
418+
PyObject *old_callback = eval_frame_callback_get();
414419

415420
// owned by caller
416421
Py_INCREF(old_callback);
417422

418-
if (new_callback == Py_None) {
419-
disable_eval_frame(tstate);
420-
} else if (new_callback == Py_False) {
421-
enable_run_only_eval_frame(tstate);
422-
} else {
423-
enable_eval_frame(tstate);
424-
#ifdef TORCHDYNAMO_MULTI_THREAD
425-
} // callback is private, so always clear it
426-
#endif
423+
if (old_callback != Py_None && new_callback == Py_None) {
424+
decrement_working_threads(tstate);
425+
} else if (old_callback == Py_None && new_callback != Py_None) {
426+
increment_working_threads(tstate);
427+
}
427428

428429
Py_INCREF(new_callback);
429-
Py_DECREF(eval_frame_callback_get());
430-
eval_frame_callback_set(new_callback);
430+
Py_DECREF(old_callback);
431431

432-
#ifndef TORCHDYNAMO_MULTI_THREAD
433-
// only actually set the global variable if we are enabled, otherwise don't
434-
// mess with other threads
435-
}
436-
#endif
432+
// Set thread local callback. This will drive behavior of our shim, if/when it
433+
// is installed.
434+
eval_frame_callback_set(new_callback);
437435

438-
return old_callback;
436+
return old_callback;
439437
}
440438

441439
static PyObject *set_eval_frame_py(PyObject *dummy, PyObject *args) {
@@ -543,10 +541,8 @@ PyMODINIT_FUNC PyInit__eval_frame(void) {
543541
CHECK(sizeof(unsigned long) == sizeof(void *));
544542
extra_index = _PyEval_RequestCodeExtraIndex(ignored);
545543

546-
#ifdef TORCHDYNAMO_MULTI_THREAD
547544
int result = PyThread_tss_create(&eval_frame_callback_key);
548545
CHECK(result == 0);
549-
#endif
550546

551547
Py_INCREF(Py_None);
552548
eval_frame_callback_set(Py_None);

0 commit comments

Comments
 (0)