10
10
#undef Py_BUILD_CORE
11
11
#endif
12
12
13
- //#define TORCHDYNAMO_MULTI_THREAD
14
- //#define TORCHDYNAMO_DEBUG
15
13
#define bool char
16
14
#define false 0
17
15
#define true 1
@@ -60,8 +58,6 @@ static PyObject *guard_error_hook = NULL;
60
58
61
59
size_t extra_index = -1 ;
62
60
63
- #ifdef TORCHDYNAMO_MULTI_THREAD
64
-
65
61
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT ;
66
62
67
63
inline static PyObject * eval_frame_callback_get (void ) {
@@ -77,45 +73,23 @@ inline static void eval_frame_callback_set(PyObject *obj) {
77
73
PyThread_tss_set (& eval_frame_callback_key , obj );
78
74
}
79
75
80
- #else
81
-
82
- static PyObject * eval_frame_callback_ = NULL ;
83
-
84
- inline static PyObject * eval_frame_callback_get (void ) {
85
- return eval_frame_callback_ ;
86
- }
87
-
88
- inline static void eval_frame_callback_set (PyObject * obj ) {
89
- eval_frame_callback_ = obj ;
90
- }
91
-
92
- #endif
93
-
94
76
static void ignored (void * obj ) {}
77
+ static PyObject * _custom_eval_frame_shim (PyThreadState * tstate ,
78
+ PyFrameObject * frame , int throw_flag );
95
79
static PyObject * _custom_eval_frame (PyThreadState * tstate , PyFrameObject * frame ,
96
- int throw_flag );
80
+ int throw_flag , PyObject * callback );
97
81
static PyObject * _custom_eval_frame_run_only (PyThreadState * tstate ,
98
82
PyFrameObject * frame ,
99
83
int throw_flag );
100
84
#if PY_VERSION_HEX >= 0x03090000
101
- static PyObject * custom_eval_frame (PyThreadState * tstate , PyFrameObject * frame ,
102
- int throw_flag ) {
103
- return _custom_eval_frame (tstate , frame , throw_flag );
104
- }
105
- static PyObject * custom_eval_frame_run_only (PyThreadState * tstate ,
106
- PyFrameObject * frame ,
107
- int throw_flag ) {
108
- return _custom_eval_frame_run_only (tstate , frame , throw_flag );
85
+ static PyObject * custom_eval_frame_shim (PyThreadState * tstate ,
86
+ PyFrameObject * frame , int throw_flag ) {
87
+ return _custom_eval_frame_shim (tstate , frame , throw_flag );
109
88
}
110
89
#else
111
- static PyObject * custom_eval_frame (PyFrameObject * frame , int throw_flag ) {
112
- PyThreadState * tstate = PyThreadState_GET ();
113
- return _custom_eval_frame (tstate , frame , throw_flag );
114
- }
115
- static PyObject * custom_eval_frame_run_only (PyFrameObject * frame ,
116
- int throw_flag ) {
90
+ static PyObject * custom_eval_frame_shim (PyFrameObject * frame , int throw_flag ) {
117
91
PyThreadState * tstate = PyThreadState_GET ();
118
- return _custom_eval_frame_run_only (tstate , frame , throw_flag );
92
+ return _custom_eval_frame_shim (tstate , frame , throw_flag );
119
93
}
120
94
#endif
121
95
@@ -132,29 +106,33 @@ inline static PyObject *eval_frame_default(PyThreadState *tstate,
132
106
#endif
133
107
}
134
108
135
- inline static void enable_eval_frame (PyThreadState * tstate ) {
109
+ inline static void enable_eval_frame_shim (PyThreadState * tstate ) {
136
110
#if PY_VERSION_HEX >= 0x03090000
137
- _PyInterpreterState_SetEvalFrameFunc (tstate -> interp , & custom_eval_frame );
138
- #else
139
- tstate -> interp -> eval_frame = & custom_eval_frame ;
140
- #endif
141
- }
142
-
143
- inline static void disable_eval_frame (PyThreadState * tstate ) {
144
- #if PY_VERSION_HEX >= 0x03090000
145
- _PyInterpreterState_SetEvalFrameFunc (tstate -> interp ,
146
- & _PyEval_EvalFrameDefault );
111
+ if (_PyInterpreterState_GetEvalFrameFunc (tstate -> interp ) !=
112
+ & custom_eval_frame_shim ) {
113
+ _PyInterpreterState_SetEvalFrameFunc (tstate -> interp ,
114
+ & custom_eval_frame_shim );
115
+ }
147
116
#else
148
- tstate -> interp -> eval_frame = & _PyEval_EvalFrameDefault ;
117
+ if (tstate -> interp -> eval_frame != & custom_eval_frame_shim ) {
118
+ // First call
119
+ tstate -> interp -> eval_frame = & custom_eval_frame_shim ;
120
+ }
149
121
#endif
150
122
}
151
123
152
- inline static void enable_run_only_eval_frame (PyThreadState * tstate ) {
124
+ inline static void enable_eval_frame_default (PyThreadState * tstate ) {
153
125
#if PY_VERSION_HEX >= 0x03090000
154
- _PyInterpreterState_SetEvalFrameFunc (tstate -> interp ,
155
- & custom_eval_frame_run_only );
126
+ if (_PyInterpreterState_GetEvalFrameFunc (tstate -> interp ) !=
127
+ & _PyEval_EvalFrameDefault ) {
128
+ _PyInterpreterState_SetEvalFrameFunc (tstate -> interp ,
129
+ & _PyEval_EvalFrameDefault );
130
+ }
156
131
#else
157
- tstate -> interp -> eval_frame = & custom_eval_frame_run_only ;
132
+ if (tstate -> interp -> eval_frame != & _PyEval_EvalFrameDefault ) {
133
+ // First call
134
+ tstate -> interp -> eval_frame = & _PyEval_EvalFrameDefault ;
135
+ }
158
136
#endif
159
137
}
160
138
@@ -310,8 +288,26 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate,
310
288
return result ;
311
289
}
312
290
291
+ static PyObject * _custom_eval_frame_shim (PyThreadState * tstate ,
292
+ PyFrameObject * frame , int throw_flag ) {
293
+ // Shims logic into one of three states. Can probably be refactored into a
294
+ // single func, later:
295
+ // - None: disables TorchDynamo
296
+ // - False: run-only mode (reuse existing compiles)
297
+ // - Python callable(): enables TorchDynamo
298
+ PyObject * callback = eval_frame_callback_get ();
299
+
300
+ if (callback == Py_None ) {
301
+ return eval_frame_default (tstate , frame , throw_flag );
302
+ } else if (callback == Py_False ) {
303
+ return _custom_eval_frame_run_only (tstate , frame , throw_flag );
304
+ } else {
305
+ return _custom_eval_frame (tstate , frame , throw_flag , callback );
306
+ }
307
+ }
308
+
313
309
static PyObject * _custom_eval_frame (PyThreadState * tstate , PyFrameObject * frame ,
314
- int throw_flag ) {
310
+ int throw_flag , PyObject * callback ) {
315
311
DEBUG_TRACE ("begin %s %s %i %i %i %i" , name (frame ),
316
312
PyUnicode_AsUTF8 (frame -> f_code -> co_filename ), frame -> f_lineno ,
317
313
frame -> f_lasti , frame -> f_iblock , frame -> f_executing );
@@ -328,15 +324,17 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame,
328
324
DEBUG_CHECK (PyDict_CheckExact (frame -> f_globals ));
329
325
DEBUG_CHECK (PyDict_CheckExact (frame -> f_builtins ));
330
326
331
- // don't run custom_eval_frame() for guard function
332
- PyObject * callback = eval_frame_callback_get ();
333
- disable_eval_frame (tstate );
327
+ // We don't run the current custom_eval_frame behavior for guards.
328
+ // So we temporarily set the callback to Py_None to drive the correct behavior
329
+ // in the shim.
330
+ eval_frame_callback_set (Py_None );
334
331
335
332
PyCodeObject * cached_code = lookup (extra , frame -> f_locals );
336
333
if (cached_code != NULL ) {
337
334
// used cached version
338
335
DEBUG_TRACE ("cache hit %s" , name (frame ));
339
- enable_eval_frame (tstate );
336
+ // Re-enable custom behavior
337
+ eval_frame_callback_set (callback );
340
338
return eval_custom_code (tstate , frame , cached_code , throw_flag );
341
339
}
342
340
// cache miss
@@ -353,13 +351,15 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame,
353
351
extra = create_cache_entry (extra , result );
354
352
Py_DECREF (result );
355
353
set_extra (frame -> f_code , extra );
356
- enable_eval_frame (tstate );
354
+ // Re-enable custom behavior
355
+ eval_frame_callback_set (callback );
357
356
return eval_custom_code (tstate , frame , extra -> code , throw_flag );
358
357
} else {
359
358
DEBUG_TRACE ("create skip %s" , name (frame ));
360
359
Py_DECREF (result );
361
360
set_extra (frame -> f_code , SKIP_CODE );
362
- enable_eval_frame (tstate );
361
+ // Re-enable custom behavior
362
+ eval_frame_callback_set (callback );
363
363
return eval_frame_default (tstate , frame , throw_flag );
364
364
}
365
365
}
@@ -390,52 +390,50 @@ static PyObject *_custom_eval_frame_run_only(PyThreadState *tstate,
390
390
}
391
391
}
392
392
393
+ static int active_dynamo_threads = 0 ;
394
+
395
+ static PyObject * increment_working_threads (PyThreadState * tstate ) {
396
+ active_dynamo_threads = active_dynamo_threads + 1 ;
397
+ if (active_dynamo_threads > 0 ) {
398
+ enable_eval_frame_shim (tstate );
399
+ }
400
+ Py_RETURN_NONE ;
401
+ }
402
+
403
+ static PyObject * decrement_working_threads (PyThreadState * tstate ) {
404
+ if (active_dynamo_threads > 0 ) {
405
+ active_dynamo_threads = active_dynamo_threads - 1 ;
406
+ if (active_dynamo_threads == 0 ) {
407
+ enable_eval_frame_default (tstate );
408
+ }
409
+ }
410
+ Py_RETURN_NONE ;
411
+ }
412
+
393
413
static PyObject * set_eval_frame (PyObject * new_callback , PyThreadState * tstate ) {
394
414
// Change the eval frame callback and return the old one
395
415
// - None: disables TorchDynamo
396
416
// - False: run-only mode (reuse existing compiles)
397
417
// - Python callable(): enables TorchDynamo
398
-
399
- PyObject * old_callback ;
400
- #if PY_VERSION_HEX >= 0x03090000
401
- void * old_eval_frame = _PyInterpreterState_GetEvalFrameFunc (tstate -> interp );
402
- #else
403
- void * old_eval_frame = tstate -> interp -> eval_frame ;
404
- #endif
405
- if (old_eval_frame == & custom_eval_frame ) {
406
- old_callback = eval_frame_callback_get ();
407
- } else if (old_eval_frame == & custom_eval_frame_run_only ) {
408
- old_callback = Py_False ;
409
- } else if (old_eval_frame == & _PyEval_EvalFrameDefault ) {
410
- old_callback = Py_None ;
411
- } else {
412
- CHECK (false);
413
- }
418
+ PyObject * old_callback = eval_frame_callback_get ();
414
419
415
420
// owned by caller
416
421
Py_INCREF (old_callback );
417
422
418
- if (new_callback == Py_None ) {
419
- disable_eval_frame (tstate );
420
- } else if (new_callback == Py_False ) {
421
- enable_run_only_eval_frame (tstate );
422
- } else {
423
- enable_eval_frame (tstate );
424
- #ifdef TORCHDYNAMO_MULTI_THREAD
425
- } // callback is private, so always clear it
426
- #endif
423
+ if (old_callback != Py_None && new_callback == Py_None ) {
424
+ decrement_working_threads (tstate );
425
+ } else if (old_callback == Py_None && new_callback != Py_None ) {
426
+ increment_working_threads (tstate );
427
+ }
427
428
428
429
Py_INCREF (new_callback );
429
- Py_DECREF (eval_frame_callback_get ());
430
- eval_frame_callback_set (new_callback );
430
+ Py_DECREF (old_callback );
431
431
432
- #ifndef TORCHDYNAMO_MULTI_THREAD
433
- // only actually set the global variable if we are enabled, otherwise don't
434
- // mess with other threads
435
- }
436
- #endif
432
+ // Set thread local callback. This will drive behavior of our shim, if/when it
433
+ // is installed.
434
+ eval_frame_callback_set (new_callback );
437
435
438
- return old_callback ;
436
+ return old_callback ;
439
437
}
440
438
441
439
static PyObject * set_eval_frame_py (PyObject * dummy , PyObject * args ) {
@@ -543,10 +541,8 @@ PyMODINIT_FUNC PyInit__eval_frame(void) {
543
541
CHECK (sizeof (unsigned long ) == sizeof (void * ));
544
542
extra_index = _PyEval_RequestCodeExtraIndex (ignored );
545
543
546
- #ifdef TORCHDYNAMO_MULTI_THREAD
547
544
int result = PyThread_tss_create (& eval_frame_callback_key );
548
545
CHECK (result == 0 );
549
- #endif
550
546
551
547
Py_INCREF (Py_None );
552
548
eval_frame_callback_set (Py_None );
0 commit comments