Skip to content

gh-76785: Add *Channel.is_closed #110606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Lib/test/support/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ def __eq__(self, other):
def id(self):
return self._id

@property
def _info(self):
return _channels.get_info(self._id)

@property
def is_closed(self):
return self._info.closed


_NOT_SET = object()

Expand Down Expand Up @@ -213,6 +221,11 @@ class SendChannel(_ChannelEnd):

_end = 'send'

@property
def is_closed(self):
info = self._info
return info.closed or info.closing

def send(self, obj, timeout=None):
"""Send the object (i.e. its data) to the channel's receiving end.

Expand Down
13 changes: 13 additions & 0 deletions Lib/test/test_interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,19 @@ def test_shareable(self):
self.assertEqual(rch2, rch)
self.assertEqual(sch2, sch)

def test_is_closed(self):
rch, sch = interpreters.create_channel()
rbefore = rch.is_closed
sbefore = sch.is_closed
rch.close()
rafter = rch.is_closed
safter = sch.is_closed

self.assertFalse(rbefore)
self.assertFalse(sbefore)
self.assertTrue(rafter)
self.assertTrue(safter)


class TestRecvChannelAttrs(TestBase):

Expand Down
278 changes: 276 additions & 2 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ static PyTypeObject *
add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
struct xid_class_registry *classes)
{
PyTypeObject *cls = (PyTypeObject *)PyType_FromMetaclass(
NULL, mod, spec, NULL);
PyTypeObject *cls = (PyTypeObject *)PyType_FromModuleAndSpec(
mod, spec, NULL);
if (cls == NULL) {
return NULL;
}
Expand Down Expand Up @@ -402,6 +402,7 @@ typedef struct {
PyTypeObject *recv_channel_type;

/* heap types */
PyTypeObject *ChannelInfoType;
PyTypeObject *ChannelIDType;
PyTypeObject *XIBufferViewType;

Expand Down Expand Up @@ -441,6 +442,7 @@ static int
traverse_module_state(module_state *state, visitproc visit, void *arg)
{
/* heap types */
Py_VISIT(state->ChannelInfoType);
Py_VISIT(state->ChannelIDType);
Py_VISIT(state->XIBufferViewType);

Expand All @@ -457,10 +459,12 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
static int
clear_module_state(module_state *state)
{
/* external types */
Py_CLEAR(state->send_channel_type);
Py_CLEAR(state->recv_channel_type);

/* heap types */
Py_CLEAR(state->ChannelInfoType);
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
}
Expand Down Expand Up @@ -2088,6 +2092,236 @@ channel_is_associated(_channels *channels, int64_t cid, int64_t interpid,
}


/* channel info */

struct channel_info {
struct {
// 1: closed; -1: closing
int closed;
struct {
Py_ssize_t nsend_only; // not released
Py_ssize_t nsend_only_released;
Py_ssize_t nrecv_only; // not released
Py_ssize_t nrecv_only_released;
Py_ssize_t nboth; // not released
Py_ssize_t nboth_released;
Py_ssize_t nboth_send_released;
Py_ssize_t nboth_recv_released;
} all;
struct {
// 1: associated; -1: released
int send;
int recv;
} cur;
} status;
Py_ssize_t count;
};

static int
_channel_get_info(_channels *channels, int64_t cid, struct channel_info *info)
{
int err = 0;
*info = (struct channel_info){0};

// Get the current interpreter.
PyInterpreterState *interp = _get_current_interp();
if (interp == NULL) {
return -1;
}
Py_ssize_t interpid = PyInterpreterState_GetID(interp);

// Hold the global lock until we're done.
PyThread_acquire_lock(channels->mutex, WAIT_LOCK);

// Find the channel.
_channelref *ref = _channelref_find(channels->head, cid, NULL);
if (ref == NULL) {
err = ERR_CHANNEL_NOT_FOUND;
goto finally;
}
_channel_state *chan = ref->chan;

// Check if open.
if (chan == NULL) {
info->status.closed = 1;
goto finally;
}
if (!chan->open) {
assert(chan->queue->count == 0);
info->status.closed = 1;
goto finally;
}
if (chan->closing != NULL) {
assert(chan->queue->count > 0);
info->status.closed = -1;
}
else {
info->status.closed = 0;
}

// Get the number of queued objects.
info->count = chan->queue->count;

// Get the ends statuses.
assert(info->status.cur.send == 0);
assert(info->status.cur.recv == 0);
_channelend *send = chan->ends->send;
while (send != NULL) {
if (send->interpid == interpid) {
info->status.cur.send = send->open ? 1 : -1;
}

if (send->open) {
info->status.all.nsend_only += 1;
}
else {
info->status.all.nsend_only_released += 1;
}
send = send->next;
}
_channelend *recv = chan->ends->recv;
while (recv != NULL) {
if (recv->interpid == interpid) {
info->status.cur.recv = recv->open ? 1 : -1;
}

// XXX This is O(n*n). Why do we have 2 linked lists?
_channelend *send = chan->ends->send;
while (send != NULL) {
if (send->interpid == recv->interpid) {
break;
}
send = send->next;
}
if (send == NULL) {
if (recv->open) {
info->status.all.nrecv_only += 1;
}
else {
info->status.all.nrecv_only_released += 1;
}
}
else {
if (recv->open) {
if (send->open) {
info->status.all.nboth += 1;
info->status.all.nsend_only -= 1;
}
else {
info->status.all.nboth_recv_released += 1;
info->status.all.nsend_only_released -= 1;
}
}
else {
if (send->open) {
info->status.all.nboth_send_released += 1;
info->status.all.nsend_only -= 1;
}
else {
info->status.all.nboth_released += 1;
info->status.all.nsend_only_released -= 1;
}
}
}
recv = recv->next;
}

finally:
PyThread_release_lock(channels->mutex);
return err;
}

PyDoc_STRVAR(channel_info_doc,
"ChannelInfo\n\
\n\
A named tuple of a channel's state.");

static PyStructSequence_Field channel_info_fields[] = {
{"open", "both ends are open"},
{"closing", "send is closed, recv is non-empty"},
{"closed", "both ends are closed"},
{"count", "queued objects"},

{"num_interp_send", "interpreters bound to the send end"},
{"num_interp_send_released",
"interpreters bound to the send end and released"},

{"num_interp_recv", "interpreters bound to the send end"},
{"num_interp_recv_released",
"interpreters bound to the send end and released"},

{"num_interp_both", "interpreters bound to both ends"},
{"num_interp_both_released",
"interpreters bound to both ends and released_from_both"},
{"num_interp_both_send_released",
"interpreters bound to both ends and released_from_the send end"},
{"num_interp_both_recv_released",
"interpreters bound to both ends and released_from_the recv end"},

{"send_associated", "current interpreter is bound to the send end"},
{"send_released", "current interpreter *was* bound to the send end"},
{"recv_associated", "current interpreter is bound to the recv end"},
{"recv_released", "current interpreter *was* bound to the recv end"},
{0}
};

static PyStructSequence_Desc channel_info_desc = {
.name = MODULE_NAME ".ChannelInfo",
.doc = channel_info_doc,
.fields = channel_info_fields,
.n_in_sequence = 8,
};

static PyObject *
new_channel_info(PyObject *mod, struct channel_info *info)
{
module_state *state = get_module_state(mod);
if (state == NULL) {
return NULL;
}

assert(state->ChannelInfoType != NULL);
PyObject *self = PyStructSequence_New(state->ChannelInfoType);
if (self == NULL) {
return NULL;
}

int pos = 0;
#define SET_BOOL(val) \
PyStructSequence_SET_ITEM(self, pos++, \
Py_NewRef(val ? Py_True : Py_False))
#define SET_COUNT(val) \
do { \
PyObject *obj = PyLong_FromLongLong(val); \
if (obj == NULL) { \
Py_CLEAR(info); \
return NULL; \
} \
PyStructSequence_SET_ITEM(self, pos++, obj); \
} while(0)
SET_BOOL(info->status.closed == 0);
SET_BOOL(info->status.closed == -1);
SET_BOOL(info->status.closed == 1);
SET_COUNT(info->count);
SET_COUNT(info->status.all.nsend_only);
SET_COUNT(info->status.all.nsend_only_released);
SET_COUNT(info->status.all.nrecv_only);
SET_COUNT(info->status.all.nrecv_only_released);
SET_COUNT(info->status.all.nboth);
SET_COUNT(info->status.all.nboth_released);
SET_COUNT(info->status.all.nboth_send_released);
SET_COUNT(info->status.all.nboth_recv_released);
SET_BOOL(info->status.cur.send == 1);
SET_BOOL(info->status.cur.send == -1);
SET_BOOL(info->status.cur.recv == 1);
SET_BOOL(info->status.cur.recv == -1);
#undef SET_COUNT
#undef SET_BOOL
assert(!PyErr_Occurred());
return self;
}


/* ChannelID class */

typedef struct channelid {
Expand Down Expand Up @@ -3079,6 +3313,33 @@ Close the channel for the current interpreter. 'send' and 'recv'\n\
(bool) may be used to indicate the ends to close. By default both\n\
ends are closed. Closing an already closed end is a noop.");

static PyObject *
channelsmod_get_info(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"cid", NULL};
struct channel_id_converter_data cid_data = {
.module = self,
};
if (!PyArg_ParseTupleAndKeywords(args, kwds,
"O&:_get_info", kwlist,
channel_id_converter, &cid_data)) {
return NULL;
}
int64_t cid = cid_data.cid;

struct channel_info info;
int err = _channel_get_info(&_globals.channels, cid, &info);
if (handle_channel_error(err, self, cid)) {
return NULL;
}
return new_channel_info(self, &info);
}

PyDoc_STRVAR(channelsmod_get_info_doc,
"get_info(cid)\n\
\n\
Return details about the channel.");

static PyObject *
channelsmod__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
{
Expand Down Expand Up @@ -3143,6 +3404,8 @@ static PyMethodDef module_functions[] = {
METH_VARARGS | METH_KEYWORDS, channelsmod_close_doc},
{"release", _PyCFunction_CAST(channelsmod_release),
METH_VARARGS | METH_KEYWORDS, channelsmod_release_doc},
{"get_info", _PyCFunction_CAST(channelsmod_get_info),
METH_VARARGS | METH_KEYWORDS, channelsmod_get_info_doc},
{"_channel_id", _PyCFunction_CAST(channelsmod__channel_id),
METH_VARARGS | METH_KEYWORDS, NULL},
{"_register_end_types", _PyCFunction_CAST(channelsmod__register_end_types),
Expand Down Expand Up @@ -3179,19 +3442,30 @@ module_exec(PyObject *mod)

/* Add other types */

// ChannelInfo
state->ChannelInfoType = PyStructSequence_NewType(&channel_info_desc);
if (state->ChannelInfoType == NULL) {
goto error;
}
if (PyModule_AddType(mod, state->ChannelInfoType) < 0) {
goto error;
}

// ChannelID
state->ChannelIDType = add_new_type(
mod, &channelid_typespec, _channelid_shared, xid_classes);
if (state->ChannelIDType == NULL) {
goto error;
}

// XIBufferView
state->XIBufferViewType = add_new_type(mod, &XIBufferViewType_spec, NULL,
xid_classes);
if (state->XIBufferViewType == NULL) {
goto error;
}

// Register external types.
if (register_builtin_xid_types(xid_classes) < 0) {
goto error;
}
Expand Down