Skip to content

gh-76785: Add Interpreter.bind() #111575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions Include/internal/pycore_crossinterp.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,6 @@ extern void _PyXI_Fini(PyInterpreterState *interp);
/* short-term data sharing */
/***************************/

// Ultimately we'd like to preserve enough information about the
// exception and traceback that we could re-constitute (or at least
// simulate, a la traceback.TracebackException), and even chain, a copy
// of the exception in the calling interpreter.

typedef struct _excinfo {
const char *type;
const char *msg;
} _Py_excinfo;


typedef enum error_code {
_PyXI_ERR_NO_ERROR = 0,
_PyXI_ERR_UNCAUGHT_EXCEPTION = -1,
Expand All @@ -197,9 +186,6 @@ typedef struct _sharedexception {
_Py_excinfo uncaught;
} _PyXI_exception_info;

PyAPI_FUNC(void) _PyXI_ApplyExceptionInfo(
_PyXI_exception_info *info,
PyObject *exctype);

typedef struct xi_session _PyXI_session;
typedef struct _sharedns _PyXI_namespace;
Expand Down Expand Up @@ -269,6 +255,9 @@ PyAPI_FUNC(void) _PyXI_Exit(_PyXI_session *session);
PyAPI_FUNC(void) _PyXI_ApplyCapturedException(
_PyXI_session *session,
PyObject *excwrapper);
PyAPI_FUNC(PyObject *) _PyXI_ResolveCapturedException(
_PyXI_session *session,
PyObject *excwrapper);
PyAPI_FUNC(int) _PyXI_HasCapturedException(_PyXI_session *session);


Expand Down
16 changes: 14 additions & 2 deletions Include/internal/pycore_exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@ extern "C" {
# error "this header requires Py_BUILD_CORE define"
#endif

#include "pycore_pyerrors.h"


/* runtime lifecycle */

extern PyStatus _PyExc_InitState(PyInterpreterState *);
extern PyStatus _PyExc_InitGlobalObjects(PyInterpreterState *);
extern int _PyExc_InitTypes(PyInterpreterState *);
extern void _PyExc_FiniHeapObjects(PyInterpreterState *);
extern void _PyExc_FiniTypes(PyInterpreterState *);
extern void _PyExc_Fini(PyInterpreterState *);


/* other API */
/* runtime state */

struct _Py_exc_state {
// The dict mapping from errno codes to OSError subclasses
Expand All @@ -26,9 +30,17 @@ struct _Py_exc_state {
int memerrors_numfree;
// The ExceptionGroup type
PyObject *PyExc_ExceptionGroup;

PyTypeObject *ExceptionSnapshotType;
};

extern void _PyExc_ClearExceptionGroupType(PyInterpreterState *);

/* other API */

PyAPI_FUNC(PyTypeObject *) _PyExc_GetExceptionSnapshotType(
PyInterpreterState *interp);

extern PyObject * PyExceptionSnapshot_FromInfo(_Py_excinfo *info);


#ifdef __cplusplus
Expand Down
24 changes: 24 additions & 0 deletions Include/internal/pycore_pyerrors.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,30 @@ extern PyStatus _PyErr_InitTypes(PyInterpreterState *);
extern void _PyErr_FiniTypes(PyInterpreterState *);


/* exception snapshots */

// Ultimately we'd like to preserve enough information about the
// exception and traceback that we could re-constitute (or at least
// simulate, a la traceback.TracebackException), and even chain, a copy
// of the exception in the calling interpreter.

typedef struct _excinfo {
const char *type;
const char *msg;
} _Py_excinfo;

extern void _Py_excinfo_Clear(_Py_excinfo *info);
extern int _Py_excinfo_Copy(_Py_excinfo *dest, _Py_excinfo *src);
extern const char * _Py_excinfo_InitFromException(
_Py_excinfo *info,
PyObject *exc);
extern void _Py_excinfo_Apply(_Py_excinfo *info, PyObject *exctype);
extern const char * _Py_excinfo_AsUTF8(
_Py_excinfo *info,
char *buf,
size_t bufsize);


/* other API */

static inline PyObject* _PyErr_Occurred(PyThreadState *tstate)
Expand Down
43 changes: 41 additions & 2 deletions Lib/test/support/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@

__all__ = [
'Interpreter', 'get_current', 'get_main', 'create', 'list_all',
'RunFailedError',
'SendChannel', 'RecvChannel',
'create_channel', 'list_all_channels', 'is_shareable',
'ChannelError', 'ChannelNotFoundError',
'ChannelEmptyError',
]


class RunFailedError(RuntimeError):

def __init__(self, snapshot):
if snapshot.type and snapshot.msg:
msg = f'{snapshot.type}: {snapshot.msg}'
else:
msg = snapshot.type or snapshot.msg
super().__init__(msg)
self.snapshot = snapshot


def create(*, isolated=True):
"""Return a new (idle) Python interpreter."""
id = _interpreters.create(isolated=isolated)
Expand Down Expand Up @@ -91,14 +103,39 @@ def close(self):
"""
return _interpreters.destroy(self._id)

# XXX setattr?
def bind(self, ns=None, /, **kwargs):
"""Bind the given values into the interpreter's __main__.

The values must be shareable.
"""
ns = dict(ns, **kwargs) if ns is not None else kwargs
_interpreters.set___main___attrs(self._id, ns)

# XXX getattr?
def get(self, name, default=None, /):
"""Return the attr value from the interpreter's __main__.

The value must be shareable.
"""
found = _interpreters.get___main___attrs(self._id, (name,), default)
assert len(found) == 1, found
return found[name]

# XXX Rename "run" to "exec"?
def run(self, src_str, /, channels=None):
# XXX Do not allow init to overwrite (by default)?
def run(self, src_str, /, *, init=None):
"""Run the given source code in the interpreter.

This is essentially the same as calling the builtin "exec"
with this interpreter, using the __dict__ of its __main__
module as both globals and locals.

If "init" is provided, it must be a dict mapping attribute names
to "shareable" objects, including channels. These are set as
attributes on the __main__ module before the given code is
executed. If a name is already bound then it is overwritten.

There is no return value.

If the code raises an unhandled exception then a RunFailedError
Expand All @@ -110,7 +147,9 @@ def run(self, src_str, /, channels=None):
that time, the previous interpreter is allowed to run
in other threads.
"""
_interpreters.exec(self._id, src_str, channels)
err = _interpreters.exec(self._id, src_str, init)
if err is not None:
raise RunFailedError(err)


def create_channel():
Expand Down
16 changes: 8 additions & 8 deletions Lib/test/test__xxinterpchannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,12 @@ def test_run_string_arg_unresolved(self):
cid = channels.create()
interp = interpreters.create()

interpreters.set___main___attrs(interp, dict(cid=cid.send))
out = _run_output(interp, dedent("""
import _xxinterpchannels as _channels
print(cid.end)
_channels.send(cid, b'spam', blocking=False)
"""),
dict(cid=cid.send))
"""))
obj = channels.recv(cid)

self.assertEqual(obj, b'spam')
Expand Down Expand Up @@ -1017,16 +1017,16 @@ def test_close_multiple_users(self):
_channels.recv({cid})
"""))
channels.close(cid)
with self.assertRaises(interpreters.RunFailedError) as cm:
interpreters.run_string(id1, dedent(f"""

excsnap = interpreters.run_string(id1, dedent(f"""
_channels.send({cid}, b'spam')
"""))
self.assertIn('ChannelClosedError', str(cm.exception))
with self.assertRaises(interpreters.RunFailedError) as cm:
interpreters.run_string(id2, dedent(f"""
self.assertIn('ChannelClosedError', excsnap.type)

excsnap = interpreters.run_string(id2, dedent(f"""
_channels.send({cid}, b'spam')
"""))
self.assertIn('ChannelClosedError', str(cm.exception))
self.assertIn('ChannelClosedError', excsnap.type)

def test_close_multiple_times(self):
cid = channels.create()
Expand Down
61 changes: 35 additions & 26 deletions Lib/test/test__xxsubinterpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def _captured_script(script):
return wrapped, open(r, encoding="utf-8")


def _run_output(interp, request, shared=None):
def _run_output(interp, request):
script, rpipe = _captured_script(request)
with rpipe:
interpreters.run_string(interp, script, shared)
interpreters.run_string(interp, script)
return rpipe.read()


Expand Down Expand Up @@ -659,10 +659,10 @@ def test_shareable_types(self):
]
for obj in objects:
with self.subTest(obj):
interpreters.set___main___attrs(interp, dict(obj=obj))
interpreters.run_string(
interp,
f'assert(obj == {obj!r})',
shared=dict(obj=obj),
)

def test_os_exec(self):
Expand Down Expand Up @@ -743,30 +743,33 @@ def assert_run_failed(self, exctype, msg=None):
"{}: {}".format(exctype.__name__, msg))

def test_invalid_syntax(self):
with self.assert_run_failed(SyntaxError):
# missing close paren
interpreters.run_string(self.id, 'print("spam"')
# missing close paren
exc = interpreters.run_string(self.id, 'print("spam"')
self.assertEqual(exc.type, 'SyntaxError')

def test_failure(self):
with self.assert_run_failed(Exception, 'spam'):
interpreters.run_string(self.id, 'raise Exception("spam")')
exc = interpreters.run_string(self.id, 'raise Exception("spam")')
self.assertEqual(exc.type, 'Exception')
self.assertEqual(exc.msg, 'spam')

def test_SystemExit(self):
with self.assert_run_failed(SystemExit, '42'):
interpreters.run_string(self.id, 'raise SystemExit(42)')
exc = interpreters.run_string(self.id, 'raise SystemExit(42)')
self.assertEqual(exc.type, 'SystemExit')
self.assertEqual(exc.msg, '42')

def test_sys_exit(self):
with self.assert_run_failed(SystemExit):
interpreters.run_string(self.id, dedent("""
import sys
sys.exit()
"""))
exc = interpreters.run_string(self.id, dedent("""
import sys
sys.exit()
"""))
self.assertEqual(exc.type, 'SystemExit')

with self.assert_run_failed(SystemExit, '42'):
interpreters.run_string(self.id, dedent("""
import sys
sys.exit(42)
"""))
exc = interpreters.run_string(self.id, dedent("""
import sys
sys.exit(42)
"""))
self.assertEqual(exc.type, 'SystemExit')
self.assertEqual(exc.msg, '42')

def test_with_shared(self):
r, w = os.pipe()
Expand All @@ -787,7 +790,8 @@ def test_with_shared(self):
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)

Expand All @@ -808,7 +812,8 @@ def test_shared_overwrites(self):
ns2 = dict(vars())
del ns2['__builtins__']
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)

r, w = os.pipe()
script = dedent(f"""
Expand Down Expand Up @@ -839,7 +844,8 @@ def test_shared_overwrites_default_vars(self):
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)

Expand Down Expand Up @@ -945,7 +951,8 @@ def script():
with open(w, 'w', encoding="utf-8") as spipe:
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
interpreters.run_func(self.id, script, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, script)

with open(r, encoding="utf-8") as outfile:
out = outfile.read()
Expand All @@ -961,7 +968,8 @@ def script():
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
def f():
interpreters.run_func(self.id, script, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, script)
t = threading.Thread(target=f)
t.start()
t.join()
Expand All @@ -981,7 +989,8 @@ def script():
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
code = script.__code__
interpreters.run_func(self.id, code, shared=dict(w=w))
interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, code)

with open(r, encoding="utf-8") as outfile:
out = outfile.read()
Expand Down
10 changes: 6 additions & 4 deletions Lib/test/test_import/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,10 +1968,12 @@ def test_disallowed_reimport(self):
print(_testsinglephase)
''')
interpid = _interpreters.create()
with self.assertRaises(_interpreters.RunFailedError):
_interpreters.run_string(interpid, script)
with self.assertRaises(_interpreters.RunFailedError):
_interpreters.run_string(interpid, script)

excsnap = _interpreters.run_string(interpid, script)
self.assertIsNot(excsnap, None)

excsnap = _interpreters.run_string(interpid, script)
self.assertIsNot(excsnap, None)


class TestSinglePhaseSnapshot(ModuleSnapshot):
Expand Down
Loading